Error Correction Neural Network

Module

Prosper_nn provides implementations for specialized time series forecasting neural networks and related utility functions.

Copyright (C) 2022 Nico Beck, Julia Schemm, Henning Frechen, Jacob Fidorra,

Denni Schmidt, Sai Kiran Srivatsav Gollapalli

This file is part of Propser_nn.

Propser_nn is free software: you can redistribute it and/or modify

it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>.

class prosper_nn.models.ecnn.ecnn.ECNN(n_features_U: int, n_state_neurons: int, past_horizon: int, forecast_horizon: int = 1, recurrent_cell_type: str = 'elman', kwargs_recurrent_cell: dict = {}, approach: str = 'backward', learn_init_state: bool = True, n_features_Y: int = 1, future_U: bool = False)[source]

Bases: Module

The ECNN class creates a Error Correction Neural Network.

An ECNN is an extension of an RNN where the forecast error at each time step is interpreted as a reaction to external influences unknown to the model. In order to correct the forecasting, this error is used as an additional input for the next time step, ideally substituting the unknown external information.

The general architecture is given by

\[s_t = tanh(As_{t-1} + Bu_t + D(\hat{y}_{t-1} - y_{t-1})) \hat{y}_t = Cs_t\]

where $y_t$ is the target variable, $u_t$ is the explanatory feature and $s_t$ is the hidden state. $A$, $B$, $C$ and $D$ are matrices.

For forecasting where neither $u_{t+2}$ nor $y_{t+1}$ is known, it’s

\[s_{t+2} = tanh(As_{t+1}) \hat{y}_{t+2} = Cs_{t+2}\]

These formulae are implemented in the ECNNCell class, which is used to construct this ECNN model.

This implementation is based on Zimmermann HG., Neuneier R., Grothmann R. (2002) Modeling Dynamical Systems by Error Correction Neural Networks. In: Soofi A.S., Cao L. (eds) Modelling and Forecasting Financial Data. Studies in Computational Finance, vol 2. Springer, Boston, MA. https://doi.org/10.1007/978-1-4615-0931-8_12

Parameters:
  • n_features_U (int) – The number of inputs, i.e. the number of elements of U at each time step.

  • n_state_neurons (int) – The number of neurons of the hidden layer, i.e. the hidden state at each time step.

  • past_horizon (int) – The length of the sequence of inputs and outputs used for prediction.

  • forecast_horizon (int) – The forecast horizon.

  • recurrent_cell_type (str) – Possible choices: elman, lstm, gru or gru_3_variant.

  • kwargs_recurrent_cell (dict) – Parameters for the recurrent cell. Activation function can be set here.

  • approach (string) – Either “backward” or “forward”. A backward approach means that the external features at time t have a direct impact on the hidden state at time t. A forward approach means that the external features at time t only have a direct impact on the hidden state at time t+1.

  • learn_init_state (boolean) – Learn the initial hidden state or not.

  • n_features_Y (int) – The number of outputs, i.e. the number of elements of Y at each time step. The default is 1.

  • future_U (boolean) – If false, U is assumed to be only known in the past and thus have the length past_horizon. If true, U is assumed to be also known in the future, e.g. weekdays, and thus have the length past_horizon+forecast_horizon.

Return type:

None

check_sizes(U: Tensor, Y: Tensor) None[source]

Checks if U and Y have right shape.

forward(U: Tensor, Y: Tensor) Tensor[source]
Parameters:
  • U (torch.Tensor) – A batch of input features sequences for the ECNN. U should have shape=(past_horizon, batchsize, n_features_U) or shape=(past_horizon+forecast_horizon, batchsize, n_features_U), depending on future_U.

  • Y (torch.Tensor) – A batch of output sequences for the ECNN. Y should have shape=(past_horizon, batchsize, n_features_Y).

Returns:

Contains past_error, the forecasting errors along the past_horizon where Y is known, and forecast, the forecast along the forecast_horizon. Both can be used for backpropagation. shape=(past_horizon+forecast_horizon, batchsize, n_features_Y)

Return type:

torch.Tensor

