Python Statsmodels: Using SARIMAX with exogenous regressors to get predicted mean and confidence intervals

Kishan Manani picture Kishan Manani · Sep 26, 2016 · Viewed 8.2k times · Source

I'm using statsmodels.tsa.SARIMAX() to train a model with exogenous variables. Is there an equivalent of get_prediction() when a model is trained with exogenous variables so that the object returned contains the predicted mean and confidence interval rather than just an array of predicted mean results? The predict() and forecast() methods take exogenous variables, but only return the predicted mean value.

SARIMA_model = sm.tsa.SARIMAX(endog=y_train.astype('float64'),
                          exog=ExogenousFeature_train.values.astype('float64'), 
                          order=(1,0,0),
                          seasonal_order=(2,1,0,7), 
                          simple_differencing=False)

model_results = SARIMA_model.fit()

pred = model_results.predict(start=train_end_date,
                               end=test_end_date,
                               exog=ExogenousFeature_test.values.astype('float64').reshape(343,1),
                               dynamic=False)

pred here is an array of predicted values rather than an object containing predicted mean values and confidence intervals that you would get if you ran get_predict(). Note, get_predict() does not take exogenous variables.

My version of statsmodels is 0.8

Answer

Vinay Kolar picture Vinay Kolar · Oct 14, 2016

There has been some backward compatibility related issues due to which full results (with pred intervals etc) are not being exposed.

To get you what you want now: Use get_prediction and get_forecast functions with parameters described below

    pred_res = sarimax_model.get_prediction(exog=ExogenousFeature_train.values.astype('float64'), full_results=True,alpha=0.05)
    pred_means = pred_res.predicted_mean
    # Specify your prediction intervals by alpha parameter. alpha=0.05 implies 95% CI
    pred_cis = pred_res.conf_int(alpha=0.05)

    # You can then plot it (import matplotlib first)
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(1,1,1)
    #Actual data
    ax.plot(y_train.astype('float64'), '--', color="blue", label='data')
    # Means
    ax.plot(pred_means, lw=1, color="black", alpha=0.5, label='SARIMAX')
    ax.fill_between(pred_means.index, pred_cis.iloc[:, 0], pred_cis.iloc[:, 1], alpha=0.05)
    ax.legend(loc='upper right')
    plt.draw()

For more info, go to: