Forecasting cropland vegetation condition¶
Keywords data used; sentinel 2, data used; crop_mask, :index: data methods; forecasting, :index: data methods; autoregression
Background¶
This notebook conducts timeseries forecasting of vegetation condition (NDVI) using SARIMAX, a variation on autoregressivemovingaverage (ARMA) models which includes an integrated (I) component to difference the timeseries so it becomes stationary, a seasonal (S) component, and has the capacity to consider exogenous (X) variables.
In this example, we will conduct a forecast on a univariate NDVI timeseries. That is, our forecast will be built on temporal patterns in NDVI. Conversely, multivariate approaches can account for influences of variables such as soil moisture and rainfall.
Description¶
In this notebook, we generate a NDVI timeseries from Sentinel2, then use it develop a forecasting algorithm.
The following steps are taken:
Load Sentinel2 data and calculate NDVI.
Mask NDVI to cropland using the crop mask.
Iterate through SARIMAX parameters and conduct model selection based on crossvalidation.
Inspect model diagnostics
Forecast NDVI into the future and visualise the result.
Load packages¶
Import Python packages that are used for the analysis.
Important note: Scipy has updated and has some incompatibilities with old versions of statsmodels. If the loading packages cell below returns an error, try running pip install statsmodels
or pip install statsmodels upgrade
in a code cell, then load the packages again.
[1]:
%matplotlib inline
from itertools import product
import datacube
import numpy as np
import pandas as pd
import statsmodels.api as sm
import xarray as xr
from datacube import Datacube
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.datahandling import load_ard
from deafrica_tools.plotting import display_map
from matplotlib import pyplot as plt
from statsmodels.tools.eval_measures import rmse
from tqdm.notebook import tqdm
Set up a Dask cluster¶
Dask can be used to better manage memory use down and conduct the analysis in parallel. For an introduction to using Dask with Digital Earth Africa, see the Dask notebook.
Note: We recommend opening the Dask processing window to view the different computations that are being executed; to do this, see the Dask dashboard in DE Africa section of the Dask notebook.
To use Dask, set up the local computing cluster using the cell below.
[2]:
create_local_dask_cluster()
Client

Cluster

