Skip to content

Commit

Permalink
feat!: use numalogic v0.4a0 for using redis model store (#133)
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <ab93@users.noreply.github.com>
  • Loading branch information
ab93 authored May 23, 2023
1 parent 7cada11 commit 77675bb
Show file tree
Hide file tree
Showing 27 changed files with 1,207 additions and 1,710 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry
run: pipx install poetry==1.4.2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry
run: pipx install poetry==1.4.2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand All @@ -30,7 +30,8 @@ jobs:
- name: Install dependencies
run: |
poetry env use ${{ matrix.python-version }}
poetry install --all-extras --with dev,torch
poetry install --all-extras --with dev
poetry run pip install --no-cache -r requirements/requirements-torch.txt
- name: Run Coverage
run: |
Expand Down
28 changes: 5 additions & 23 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,12 @@ jobs:
black:
name: Black format
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]

steps:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'poetry'

- name: Install dependencies
run: |
poetry env use ${{ matrix.python-version }}
poetry install --with dev
- name: Black format check
run: poetry run black --check .
- uses: actions/checkout@v3
- uses: psf/black@stable
with:
options: "--check --verbose"
version: "~= 23.3"

flake8:
name: flake8 check
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry
run: pipx install poetry==1.4.2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
20 changes: 0 additions & 20 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,6 @@ RUN apt-get update \
&& chmod +x /dumb-init \
&& curl -sSL https://install.python-poetry.org | python3 -

####################################################################################################
# mlflow: used for running the mlflow server
####################################################################################################
FROM builder AS mlflow

WORKDIR $PYSETUP_PATH
COPY ./pyproject.toml ./poetry.lock ./
RUN poetry install --only mlflowserver --no-cache --no-root && \
rm -rf ~/.cache/pypoetry/

ADD . /app
WORKDIR /app

RUN chmod +x entry.sh

ENTRYPOINT ["/dumb-init", "--"]
CMD ["/app/entry.sh"]

EXPOSE 5000

####################################################################################################
# udf: used for running the udf vertices
####################################################################################################
Expand Down
56 changes: 51 additions & 5 deletions numaprom/clients/sentinel.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,41 @@
import os
from typing import Optional

from numalogic.tools.types import redis_client_t
from redis.backoff import ExponentialBackoff
from redis.client import Redis
from redis.exceptions import RedisClusterException, RedisError
from redis.retry import Retry
from redis.sentinel import Sentinel, MasterNotFoundError

from numaprom import get_logger
from numaprom._config import RedisConf
from numaprom.watcher import ConfigManager

_LOGGER = get_logger(__name__)
SENTINEL_MASTER_CLIENT: Optional[Redis] = None
SENTINEL_MASTER_CLIENT: Optional[redis_client_t] = None


def get_redis_client(
host: str, port: int, password: str, mastername: str, recreate: bool = False
) -> Redis:
host: str,
port: int,
password: str,
mastername: str,
decode_responses: bool = False,
recreate: bool = False,
) -> redis_client_t:
"""
Return a master redis client for sentinel connections, with retry.
Args:
host: Redis host
port: Redis port
password: Redis password
mastername: Redis sentinel master name
decode_responses: Whether to decode responses
recreate: Whether to flush and recreate the client
Returns:
Redis client instance
"""
global SENTINEL_MASTER_CLIENT

Expand All @@ -34,7 +53,11 @@ def get_redis_client(
MasterNotFoundError,
),
)
sentinel_args = {"sentinels": [(host, port)], "socket_timeout": 0.1, "decode_responses": True}
sentinel_args = {
"sentinels": [(host, port)],
"socket_timeout": 0.1,
"decode_responses": decode_responses,
}

_LOGGER.info("Sentinel redis params: %s", sentinel_args)

Expand All @@ -43,3 +66,26 @@ def get_redis_client(
)
SENTINEL_MASTER_CLIENT = sentinel.master_for(mastername)
return SENTINEL_MASTER_CLIENT


def get_redis_client_from_conf(redis_conf: RedisConf = None, **kwargs) -> redis_client_t:
"""
Return a master redis client from config for sentinel connections, with retry.
Args:
redis_conf: RedisConf object with host, port, master_name, etc.
**kwargs: Additional arguments to pass to get_redis_client.
Returns:
Redis client instance
"""
if not redis_conf:
redis_conf = ConfigManager.get_redis_config()

return get_redis_client(
redis_conf.host,
redis_conf.port,
password=os.getenv("REDIS_AUTH"),
mastername=redis_conf.master_name,
**kwargs
)
56 changes: 2 additions & 54 deletions numaprom/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,18 @@
from datetime import timedelta, datetime
from functools import wraps
from json import JSONDecodeError
from typing import Optional, Sequence, List
from typing import List

import boto3
import numpy as np
import pandas as pd
import pytz
from botocore.session import get_session
from mlflow.entities.model_registry import ModelVersion
from mlflow.exceptions import RestException
from numalogic.config import PostprocessFactory
from numalogic.models.threshold import SigmoidThreshold
from numalogic.registry import MLflowRegistry, ArtifactData
from pynumaflow.function import Messages, Message

