add_statsmodel_fit
plotly.stats.add_statsmodel_fit(
fig: go.Figure,
x: np.ndarray,
y: np.ndarray,
fitfunc: Callable = sm.OLS,
row: int | None = None,
col: int | None = None,
ci_alpha: float = 0.05,
show_ci: bool = True,
show_obs_ci: bool = False,
line_kwargs: dict = {'line': {'color': '#222'}},
ci_kwargs: dict = {'fill': 'toself', 'fillcolor': '#222', 'line_color': '#222', 'opacity': 0.2},
obs_ci_kwargs: dict = {'line': {'dash': 'dash', 'color': '#222'}, 'opacity': 0.5},
)Add statistical model fit with confidence intervals to Plotly figure.
Fits a statsmodels regression model (OLS, GLM, etc.) to the data and adds the fitted line, confidence intervals, and optional prediction intervals to an existing Plotly figure.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| fig | plotly.graph_objects.Figure | Plotly figure to add the fit to. | required |
| x | np.ndarray | 1D array of independent variable values (predictor). | required |
| y | np.ndarray | 1D array of dependent variable values (response). | required |
| fitfunc | Callable | Statsmodels model class to use for fitting (e.g., sm.OLS, sm.GLM). | statsmodels.api.OLS |
| row | int or None | Subplot row index (1-based) to add fit to. None adds to main plot. | None |
| col | int or None | Subplot column index (1-based) to add fit to. None adds to main plot. | None |
| ci_alpha | float | Significance level for confidence intervals. Default 0.05 gives 95% confidence intervals (2.5% and 97.5% quantiles). | 0.05 |
| show_ci | bool | If True, display shaded confidence interval around the fit line. | True |
| show_obs_ci | bool | If True, display dashed lines for prediction interval (observation CI). | False |
| line_kwargs | dict | Keyword arguments passed to the fit line scatter trace. | {'line': {'color': '#222'}} |
| ci_kwargs | dict | Keyword arguments passed to the confidence interval filled area. Default creates semi-transparent shaded region. | {'fill': 'toself', 'fillcolor': '#222', 'line_color': '#222', 'opacity': 0.2} |
| obs_ci_kwargs | dict | Keyword arguments passed to the observation CI (prediction interval) lines. Default creates dashed lines. | {'line': {'dash': 'dash', 'color': '#222'}, 'opacity': 0.5} |
Returns
| Name | Type | Description |
|---|---|---|
| plotly.graph_objects.Figure | Modified figure with added fit line and optional confidence intervals. |
Examples
>>> import numpy as np
>>> import plotly.express as px
>>> import statsmodels.api as sm
>>> # Create sample data with linear relationship + noise
>>> x = np.linspace(0, 10, 50)
>>> y = 2 * x + 5 + np.random.normal(0, 2, 50)
>>> # Create scatter plot
>>> fig = px.scatter(x=x, y=y)
>>> # Add OLS fit with 95% CI
>>> fig = add_statsmodel_fit(fig, x=x, y=y, show_ci=True)
>>> # fig.show()>>> # Add fit to specific subplot with prediction interval
>>> from plotly.subplots import make_subplots
>>> fig = make_subplots(rows=1, cols=2)
>>> fig.add_scatter(x=x, y=y, mode='markers', row=1, col=1)
>>> fig = add_statsmodel_fit(
... fig, x=x, y=y, row=1, col=1,
... show_ci=True, show_obs_ci=True
... )Notes
- The function automatically adds a constant (intercept) to the model
- x-values are sorted before fitting to ensure proper line rendering
- Confidence interval (CI) represents uncertainty in the mean prediction
- Observation CI (prediction interval) represents uncertainty for new observations
- Custom statsmodels models can be used via the fitfunc parameter