Source code for oats.models.reconstruction.vae
"""
Variational Auto-Encoder (VAE)
-----------------
"""
from pyod.models.vae import VAE
from oats.models._pyod_model import PyODModel
[docs]class VAEModel(PyODModel):
"""VAE Model (Variational Auto-encover)
Using reconstruction error of the trained encoder-decoder network as anomaly scores.
Reference: https://pyod.readthedocs.io/en/latest/_modules/pyod/models/vae.html
"""
def __init__(self, window=10, **kwargs):
"""
initialization also accepts any parameters used by: https://pyod.readthedocs.io/en/latest/_modules/pyod/models/vae.html
Args:
window (int, optional): _description_. Defaults to 10.
"""
model_cls = VAE
self.window = window
super().__init__(model_cls, window, **kwargs)
[docs] def fit(self, train_data, **kwargs):
n_feat = (
train_data.shape[1]
if train_data.ndim > 1 and train_data.shape[1] > 1
else 1
)
if not self.params.get("encoder_neurons"):
self.params["encoder_neurons"] = [
n_feat * self.window,
n_feat * self.window // 2,
n_feat * self.window // 4,
]
if not self.params.get("decoder_neurons"):
self.params["decoder_neurons"] = [
n_feat * self.window // 4,
n_feat * self.window // 2,
n_feat * self.window,
]
self.model = VAE(**self.params)
super().fit(train_data, **kwargs)