-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MNT fix same y_init used for different sessions (#37)
* MNT fix same y_init used for different sessions * DOC add example of msa NARX
- Loading branch information
1 parent
711673e
commit 0b38856
Showing
10 changed files
with
1,082 additions
and
1,182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
.. currentmodule:: fastcan.narx | ||
|
||
.. _narx: | ||
|
||
======================================================= | ||
Reduced polynomial NARX model for system identification | ||
======================================================= | ||
|
||
:class:`NARX` is a class to build a reduced polynomial NARX (Nonlinear AutoRegressive eXogenous) model for system identification. | ||
The structure of a polynomial NARX model is shown below: | ||
|
||
.. math:: | ||
y(k) = 0.5y(k-1) + 0.3u_0(k)^2 + 2u_0(k-1)u_0(k-3) + 1.5u_0(k-2)u_1(k-3) + 1 | ||
where :math:`k` is the time index, :math:`u_0` and :math:`u_1` are input signals, and :math:`y` is the output signal. | ||
|
||
|
||
The NARX model is composed of three parts: | ||
|
||
#. Terms: :math:`y(k-1)`, :math:`u_0(k)^2`, :math:`u_0(k-1)u_0(k-3)`, and :math:`u_0(k-2)u_1(k-3)` | ||
#. Coefficients: 0.5, 0.3, 2, and 1.5 | ||
#. Intercept: 1 | ||
|
||
|
||
To build a reduced polynomial NARX model, we normally need the following steps: | ||
|
||
#. Build term library | ||
#. Get time shifted variables: e.g., :math:`u(t-2)` and :math:`y(t-1)` | ||
#. Get terms by nonlinearising the time shifted variables with polynomial basis, e.g., :math:`u_0(k-1)u_0(k-3)` and :math:`u_0(k-2)u_1(k-3)` | ||
#. Search useful terms by feature selection, i.e., `Reduced Order Modelling (ROM) <https://www.mathworks.com/discovery/reduced-order-modeling.html>`_ | ||
#. Learn the coefficients by least-squares | ||
|
||
.. rubric:: Examples | ||
|
||
* See :ref:`sphx_glr_auto_examples_plot_narx.py` for an example of the basic usage of the NARX model. | ||
|
||
Multiple time series and missing values | ||
======================================= | ||
|
||
For time series modelling, the continuity of the data is important. | ||
The continuity means all neighbouring samples have the same time interval. | ||
A discontinuous time series can be treated as multiple pieces of continuous time series. | ||
The discontinuity can be caused by | ||
|
||
#. Different measurement sessions, e.g., one measurement session is carried out at Monday and | ||
lasts for one hour at 1 Hz, and another measurement is carried out at Tuesday and lasts for two hours at 1 Hz | ||
#. Missing values | ||
|
||
No matter how the discontinuity is caused, :class:`NARX` treats the discontinuous time series as multiple pieces of continuous time series in the same way. | ||
|
||
.. note:: | ||
It is easy to notify :class:`NARX` that the time series data are from different measurement sessions by inserting np.nan to break the data. For example, | ||
|
||
>>> import numpy as np | ||
>>> x0 = np.zeros((3, 2)) # First measurement session has 3 samples with 2 features | ||
>>> x1 = np.ones((5, 2)) # Second measurement session has 5 samples with 2 features | ||
>>> u = np.r_[x0, [[np.nan, np.nan]], x1] # Insert np.nan to break the two measurement sessions | ||
|
||
It is important to break the different measurement sessions by np.nan, because otherwise, | ||
the model will assume the time interval between the two measurement sessions is the same as the time interval within a session. | ||
|
||
|
||
One-step-ahead and multiple-step-ahead prediction | ||
================================================= | ||
|
||
In system identification, two types of predictions can be made by NARX models: | ||
|
||
#. One-step-ahead prediction: predict the output at the next time step given the measured input and the measured output at the past time steps, e.g., :math:`\hat{y}(k) = 0.5y(k-1) + 0.3u_0(k-1)^2`, where :math:`\hat{y}(k)` is the predicted output at time step :math:`k` | ||
#. Multiple-step-ahead prediction: predict the output at the next time step given the measured input and the predicted output at the past time steps, e.g., :math:`\hat{y}(k) = 0.5\hat{y}(k-1) + 0.3u_0(k-1)^2` | ||
|
||
Obviously, the multiple-step-ahead prediction is more challenging, because the predicted output is used as the input to predict the output at the next step, which may accumulate the error. | ||
The `prediction` method of :class:`NARX` gives the multiple-step-ahead prediction. | ||
|
||
It should also be noted the different types of predictions in model training. | ||
|
||
#. If assume the NARX model is one-step-ahead prediction structure, we know all input data for training in advance, | ||
and the model can be trained by the simple ordinary least-squares (OLS) method | ||
#. If assume the NARX model is a multiple-step-ahead prediction structure, the input data, like :math:`\hat{y}(k-1)` is | ||
unknown in advance. Therefore, the training data must first be generated by the multiple-step-ahead prediction with | ||
the initial model coefficients, and then the coefficients can be updated recursively. | ||
|
||
ARX and OE model | ||
---------------- | ||
|
||
To better understant the two types of training, it is helpful to know two linear time series model structures, | ||
i.e., `ARX (AutoRegressive eXogenous) model <https://www.mathworks.com/help/ident/ref/arx.html>`_ and | ||
`OE (output error) model <https://www.mathworks.com/help/ident/ref/oe.html>`_. | ||
|
||
The ARX model structure is | ||
|
||
.. math:: | ||
(1-A(z))y(t) = B(z)u(t) + e(t) | ||
where :math:`A(z)` and :math:`B(z)` are polynomials of :math:`z`, e.g., | ||
:math:`A(z) = 1.5 z^{-1} - 0.7 z^{-2}`, :math:`z` is a complex number from `Z-transformation <https://en.wikipedia.org/wiki/Z-transform>`_ | ||
and :math:`z^{-1}` can be view as a time-shift operator, and :math:`e(t)` is the white noise. | ||
|
||
The OE model structure is | ||
|
||
.. math:: | ||
y(t) = \frac{B(z)}{(1-A(z))} u(t) + e(t) | ||
The ARX model and OE model is very similar. If we rewrite ARX model to :math:`y(t) = \frac{B(z)}{(1-A(z))} u(t) + \frac{1}{(1-A(z))} e(t)`, | ||
we can see that both the ARX model and the OE model have the same transfer function :math:`\frac{B(z)}{(1-A(z))}`. | ||
The two key differences are: | ||
|
||
#. noise assumption | ||
#. `best linear unbiased estimator (BLUE) <https://en.wikipedia.org/wiki/Gauss%E2%80%93Markov_theorem>`_ | ||
|
||
For noise assumption, the noise of the ARX model passes through a filter :math:`\frac{1}{(1-A(z))}` while the noise of the OE model not. | ||
The structure of ARX models implies that the measurement noise is a very special colored noise, which is not common in practice. | ||
|
||
For difference in BLUE, to get BLUE of models, we need to rewrite models to the form of | ||
|
||
.. math:: | ||
e(t) = y(t) - \hat{y}(t) | ||
where :math:`\hat{y}(t)` is a predictor, which describes how to compute the predicted output to make the error white. | ||
It is easy to find that for a ARX model | ||
|
||
.. math:: | ||
\hat{y}(t) = A(z)y(t) + B(z)u(t) | ||
while for an OE model | ||
|
||
.. math:: | ||
\hat{y}(t) = \frac{B(z)}{(1-A(z))} u(t) = A(z)\hat{y}(t) + B(z)u(t) | ||
The difference between the two predictors is the one for ARX models uses both measured input and output to make prediction, | ||
while the one for OE models only uses measured input to make prediction. | ||
In other words, the predictor of the BLUE for ARX computes one-step-ahead prediction, | ||
while the predictor of the BLUE for OE computes multiple-step-ahead prediction. | ||
|
||
In summary, using multiple-step-ahead prediction in model training is more aligned with the real noise condition, but harder to train. | ||
However, using one-step-ahead prediction in model training is easier to train, but the noise assumption is too strong to be true. | ||
The `fit` method of :class:`NARX` uses the one-step-ahead prediction for training by default, | ||
but the multiple-step-ahead prediction can be used by setting the parameter `coef_init`. | ||
|
||
.. rubric:: Examples | ||
|
||
* See :ref:`sphx_glr_auto_examples_plot_narx_msa.py` for an illustration of the multiple-step-ahead NARX model. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,5 @@ User Guide | |
unsupervised.rst | ||
multioutput.rst | ||
redundancy.rst | ||
ols_and_omp.rst | ||
ols_and_omp.rst | ||
narx.rst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
""" | ||
=========================== | ||
Multi-step-ahead NARX model | ||
=========================== | ||
.. currentmodule:: fastcan | ||
In this example, we will compare one-step-ahead NARX and multi-step-ahead NARX. | ||
""" | ||
|
||
# Authors: Sikai Zhang | ||
# SPDX-License-Identifier: MIT | ||
|
||
# %% | ||
# Nonlinear system | ||
# ---------------- | ||
# | ||
# `Duffing equation <https://en.wikipedia.org/wiki/Duffing_equation>` is used to | ||
# generate simulated data. The mathematical model is given by | ||
# | ||
# .. math:: | ||
# \ddot{y} + 0.1\dot{y} - y + 0.25y^3 = u | ||
# | ||
# where :math:`y` is the output signal and :math:`u` is the input signal, which is | ||
# :math:`u(t) = 2.5\cos(2\pi t)`. | ||
# | ||
# The phase portraits of the Duffing equation are shown below. | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from scipy.integrate import odeint | ||
|
||
|
||
def duffing_equation(y, t): | ||
"""Non-autonomous system""" | ||
y1, y2 = y | ||
u = 2.5 * np.cos(2 * np.pi * t) | ||
dydt = [y2, -0.1 * y2 + y1 - 0.25 * y1**3 + u] | ||
return dydt | ||
|
||
|
||
def auto_duffing_equation(y, t): | ||
"""Autonomous system""" | ||
y1, y2 = y | ||
dydt = [y2, -0.1 * y2 + y1 - 0.25 * y1**3] | ||
return dydt | ||
|
||
|
||
dur = 10 | ||
n_samples = 1000 | ||
|
||
y0 = None | ||
if y0 is None: | ||
n_init = 10 | ||
x0 = np.linspace(0, 2, n_init) | ||
y0_y = np.cos(np.pi * x0) | ||
y0_x = np.sin(np.pi * x0) | ||
y0 = np.c_[y0_x, y0_y] | ||
else: | ||
n_init = len(y0) | ||
|
||
t = np.linspace(0, dur, n_samples) | ||
sol = np.zeros((n_init, n_samples, 2)) | ||
for i in range(n_init): | ||
sol[i] = odeint(auto_duffing_equation, y0[i], t) | ||
|
||
for i in range(n_init): | ||
plt.plot(sol[i, :, 0], sol[i, :, 1], c="tab:blue") | ||
|
||
plt.title("Phase portraits of Duffing equation") | ||
plt.xlabel("y(t)") | ||
plt.ylabel("dy/dt(t)") | ||
plt.show() | ||
|
||
# %% | ||
# Generate training-test data | ||
# --------------------------- | ||
# | ||
# In the phase portraits, it is shown that the system has two stable equilibria. | ||
# We use one to generate training data and the other to generate test data. | ||
|
||
dur = 10 | ||
n_samples = 1000 | ||
|
||
t = np.linspace(0, dur, n_samples) | ||
|
||
sol = odeint(duffing_equation, [0.6, 0.8], t) | ||
u_train = 2.5 * np.cos(2 * np.pi * t).reshape(-1, 1) | ||
y_train = sol[:, 0] | ||
|
||
sol = odeint(auto_duffing_equation, [0.6, -0.8], t) | ||
u_test = 2.5 * np.cos(2 * np.pi * t).reshape(-1, 1) | ||
y_test = sol[:, 0] | ||
|
||
# %% | ||
# One-step-head VS. multi-step-ahead NARX | ||
# --------------------------------------- | ||
# | ||
# First, we use :meth:`make_narx` to obtain the reduced NARX model. | ||
# Then, the NARX model will be fitted with one-step-ahead predictor and | ||
# multi-step-ahead predictor, respectively. Generally, the training of one-step-ahead | ||
# (OSA) NARX is faster, while the multi-step-ahead (MSA) NARX is more accurate. | ||
|
||
from sklearn.metrics import r2_score | ||
|
||
from fastcan.narx import make_narx | ||
|
||
max_delay = 2 | ||
|
||
narx_model = make_narx( | ||
X=u_train, | ||
y=y_train, | ||
n_features_to_select=10, | ||
max_delay=max_delay, | ||
poly_degree=3, | ||
verbose=0, | ||
) | ||
|
||
|
||
def plot_prediction(ax, t, y_true, y_pred, title): | ||
ax.plot(t, y_true, label="true") | ||
ax.plot(t, y_pred, label="predicted") | ||
ax.legend() | ||
ax.set_title(f"{title} (R2: {r2_score(y_true, y_pred):.5f})") | ||
ax.set_xlabel("t (s)") | ||
ax.set_ylabel("y(t)") | ||
|
||
|
||
narx_model.fit(u_train, y_train) | ||
y_train_osa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay]) | ||
y_test_osa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay]) | ||
|
||
narx_model.fit(u_train, y_train, coef_init="one_step_ahead") | ||
y_train_msa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay]) | ||
y_test_msa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay]) | ||
|
||
fig, ax = plt.subplots(2, 2, figsize=(8, 6)) | ||
plot_prediction(ax[0, 0], t, y_train, y_train_osa_pred, "OSA NARX on Train") | ||
plot_prediction(ax[0, 1], t, y_train, y_train_msa_pred, "MSA NARX on Train") | ||
plot_prediction(ax[1, 0], t, y_test, y_test_osa_pred, "OSA NARX on Test") | ||
plot_prediction(ax[1, 1], t, y_test, y_test_msa_pred, "MSA NARX on Test") | ||
fig.tight_layout() | ||
plt.show() | ||
|
||
|
||
# %% | ||
# Multiple measurement sessions | ||
# ----------------------------- | ||
# | ||
# The plot above shows that the NARX model cannot capture the dynamics at | ||
# the left equilibrium shown in the phase portraits. To improve the performance, let us | ||
# combine the training and test data for model training to include the dynamics of both | ||
# equilibria. Here, we need to insert `np.nan` to indicate the model that training data | ||
# and test data are from different measurement sessions. The plot shows that the | ||
# prediction performance of the NARX on test data has been largely improved. | ||
|
||
u_all = np.r_[u_train, [[np.nan]], u_test] | ||
y_all = np.r_[y_train, [np.nan], y_test] | ||
narx_model = make_narx( | ||
X=u_all, | ||
y=y_all, | ||
n_features_to_select=10, | ||
max_delay=max_delay, | ||
poly_degree=3, | ||
verbose=0, | ||
) | ||
|
||
narx_model.fit(u_all, y_all) | ||
y_train_osa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay]) | ||
y_test_osa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay]) | ||
|
||
narx_model.fit(u_all, y_all, coef_init="one_step_ahead") | ||
y_train_msa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay]) | ||
y_test_msa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay]) | ||
|
||
fig, ax = plt.subplots(2, 2, figsize=(8, 6)) | ||
plot_prediction(ax[0, 0], t, y_train, y_train_osa_pred, "OSA NARX on Train") | ||
plot_prediction(ax[0, 1], t, y_train, y_train_msa_pred, "MSA NARX on Train") | ||
plot_prediction(ax[1, 0], t, y_test, y_test_osa_pred, "OSA NARX on Test") | ||
plot_prediction(ax[1, 1], t, y_test, y_test_msa_pred, "MSA NARX on Test") | ||
fig.tight_layout() | ||
plt.show() |
Oops, something went wrong.