-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add base class and some possible plot functions
- Loading branch information
Showing
3 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
""" | ||
Created on Dec. 19, 2024 | ||
@author: wangc, mandd | ||
Base Class for Anomaly Detection | ||
""" | ||
|
||
import abc | ||
from sklearn.base import BaseEstimator | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.preprocessing import RobustScaler | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class AnomalyBase(BaseEstimator): | ||
"""Anomaly detection base class | ||
""" | ||
|
||
def __init__(self, norm='robust'): | ||
"""Constructor | ||
""" | ||
self.print_tag = type(self).__name__ | ||
self.is_fitted = False | ||
self._features = None | ||
self._targets = None | ||
self._norm = norm | ||
if norm is None: | ||
logger.info('Use standard scalar (z-score) to transform input data.') | ||
self._scalar = StandardScaler() | ||
elif norm.lower() == 'robust': | ||
logger.info('Use robust scalar to transform input data.') | ||
self._scalar = RobustScaler() | ||
else: | ||
logger.warning('Unrecognized value for param "norm", using default "RobustScalar"') | ||
self._scalar = RobustScaler() | ||
self._meta = {} | ||
|
||
def reset(self): | ||
"""reset | ||
""" | ||
self.is_fitted = False | ||
self._features = None | ||
self._targets = None | ||
self._norm = 'robust' | ||
self._scalar = RobustScaler() | ||
self._meta = {} | ||
|
||
def get_params(self): | ||
""" | ||
Get parameters for this estimator. | ||
Returns | ||
------- | ||
params : dict | ||
Parameter names mapped to their values. | ||
""" | ||
params = super().get_params() | ||
return params | ||
|
||
def set_params(self, **params): | ||
"""Set the parameters of this estimator. | ||
Parameters | ||
---------- | ||
**params : dict | ||
Estimator parameters. | ||
Returns | ||
------- | ||
self : estimator instance | ||
Estimator instance. | ||
""" | ||
super().set_params(**params) | ||
|
||
|
||
def fit(self, X, y=None): | ||
"""perform fitting | ||
Args: | ||
X (array-like): (n_samples, n_features) | ||
y (array-like, optional): (n_samples, n_features). Defaults to None. | ||
""" | ||
logger.info('Train model.') | ||
self.is_fitted = True | ||
self._features = X | ||
self._targets = y | ||
X_transform = self._scalar.fit_transform(X) | ||
fit_obj = self._fit(X_transform, y) | ||
return fit_obj | ||
|
||
def predict(self, X): | ||
"""perform prediction | ||
Args: | ||
X (array-like): (n_samples, n_features) | ||
""" | ||
logger.info('Perform model forecast.') | ||
X_transform = self._scalar.fit_transform(X) | ||
y_new = self._predict(X_transform) | ||
|
||
return y_new | ||
|
||
################################################################################# | ||
# To be implemented in subclasses | ||
|
||
@abc.abstractmethod | ||
def _fit(self, X, y=None): | ||
"""perform fitting | ||
Args: | ||
X (array-like): (n_samples, n_features) | ||
y (array-like, optional): (n_samples, n_features). Defaults to None. | ||
""" | ||
|
||
|
||
@abc.abstractmethod | ||
def _predict(self, X): | ||
"""perform prediction | ||
Args: | ||
X (array-like): (n_samples, n_features) | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import numpy as np | ||
|
||
from .AnomalyBase import AnomalyBase | ||
|
||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MatrixProfile(AnomalyBase): | ||
"""_summary_ | ||
Args: | ||
AnomalyBase (_type_): _description_ | ||
""" | ||
|
||
|
||
def __init__(self, norm='robust'): | ||
"""Constructor | ||
""" | ||
super().__init__(norm=norm) | ||
|
||
|
||
|
||
def _fit(self, X, y=None): | ||
"""perform fitting | ||
Args: | ||
X (array-like): (n_samples, n_features) | ||
y (array-like, optional): ignored, (n_samples, n_features). Defaults to None. | ||
""" | ||
|
||
|
||
|
||
def _predict(self, X): | ||
"""perform prediction | ||
Args: | ||
X (array-like): (n_samples, n_features) | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import logging | ||
logger = logging.getLogger(__name__) | ||
|
||
try: | ||
from matplotlib import pyplot as plt | ||
from matplotlib.ticker import FuncFormatter | ||
except ImportError: | ||
logger.error('Importing matplotlib failed. Plotting will not work.') | ||
|
||
try: | ||
import plotly.graph_objs as go | ||
from plotly.subplots import make_subplots | ||
except ImportError: | ||
logger.error('Importing plotly failed. Interactive plots will not work.') | ||
|
||
|
||
def plot(x, y, ax=None, xlabel='time', ylabel='ds', figsize=(10,6), include_legend=False): | ||
"""_summary_ | ||
Args: | ||
ax (_type_, optional): _description_. Defaults to None. | ||
xlabel (str, optional): _description_. Defaults to 'time'. | ||
ylabel (str, optional): _description_. Defaults to 'ds'. | ||
figsize (tuple, optional): _description_. Defaults to (10,6). | ||
include_legend (bool, optional): _description_. Defaults to False. | ||
""" | ||
user_provided_ax = False if ax is None else True | ||
if ax is None: | ||
fig = plt.figure(facecolor='w', figsize=figsize) | ||
ax = fig.add_subplot(111) | ||
else: | ||
fig = ax.get_figure() | ||
|
||
ax.plot(x, y, 'k.') | ||
ax.set_xlabel(xlabel) | ||
ax.set_ylabel(ylabel) | ||
if include_legend: | ||
ax.legend() | ||
if not user_provided_ax: | ||
fig.tight_layout() | ||
return fig | ||
|
||
def plot_plotly(x, y, xlabel='time', ylabel='ds', figsize=(900, 600)): | ||
"""_summary_ | ||
Args: | ||
x (_type_): _description_ | ||
y (_type_): _description_ | ||
figsize (_type_): _description_ | ||
xlabel (str, optional): _description_. Defaults to 'time'. | ||
ylabel (str, optional): _description_. Defaults to 'ds'. | ||
""" | ||
line_width = 2 | ||
marker_size = 4 | ||
|
||
data = [] | ||
|
||
data.append(go.Scatter(name='', x=x, y=y, model='lines')) | ||
layout = go.Layout(width=figsize[0], height=figsize[1], showlegend=False) | ||
fig = go.Figure(data=data, layout=layout) | ||
|
||
return fig | ||
|