Connect to the datacube¶
[3]:
dc = datacube.Datacube(app="NDVI_forecast")
Analysis parameters¶
lat
,lon
: The central latitude and longitude to analyse. In this example we’ll use an agricultural area in Ethiopia.buffer
: The number of square degrees to load around the central latitude and longitude. For reasonable loading times, set this as 0.1 or lower.products
: The satellite data to load, in the example we will use Sentinel2.time_range
: The date range to analyse. The longer the daterange, the more data the model has to derive patterns in the NDVI timeseries.freq
: The frequency we want to resample the timeseries to e.g. for monthly time steps use'1M'
, for fortinightly use'2W'
.forecast_length
: The length of time beyond the latest observation in the dataset that we want the model to forecast, expressed in units of resample frequencyfreq
. A longerforecast_length
means greater forecast uncertainty.resolution
: The pixel resolution (in metres) to use for loading Sentinel2 data.dask_chunks
: How to chunk the datasets to work with dask.
[4]:
# Define the analysis region (LatLon box)
lat, lon = 8.5593, 40.6975
buffer = 0.04
# the satellite product to load
products = "s2_l2a"
# Define the time window for defining the model
time_range = ("20170101", "202201")
# resample frequency
freq = "1M"
# number of timesteps to forecast (in units of `freq`)
forecast_length = 6
# resolution of Sentinel2 pixels
resolution = (20, 20)
# dask chunk sizes
dask_chunks = {"x": 500, "y": 500, "time": 1}
Display analysis area on an interactive map¶
[5]:
lon = (lon  buffer, lon + buffer)
lat = (lat  buffer, lat + buffer)
display_map(lon, lat)
[5]:
Load the satellite data¶
Using the parameters we defined above.
[6]:
# set up a datcube query object
query = {
"x": lon,
"y": lat,
"time": time_range,
"measurements": ["red", "nir"],
"output_crs": "EPSG:6933",
"resolution": resolution,
"resampling": {"fmask": "nearest", "*": "bilinear"},
}
[7]:
# load the satellite data
ds = load_ard(dc=dc, dask_chunks=dask_chunks, products=products, **query)
print(ds)
Using pixel quality parameters for Sentinel 2
Finding datasets
s2_l2a
Applying pixel quality/cloud mask
Returning 702 time steps as a dask array
<xarray.Dataset>
Dimensions: (time: 702, y: 505, x: 387)
Coordinates:
* time (time) datetime64[ns] 20170106T07:42:19 ... 20220130T07:...
* y (y) float64 1.093e+06 1.093e+06 ... 1.083e+06 1.083e+06
* x (x) float64 3.923e+06 3.923e+06 ... 3.931e+06 3.931e+06
spatial_ref int32 6933
Data variables:
red (time, y, x) float32 dask.array<chunksize=(702, 500, 387), meta=np.ndarray>
nir (time, y, x) float32 dask.array<chunksize=(702, 500, 387), meta=np.ndarray>
Attributes:
crs: EPSG:6933
grid_mapping: spatial_ref
Mask region with DE Africa’s cropland extent map¶
Load the cropland mask over the region of interest. The default region we’re analysing is in Ethiopia, so we need to load either the crop_mask product which covers the entire African continent, or the crop_mask_eastern product, which cover the countries of Ethiopia, Kenya, Tanzania, Rwanda, and Burundi. If you change the analysis region from the default one, you may need to load a different crop mask  see the docs page to find out more.
[8]:
cm = dc.load(
product="crop_mask",
time=("2019"),
measurements="filtered",
resampling="nearest",
like=ds.geobox,
).filtered.squeeze()
# convert the missing values (255) to NaN
cm = cm.where(cm != 255)
cm.plot.imshow(add_colorbar=False, figsize=(6, 6))
plt.title("Cropland Extent");
Now we will use the cropland map to mask the regions in the Sentinel2 data that only have cropping.
[9]:
ds = ds.where(cm == 1)
Calculate NDVI and clean the timeseries¶
After calculating NDVI, we will smooth and interpolate the data to ensure we are working with a consistent timeseries. This is a very important step in the workflow and there are many ways to smooth, interpolate, gapfill, remove outliers, or curvefit the data to ensure a consistent timeseries. If not using the default example, you may have to define additional methods to those used here.
To do this we take two steps:
Resample the data to monthly timesteps using the mean
Calculate a rolling mean with a window of 4 steps
[10]:
# calculate NDVI
ndvi = calculate_indices(ds, "NDVI", drop=True, satellite_mission="s2")
Dropping bands ['red', 'nir']
[11]:
# resample and smooth
window = 4
ndvi = (
ndvi.resample(time=freq)
.mean()
.rolling(time=window, min_periods=1, center=True)
.mean()
)
Reduce the timeseries to one dimension¶
In this example, we’re generating a forecast on a simple 1D timeseries. This timeseries represents the spatially averaged NDVI at each timestep in the series.
In this step, all the calculations above are triggered and the dataset is brought into memory so this step can take a few minutes to complete.
[12]:
ndvi = ndvi.mean(["x", "y"])
ndvi = ndvi.NDVI.compute()
CPLReleaseMutex: Error = 1 (Operation not permitted)
Plot the NDVI timeseries¶
[13]:
ndvi.plot(figsize=(11, 5), linestyle="dashed", marker="o")
plt.title("NDVI")
plt.ylim(0, ndvi.max().values + 0.05);
Split data and fit a model¶
Crossvalidation is a common method for evaluating model performance. It involves dividing data into a training set on which the model is trained, and test (or validation) set, to which the model is applied to produce predictions which are compared against actual values (that weren’t used in model training).
[14]:
ndvi = ndvi.drop("spatial_ref").to_dataframe()
train_data = ndvi["NDVI"][
: len(ndvi)  10
] # remove the last ten observations and keep them as test data
test_data = ndvi["NDVI"][len(ndvi)  10 :]
Iteratively find the best parameters for the SARIMAX model¶
SARIMAX models are fitted with parameters for both the trend and seasonal components of the timeseries. The parameters can be defined as: * Trend elements * p: Autoregression order. This is the number of immediately preceding values in the series that are used to predict the value at the present time. * d: Difference order. The number of times that differencing is performed is called the difference order. * q: Moving average order. The size of the moving average window. * Seasonal elements are as above, but for the seasonal component of the timeseries. * P * D * Q * We also need to define the length of season. * s: In this case we use 6, which is in units of resample frequency so refers to 6 months.
In the cell below, initial values and a range are given for the parameters above. Using range(0, 3)
means the values 0, 1, and 2 are iterated through for each of p, d, q and P, D, Q. This means that there are \(3^2 \times 3^2 = 81\) possible combinations.
[15]:
# Set initial values and some bounds
p = range(0, 3)
d = 1
q = range(0, 3)
P = range(0, 3)
D = 1
Q = range(0, 3)
s = 6
# Create a list with all possible combinations of parameters
parameters = product(p, q, P, Q)
parameters_list = list(parameters)
print("Number of iterations to run:", len(parameters_list))
# Train many SARIMA models to find the best set of parameters
def optimize_SARIMA(parameters_list, d, D, s):
"""
Return an ordered (ascending RMSE) dataframe with parameters,
corresponding AIC, and RMSE.
parameters_list  list with (p, q, P, Q) tuples
d  integration order
D  seasonal integration order
s  length of season
"""
results = []
best_aic = float("inf")
for param in tqdm(parameters_list):
try:
import warnings
warnings.filterwarnings("ignore")
model = sm.tsa.statespace.SARIMAX(
train_data,
order=(param[0], d, param[1]),
seasonal_order=(param[2], D, param[3], s),
).fit(disp=1)
pred = model.predict(start=len(train_data), end=(len(ndvi)  1))
error = rmse(pred, test_data)
except:
continue
aic = model.aic
# Save best model, AIC and parameters
if aic < best_aic:
best_model = model
best_aic = aic
best_param = param
results.append([param, model.aic, error])
result_table = pd.DataFrame(results)
result_table.columns = ["parameters", "aic", "error"]
# Sort in ascending order, lower AIC is better
result_table = result_table.sort_values(by="error", ascending=True).reset_index(
drop=True
)
return result_table
Number of iterations to run: 81
Now will will run the function above for every iteration of parameters we have defined. Depending on the number of iterations, this can take a few minutes to run. A progress bar is printed below.
[16]:
# run the function above
result_table = optimize_SARIMA(parameters_list, d, D, s)
Model selection¶
The rootmeansquare error (RMSE) is a common metric used to evaluate model or forecast performance. It is the standard deviation of residuals (difference between forecast and actual value) expressed in units of the variable of interest e.g. NDVI. We can calculate RMSE of our forecast because we withheld some observations as test or validation data.
We can also use either the Akaike information criterion (AIC) or Bayesian information criterion (BIC) for model selection. Both these criteria aim to optimise the tradeoff between goodness of fit and model simplicity. We are aiming to find the model that can explain the most variation in the timeseries with the least complexity, as added complexity may lead to overfitting. The BIC penalises additional parameters (greater complexity) more than the AIC.
There are different schools of thought on which criterion to use. A general rule of thumb is that the BIC should be used for inference and interpretation whereas the AIC should be used for prediction. As our goal is prediction (forecasting), we could select the model with the lowest AIC, though this approach is often reserved for when there is no test data available for crossvalidation.
The cell below presents the top 15 models based on AIC and the RMSE on the crossvalidation.
[17]:
# Sort table by the lowest AIC (Akaike Information Criteria) where the RMSE is low
result_table = result_table.sort_values("aic").sort_values("error")
print(result_table[0:15])
parameters aic error
0 (0, 0, 1, 2) 167.730806 0.018184
1 (0, 1, 2, 2) 187.450473 0.018201
2 (0, 0, 2, 1) 163.183898 0.018655
3 (0, 1, 1, 1) 186.749586 0.019855
4 (2, 2, 0, 2) 203.806815 0.020542
5 (0, 0, 2, 2) 166.948974 0.021318
6 (2, 2, 1, 2) 201.434439 0.023771
7 (2, 2, 0, 1) 205.304806 0.023861
8 (0, 1, 1, 2) 188.009648 0.024562
9 (0, 0, 1, 1) 163.806008 0.024623
10 (0, 1, 2, 1) 185.070702 0.026774
11 (2, 1, 0, 0) 193.296201 0.027527
12 (1, 1, 2, 2) 191.548041 0.028136
13 (0, 2, 1, 2) 190.940525 0.029760
14 (0, 2, 2, 2) 188.050115 0.033009
Select model and predict¶
In the cell below. We wil select a model from the list above. In this case we’ve selected model 0
as it has the lowest RMSE, though you can select any model by setting the index number in the cell below using the model_sel_index
parameter.
[18]:
# selected model
model_sel_index = 0
# store parameters from selected model
p, q, P, Q = result_table.iloc[model_sel_index].parameters
print(result_table.iloc[model_sel_index])
# fit the model with the parameters identified above
best_model = sm.tsa.statespace.SARIMAX(
train_data, order=(p, d, q), seasonal_order=(P, D, Q, s)
).fit(disp=1)
parameters (0, 0, 1, 2)
aic 167.730806
error 0.018184
Name: 0, dtype: object
Plot model diagnostics¶
There are some typical plots we can use to evaluate our model.
Standardised residuals (topleft) The standardised residuals are plotted against x (time) values. This allows us to check that variance (distance of residuals from 0) is constant across time values. There should be no obvious patterns.
Histogram and estimated density (topright) A kernel density estimation (KDE) is an estimated probability density function fitted on the actual distribution (histogram) of standardised residuals. A normal distribution (N (0,1)) is shown for reference. This plot shows that the distribution of our standardised residuals is close to normally distributed.
Normal quantilequantile (QQ) plot (bottomleft) This plot shows ‘expected’ or ‘theoretical’ quantiles drawn from a normal distribution on the xaxis against quantiles taken from the sample of residuals on the yaxis. If the observations in blue match the 1:1 line in red, then we can conclude that our residuals are normally distributed.
Correlogram (bottomright) The correlations for lags greater than 0 should not be statistically significant. That is, they should not be outside the blue ribbon.
Note: The QQ plot and correlogram generated for model
0
show there is some pattern in the residuals. That is, there is remaining variation in the data which the model has not accounted for. You could experiment with different parameter values or model selection in the prior steps to see if this can be addressed.
[19]:
fig = plt.figure(figsize=(16, 9))
fig = best_model.plot_diagnostics(lags=25, fig=fig)
Backtest forecast¶
We saved the last 10 observations as test data above. Now we can use our model to predict NDVI for those timesteps and compare those predictions with actual values. We can do this visually in the graph below and also quantify the error with the rootmeansquare error (RMSE).
[20]:
pred = best_model.predict(start=len(train_data), end=(len(ndvi)  1))
plt.figure(figsize=(11, 5))
pred.plot(label="forecast", linestyle="dashed", marker="o")
train_data.plot(label="training data", linestyle="dashed", marker="o")
test_data.plot(label="test data", linestyle="dashed", marker="o")
plt.legend(loc="upper left");
Plot the result of our forecast¶
To forecast NDVI into the future, we’ll run a model on the entire time series so we can include the latest observations. We can see that the forecast uncertainty, expressed as the 95% confidence interval, increases with time.
[21]:
final_model = sm.tsa.statespace.SARIMAX(
ndvi, order=(p, d, q), seasonal_order=(P, D, Q, s)
).fit(disp=1)
yhat = final_model.get_forecast(forecast_length);
[22]:
fig, ax = plt.subplots(1, 1, figsize=(11, 5))
yhat.predicted_mean.plot(label="NDVI forecast", ax=ax, linestyle="dashed", marker="o")
ax.fill_between(
yhat.predicted_mean.index,
yhat.conf_int()["lower NDVI"],
yhat.conf_int()["upper NDVI"],
alpha=0.2,
)
ndvi[36:].plot(label="Observaions", ax=ax, linestyle="dashed", marker="o")
plt.legend(loc="upper left");
Our forecast looks reasonable in the context of the timeseries above.
Additional information¶
License: The code in this notebook is licensed under the Apache License, Version 2.0. Digital Earth Australia data is licensed under the Creative Commons by Attribution 4.0 license.
Contact: If you need assistance, please post a question on the Open Data Cube Slack channel or on the GIS Stack Exchange using the opendatacube
tag (you can view previously asked questions here). If you would like to report an issue with this notebook, you can file one on
Github.
Last modified: January 2022
Compatible datacube version:
[23]:
print(datacube.__version__)
1.8.6
Last Tested:
[24]:
from datetime import datetime
datetime.today().strftime("%Y%m%d")
[24]:
'20220707'