from numaprom import get_logger, MetricConf
from numaprom.entities import TrainerPayload, StreamPayload
from numaprom.clients.prometheus import Prometheus
from numaprom.entities import TrainerPayload, StreamPayload
from numaprom.watcher import ConfigManager

_LOGGER = get_logger(__name__)
Expand Down Expand Up @@ -121,37 +116,6 @@ def is_host_reachable(hostname: str, port=None, max_retries=5, sleep_sec=5) -> b
return False


def load_model(
skeys: Sequence[str], dkeys: Sequence[str], artifact_type: str = "pytorch"
) -> Optional[ArtifactData]:
set_aws_session()
try:
registry_conf = ConfigManager.get_registry_config()
ml_registry = MLflowRegistry(
tracking_uri=registry_conf.tracking_uri, artifact_type=artifact_type
)
return ml_registry.load(skeys=skeys, dkeys=dkeys)
except RestException as warn:
if warn.error_code == 404:
return None
_LOGGER.warning("Non 404 error from mlflow: %r", warn)
except Exception as ex:
_LOGGER.error("Unexpected error while loading model from MLflow database: %r", ex)
return None


def save_model(
skeys: Sequence[str], dkeys: Sequence[str], model, artifact_type="pytorch", **metadata
) -> Optional[ModelVersion]:
set_aws_session()
registry_conf = ConfigManager.get_registry_config()
ml_registry = MLflowRegistry(
tracking_uri=registry_conf.tracking_uri, artifact_type=artifact_type
)
version = ml_registry.save(skeys=skeys, dkeys=dkeys, artifact=model, **metadata)
return version


def fetch_data(
payload: TrainerPayload,
metric_config: MetricConf,
Expand Down Expand Up @@ -183,22 +147,6 @@ def fetch_data(
return df


def set_aws_session() -> None:
"""
Setup default aws session by refreshing credentials.
"""
session = get_session()
credentials = session.get_credentials()
if not credentials:
_LOGGER.debug("No AWS credentials object returned")
return
boto3.setup_default_session(
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
)


def calculate_static_thresh(payload: StreamPayload, upper_limit: float):
"""
Calculates anomaly scores using static thresholding.
Expand Down
44 changes: 24 additions & 20 deletions numaprom/udf/inference.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import time
from datetime import datetime, timedelta

from numalogic.config import NumalogicConf
from numalogic.models.autoencoder import AutoencoderTrainer
from numalogic.registry import ArtifactData
from numalogic.registry import ArtifactData, RedisRegistry
from numalogic.tools.data import StreamingDataset
from numalogic.tools.exceptions import RedisRegistryError
from orjson import orjson
from pynumaflow.function import Datum
from torch.utils.data import DataLoader

from numaprom import get_logger, MetricConf
from numaprom import get_logger
from numaprom.clients.sentinel import get_redis_client_from_conf
from numaprom.entities import PayloadFactory
from numaprom.entities import Status, StreamPayload, Header
from numaprom.tools import load_model, msg_forward
from numaprom.tools import msg_forward
from numaprom.watcher import ConfigManager

_LOGGER = get_logger(__name__)
REDIS_CLIENT = get_redis_client_from_conf()


def _run_inference(
Expand All @@ -40,17 +42,6 @@ def _run_inference(
return payload


def _is_model_stale(
payload: StreamPayload, artifact_data: ArtifactData, metric_config: MetricConf
) -> bool:
date_updated = artifact_data.extras["last_updated_timestamp"] / 1000
stale_date = (datetime.now() - timedelta(hours=int(metric_config.retrain_freq_hr))).timestamp()
if date_updated < stale_date:
_LOGGER.info("%s - Model found is stale for %s", payload.uuid, payload.composite_keys)
return True
return False


@msg_forward
def inference(_: str, datum: Datum) -> bytes:
_start_time = time.perf_counter()
Expand All @@ -74,10 +65,23 @@ def inference(_: str, datum: Datum) -> bytes:
numalogic_conf = metric_config.numalogic_conf

# Load inference model
artifact_data = load_model(
skeys=[payload.composite_keys["namespace"], payload.composite_keys["name"]],
dkeys=[numalogic_conf.model.name],
)
model_registry = RedisRegistry(client=REDIS_CLIENT)
try:
artifact_data = model_registry.load(
skeys=[payload.composite_keys["namespace"], payload.composite_keys["name"]],
dkeys=[numalogic_conf.model.name],
)
except RedisRegistryError as err:
_LOGGER.exception(
"%s - Error while fetching inference artifact, keys: %s, err: %r",
payload.uuid,
payload.composite_keys,
err,
)
payload.set_header(Header.STATIC_INFERENCE)
payload.set_status(Status.RUNTIME_ERROR)
return orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)

if not artifact_data:
_LOGGER.info(
"%s - Inference artifact not found, forwarding for static thresholding. Keys: %s",
Expand All @@ -89,7 +93,7 @@ def inference(_: str, datum: Datum) -> bytes:
return orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)

# Check if current model is stale
if _is_model_stale(payload, artifact_data, metric_config):
if RedisRegistry.is_artifact_stale(artifact_data, int(metric_config.retrain_freq_hr)):
payload.set_header(Header.MODEL_STALE)

# Generate predictions
Expand Down
Loading

0 comments on commit 77675bb

Please sign in to comment.