Recurrent Neural Networks (RNN)#
- class oats.models.predictive.rnn.RNNModel(window: int = 10, n_steps: int = 1, use_gpu: bool = False, val_split: float = 0.2, rnn_model: str = 'RNN', **kwargs)[source]#
Bases:
DartsModel
Recurrent Neural Network Model
Using RNN as a predictor. Anomalies scores are deviations from predictions.
Reference: https://unit8co.github.io/darts/generated_api/darts.models.forecasting.rnn_model.html
initialization also accepts any parameters used by: https://unit8co.github.io/darts/generated_api/darts.models.forecasting.rnn_model.html
- Parameters:
window (int, optional) – rolling window size to feed into the predictor. Defaults to 10.
n_steps (int, optional) – number of steps to predict forward. Defaults to 1.
use_gpu (bool, optional) – whether to use GPU. Defaults to False.
val_split (float, optional) – proportion of data points reserved for validation. Defaults to 0.2.
rnn_model (str, optional) – RNN (vanilla RNN), LSTM, or GRU. Defaults to RNN.
- fit(train_data: ndarray[Any, dtype[Any]], epochs: int = 15, **kwargs)#
- get_scores(test_data: ndarray[Any, dtype[float32]], normalize=False, **kwargs) Tuple[ndarray[Any, dtype[float32]]] #