add_statsmodel_fit

plotly.stats.add_statsmodel_fit(
    fig: go.Figure,
    x: np.ndarray,
    y: np.ndarray,
    fitfunc: Callable = sm.OLS,
    row: int | None = None,
    col: int | None = None,
    ci_alpha: float = 0.05,
    show_ci: bool = True,
    show_obs_ci: bool = False,
    line_kwargs: dict = {'line': {'color': '#222'}},
    ci_kwargs: dict = {'fill': 'toself', 'fillcolor': '#222', 'line_color': '#222', 'opacity': 0.2},
    obs_ci_kwargs: dict = {'line': {'dash': 'dash', 'color': '#222'}, 'opacity': 0.5},
)

Add statistical model fit with confidence intervals to Plotly figure.

Fits a statsmodels regression model (OLS, GLM, etc.) to the data and adds the fitted line, confidence intervals, and optional prediction intervals to an existing Plotly figure.

Parameters

Name Type Description Default
fig plotly.graph_objects.Figure Plotly figure to add the fit to. required
x np.ndarray 1D array of independent variable values (predictor). required
y np.ndarray 1D array of dependent variable values (response). required
fitfunc Callable Statsmodels model class to use for fitting (e.g., sm.OLS, sm.GLM). statsmodels.api.OLS
row int or None Subplot row index (1-based) to add fit to. None adds to main plot. None
col int or None Subplot column index (1-based) to add fit to. None adds to main plot. None
ci_alpha float Significance level for confidence intervals. Default 0.05 gives 95% confidence intervals (2.5% and 97.5% quantiles). 0.05
show_ci bool If True, display shaded confidence interval around the fit line. True
show_obs_ci bool If True, display dashed lines for prediction interval (observation CI). False
line_kwargs dict Keyword arguments passed to the fit line scatter trace. {'line': {'color': '#222'}}
ci_kwargs dict Keyword arguments passed to the confidence interval filled area. Default creates semi-transparent shaded region. {'fill': 'toself', 'fillcolor': '#222', 'line_color': '#222', 'opacity': 0.2}
obs_ci_kwargs dict Keyword arguments passed to the observation CI (prediction interval) lines. Default creates dashed lines. {'line': {'dash': 'dash', 'color': '#222'}, 'opacity': 0.5}

Returns

Name Type Description
plotly.graph_objects.Figure Modified figure with added fit line and optional confidence intervals.

Examples

>>> import numpy as np
>>> import plotly.express as px
>>> import statsmodels.api as sm
>>> # Create sample data with linear relationship + noise
>>> x = np.linspace(0, 10, 50)
>>> y = 2 * x + 5 + np.random.normal(0, 2, 50)
>>> # Create scatter plot
>>> fig = px.scatter(x=x, y=y)
>>> # Add OLS fit with 95% CI
>>> fig = add_statsmodel_fit(fig, x=x, y=y, show_ci=True)
>>> # fig.show()
>>> # Add fit to specific subplot with prediction interval
>>> from plotly.subplots import make_subplots
>>> fig = make_subplots(rows=1, cols=2)
>>> fig.add_scatter(x=x, y=y, mode='markers', row=1, col=1)
>>> fig = add_statsmodel_fit(
...     fig, x=x, y=y, row=1, col=1,
...     show_ci=True, show_obs_ci=True
... )

Notes

  • The function automatically adds a constant (intercept) to the model
  • x-values are sorted before fitting to ensure proper line rendering
  • Confidence interval (CI) represents uncertainty in the mean prediction
  • Observation CI (prediction interval) represents uncertainty for new observations
  • Custom statsmodels models can be used via the fitfunc parameter