Temporal Fusion Transformer (TFT)#
- class oats.models.predictive.tft.TFTModel(window: int = 10, n_steps: int = 1, use_gpu: bool = 1, val_split: float = 0.2, **kwargs)[source]#
Bases:
DartsModel
TFT Model (Temporal Fusion Transformer)
Using TFT as a predictor. Anomalies scores are deviations from predictions.
Reference: https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tft_model.html
initialization also accepts any parameters used by: https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tft_model.html
- Parameters:
window (int, optional) – rolling window size to feed into the predictor. Defaults to 10.
n_steps (int, optional) – number of steps to predict forward. Defaults to 1.
use_gpu (bool, optional) – whether to use GPU. Defaults to False.
val_split (float, optional) – proportion of data points reserved for validation. Defaults to 0.2.
- fit(train_data: ndarray[Any, dtype[Any]], epochs: int = 15, **kwargs)#
- get_scores(test_data: ndarray[Any, dtype[float32]], normalize=False, **kwargs) Tuple[ndarray[Any, dtype[float32]]] #