repeat_init_state(batchsize)[source]
set_init_state()[source]

Example

import torch

import prosper_nn.utils.generate_time_series_data as gtsd
import prosper_nn.utils.create_input_ecnn_hcnn as ci
from prosper_nn.models.ecnn import ECNN

# Define network and data parameters
past_horizon = 10
forecast_horizon = 5
n_features_U = 2
n_features_Y = 2
n_data = 20
n_state_neurons = 4
batchsize = 1

# Initialise Error Correction Neural Network
ecnn = ECNN(n_features_U,
            n_state_neurons,
            past_horizon,
            forecast_horizon,
            n_features_Y=n_features_Y)

# Generate data
Y, U = gtsd.sample_data(n_data, n_features_Y, n_features_U)
Y_batches, U_batches = ci.create_input(
    Y=Y,
    past_horizon=past_horizon,
    batchsize=batchsize,
    U=U,
    forecast_horizon=forecast_horizon,
)

targets = torch.zeros((past_horizon, batchsize, n_features_Y))

# Train model
optimizer = torch.optim.Adam(ecnn.parameters())
loss_function = torch.nn.MSELoss()

for epoch in range(10):
    for batch_index in range(0, U_batches.shape[0]):
        U_batch = U_batches[batch_index]
        Y_batch = Y_batches[batch_index]
        model_output = ecnn(U_batch, Y_batch)
        past_error, forecast = torch.split(model_output, past_horizon)

        ecnn.zero_grad()
        loss = loss_function(past_error, targets)
        loss.backward()
        optimizer.step()

Reference

Zimmermann HG., Tietz C., Grothmann R. (2012) Forecasting with Recurrent Neural Networks: 12 Tricks. In: Montavon G., Orr G.B., Müller KR. (eds) Neural Networks: Tricks of the Trade. Lecture Notes in Computer Science, vol 7700. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-642-35289-8_37

Error Correction Neural Network Cell

Module

Prosper_nn provides implementations for specialized time series forecasting neural networks and related utility functions.

Copyright (C) 2022 Nico Beck, Julia Schemm, Henning Frechen, Jacob Fidorra,

Denni Schmidt, Sai Kiran Srivatsav Gollapalli

This file is part of Propser_nn.

Propser_nn is free software: you can redistribute it and/or modify

it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>.

class prosper_nn.models.ecnn.ecnn_cell.ECNNCell(n_features_U: int, n_state_neurons: int, n_features_Y: int = 1, recurrent_cell_type: str = 'elman', kwargs_recurrent_cell: dict = {})[source]

Bases: Module

Cell of a Error-Correction Neural Network (ECNN). It models one time step of an ECNN:

\[s_t = tanh(As_{t-1} + Bu_t + D(\hat{y}_{t-1} - y_{t-1})) \hat{y}_t = Cs_t\]
Parameters:
  • n_features_U (int) – The number of inputs, i.e. the number of elements of U at each time step.

  • n_state_neurons (int) – The number of neurons of the hidden layer, i.e. the hidden state state at each time step.

  • n_features_Y (int) – The number of outputs, i.e. the number of elements of Y at each time step. The default is 1.

  • recurrent_cell_type (str) – Select the cell for the state transition. The cells elman, lstm, gru (all from pytorch) and gru_3_variant (from prosper_nn) are supported.

  • kwargs_recurrent_cell (dict) – Parameters for the recurrent cell. Activation function can be set here.

Return type:

None

forward(state: Tensor | Tuple[Tensor], U: Tensor | None = None, Y: Tensor | None = None) Tuple[Tensor, Tensor][source]

Calculates one time step with the inputs and returns the prediction and the state of the next time step.

Parameters:
  • state (torch.Tensor) – The hidden state of the ECNN at time t-1. state should have shape=(batchsize, n_state_neurons).

  • U (torch.Tensor) – The input for the ECNN at time t (if known). U should have shape=(batchsize, n_features_U).

  • Y (torch.Tensor) – The output of the ECNN at time t-1. Y should have shape=(batchsize, n_features_Y). The Y of the last time step (if known) is used to calculate the error for the error-correction.

Returns:

Contains output, which is the error or the forecast at time t-1 and has the same dimensions as Y, and state, which is the hidden state at time t.

Return type:

tuple

get_batchsize(state)[source]

Example

ecnn_cell = model.ecnn.ECNNCell(5, 10)
state = torch.randn(1, 5)
U = torch.randn(1, 5)
Y = torch.randn(1, 1)
outputs = []
for i in range(6):
    state, output = ecnn_cell(state, U, Y)
    outputs.append(output)

Reference

Zimmermann HG., Tietz C., Grothmann R. (2012) Forecasting with Recurrent Neural Networks: 12 Tricks. In: Montavon G., Orr G.B., Müller KR. (eds) Neural Networks: Tricks of the Trade. Lecture Notes in Computer Science, vol 7700. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-642-35289-8_37

Error Correction Neural Network ECNN GRU 3 Variant

Module

Prosper_nn provides implementations for specialized time series forecasting neural networks and related utility functions.

Copyright (C) 2022 Nico Beck, Julia Schemm, Henning Frechen, Jacob Fidorra,

Denni Schmidt, Sai Kiran Srivatsav Gollapalli

This file is part of Propser_nn.

Propser_nn is free software: you can redistribute it and/or modify

it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>.

class prosper_nn.models.ecnn.gru_cell_variant.GRU_3_variant(input_size: int, hidden_size: int, activation: ~typing.Type[~torch.autograd.function.Function] = <built-in method tanh of type object>)[source]

Bases: Module

GRU_3_variant cell . State $s_t$ is calculated as:

\[s_t = (1 - \sigma(update\_vector)) \circ s_{t-1} + \sigma(update\_vector) \circ (tanh(As_{t-1} + Bx_t))\]

The implementation is similar to version 3 of the GRU variants in the following paper. One difference is that $$r_t$$ is fixed to a vector with ones in our implementation.

R. Dey and F. M. Salem, “Gate-variants of Gated Recurrent Unit (GRU) neural networks,” 2017 IEEE 60th International Midwest Symposium on Circuits and Systems (MWSCAS), Boston, MA, USA, 2017, pp. 1597-1600, doi: 10.1109/MWSCAS.2017.8053243

If the update vector has large values, the sigmoid function converges toward 1 and the architecture defaults to the regular RNNCell. On the other hand, if the update vector contains large negative values, then $s_t=s_{t-1}$ and there is total memory.

Parameters:
  • input_size (int) – The number of inputs, i.e. the number of elements of input at each time step.

  • hidden_size (int) – The number of neurons of the hidden layer, i.e. the hidden state state at each time step.

  • activation (nn.functional, optional) – The activation function that is applied on the output of the hidden layers. The same function is used on all hidden layers.

Return type:

None

forward(input: Tensor, state: Tensor) Tuple[Tensor, Tensor][source]

Forward pass of the GRU 3 variant.

Parameters:
  • input (torch.Tensor) – The input for the cell at time t. It should have shape=(batchsize, input_size).

  • state (torch.Tensor) – The hidden state at time t-1. state should have shape=(batchsize, hidden_size).

Returns:

Contains state, which is the hidden state for the next time step.

Return type:

torch.Tensor

Example

ecnn_cell = model.ecnn.GRU_3_variant(5, 10)
state = torch.randn(1, 5)
U = torch.randn(1, 5)
states = []
for i in range(6):
    state = ecnn_cell(state, U)
    states.append(state)

References

R. Dey and F. M. Salem, “Gate-variants of Gated Recurrent Unit (GRU) neural networks,” 2017 IEEE 60th International Midwest Symposium on Circuits and Systems (MWSCAS), Boston, MA, USA, 2017, pp. 1597-1600, doi: 10.1109/MWSCAS.2017.8053243

Zimmermann HG., Tietz C., Grothmann R. (2012) Forecasting with Recurrent Neural Networks: 12 Tricks. In: Montavon G., Orr G.B., Müller KR. (eds) Neural Networks: Tricks of the Trade. Lecture Notes in Computer Science, vol 7700. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-642-35289-8_37