diff --git a/jobbergate-agent/jobbergate_agent/jobbergate/update.py b/jobbergate-agent/jobbergate_agent/jobbergate/update.py index 4c947231..dc709aef 100644 --- a/jobbergate-agent/jobbergate_agent/jobbergate/update.py +++ b/jobbergate-agent/jobbergate_agent/jobbergate/update.py @@ -1,6 +1,7 @@ import asyncio import json from itertools import chain +from textwrap import dedent from typing import List import msgpack @@ -17,7 +18,7 @@ InfluxDBPointDict, ) from jobbergate_agent.settings import SETTINGS -from jobbergate_agent.utils.exception import JobbergateApiError, SbatchError +from jobbergate_agent.utils.exception import JobbergateApiError, SbatchError, JobbergateAgentError from jobbergate_agent.utils.logging import log_error from jobbergate_agent.jobbergate.constants import INFLUXDB_MEASUREMENT from jobbergate_agent.utils.compute import aggregate_influx_measures @@ -83,20 +84,39 @@ async def update_job_data( async def fetch_influx_data( - time: int, host: str, step: int, task: int, job: int, measurement: INFLUXDB_MEASUREMENT + job: int, + measurement: INFLUXDB_MEASUREMENT, + *, + time: int | None = None, + host: str | None = None, + step: int | None = None, + task: int | None = None, ) -> list[InfluxDBPointDict]: """ Fetch data from InfluxDB for a given host, step and task. """ - query = f""" - SELECT * FROM {measurement} WHERE time > $time AND host = $host AND step = $step AND task = $task AND job = $job - """ - with JobbergateApiError.handle_errors("Failed to fetch data from InfluxDB", do_except=log_error): + with JobbergateAgentError.handle_errors("Failed to fetch measures from InfluxDB", do_except=log_error): + all_none = all(arg is None for arg in [time, host, step, task]) + all_set = all(arg is not None for arg in [time, host, step, task]) + + if not (all_none or all_set): + raise ValueError("Invalid argument combination: all optional arguments must be either set or None.") + + if all_set: + query = dedent(f""" + SELECT * FROM {measurement} WHERE time > $time AND host = $host AND step = $step AND task = $task AND job = $job + """) + params = {"time": time, "host": host, "step": str(step), "task": str(task), "job": str(job)} + else: + query = f"SELECT * FROM {measurement} WHERE job = $job" + params = {"job": str(job)} + assert influxdb_client is not None # mypy assertion - params = dict(time=time, host=host, step=str(step), task=str(task), job=str(job)) + logger.debug(f"Querying InfluxDB with: {query=}, {params=}") result = influxdb_client.query(query, bind_params=params, epoch="us") logger.debug("Successfully fetched data from InfluxDB") + return [ InfluxDBPointDict( time=point["time"], @@ -140,18 +160,27 @@ async def update_job_metrics(active_job_submittion: ActiveJobSubmission) -> None influx_measurements = fetch_influx_measurements() - tasks = ( - fetch_influx_data( - job_max_time.max_time, - job_max_time.node_host, - job_max_time.step, - job_max_time.task, - active_job_submittion.slurm_job_id, - measurement["name"], + if not job_max_times.max_times: + tasks = ( + fetch_influx_data( + active_job_submittion.slurm_job_id, + measurement["name"], + ) + for measurement in influx_measurements + ) + else: + tasks = ( + fetch_influx_data( + active_job_submittion.slurm_job_id, + measurement["name"], + time=job_max_time.max_time, + host=job_max_time.node_host, + step=job_max_time.step, + task=job_max_time.task, + ) + for job_max_time in job_max_times.max_times + for measurement in influx_measurements ) - for job_max_time in job_max_times.max_times - for measurement in influx_measurements - ) results = await asyncio.gather(*list(tasks)) data_points = chain.from_iterable(results) aggregated_data_points = aggregate_influx_measures(data_points) diff --git a/jobbergate-agent/jobbergate_agent/utils/exception.py b/jobbergate-agent/jobbergate_agent/utils/exception.py index 96fc2c35..eb69610c 100644 --- a/jobbergate-agent/jobbergate_agent/utils/exception.py +++ b/jobbergate-agent/jobbergate_agent/utils/exception.py @@ -8,31 +8,31 @@ from buzz.tools import DoExceptParams, noop -class ClusterAgentError(Buzz): +class JobbergateAgentError(Buzz): """Raise exception when execution command returns an error""" -class ProcessExecutionError(ClusterAgentError): +class ProcessExecutionError(JobbergateAgentError): """Raise exception when execution command returns an error""" -class AuthTokenError(ClusterAgentError): +class AuthTokenError(JobbergateAgentError): """Raise exception when there are connection issues with the backend""" -class SbatchError(ClusterAgentError): +class SbatchError(JobbergateAgentError): """Raise exception when sbatch raises any error""" -class JobbergateApiError(ClusterAgentError): +class JobbergateApiError(JobbergateAgentError): """Raise exception when communication with Jobbergate API fails""" -class JobSubmissionError(ClusterAgentError): +class JobSubmissionError(JobbergateAgentError): """Raise exception when a job cannot be submitted raises any error""" -class SlurmParameterParserError(ClusterAgentError): +class SlurmParameterParserError(JobbergateAgentError): """Raise exception when Slurm mapper or SBATCH parser face any error""" diff --git a/jobbergate-agent/tests/jobbergate/test_update.py b/jobbergate-agent/tests/jobbergate/test_update.py index 1313c90e..1c182026 100644 --- a/jobbergate-agent/tests/jobbergate/test_update.py +++ b/jobbergate-agent/tests/jobbergate/test_update.py @@ -2,8 +2,10 @@ import random from datetime import datetime from typing import get_args +from textwrap import dedent from unittest import mock from collections.abc import Callable +from itertools import combinations import contextlib import httpx @@ -22,7 +24,7 @@ ) from jobbergate_agent.jobbergate.constants import INFLUXDB_MEASUREMENT from jobbergate_agent.settings import SETTINGS -from jobbergate_agent.utils.exception import JobbergateApiError +from jobbergate_agent.utils.exception import JobbergateApiError, JobbergateAgentError @pytest.fixture() @@ -333,10 +335,11 @@ def _mocked_update_job_data(job_submission_id, slurm_job_data): @pytest.mark.asyncio @mock.patch("jobbergate_agent.jobbergate.update.influxdb_client") -async def test_fetch_influx_data__success(mocked_influxdb_client: mock.MagicMock): +async def test_fetch_influx_data__success_with_all_set(mocked_influxdb_client: mock.MagicMock): """ Test that the ``fetch_influx_data()`` function can successfully retrieve - data from InfluxDB as a list of ``InfluxDBPointDict``. + data from InfluxDB as a list of ``InfluxDBPointDict`` when all arguments + are passed. """ time = random.randint(0, 1000) # noqa: F811 host = "test-host" @@ -357,6 +360,11 @@ async def test_fetch_influx_data__success(mocked_influxdb_client: mock.MagicMock ) ] + query = dedent(f""" + SELECT * FROM {measurement} WHERE time > $time AND host = $host AND step = $step AND task = $task AND job = $job + """) + params = dict(time=time, host=host, step=str(step), task=str(task), job=str(job)) + result = await fetch_influx_data( time=time, host=host, @@ -374,13 +382,92 @@ async def test_fetch_influx_data__success(mocked_influxdb_client: mock.MagicMock assert result[0]["task"] == task assert result[0]["value"] == measurement_value assert result[0]["measurement"] == measurement + mocked_influxdb_client.query.assert_called_once_with(query, bind_params=params, epoch="us") + + +@pytest.mark.asyncio +@mock.patch("jobbergate_agent.jobbergate.update.influxdb_client") +async def test_fetch_influx_data__success_with_all_None(mocked_influxdb_client: mock.MagicMock): + """ + Test that the ``fetch_influx_data()`` function can successfully retrieve + data from InfluxDB as a list of ``InfluxDBPointDict`` when some arguments + are None. + """ + time = random.randint(0, 1000) # noqa: F811 + host = "test-host" + step = random.randint(0, 1000) + task = random.randint(0, 1000) + job = random.randint(0, 1000) + measurement_value = random.uniform(1, 1000) + measurement = random.choice(get_args(INFLUXDB_MEASUREMENT)) + + mocked_influxdb_client.query.return_value.get_points.return_value = [ + dict( + time=time, + host=host, + job=job, + step=step, + task=task, + value=measurement_value, + ) + ] + + query = f"SELECT * FROM {measurement} WHERE job = $job" + params = {"job": str(job)} + + result = await fetch_influx_data(job, measurement) + + assert len(result) == 1 + assert result[0]["time"] == time + assert result[0]["host"] == host + assert result[0]["job"] == job + assert result[0]["step"] == step + assert result[0]["task"] == task + assert result[0]["value"] == measurement_value + assert result[0]["measurement"] == measurement + mocked_influxdb_client.query.assert_called_once_with(query, bind_params=params, epoch="us") @pytest.mark.asyncio +@pytest.mark.parametrize( + "time, host, step, task", + [ + tuple(random.randint(1, 100) if i not in combination else None for i in range(4)) + for r in range(1, 4) + for combination in combinations(range(4), r) + ], +) @mock.patch("jobbergate_agent.jobbergate.update.influxdb_client") -async def test_fetch_influx_data__raises_JobbergateApiError_if_query_fails(mocked_influxdb_client: mock.MagicMock): +async def test_fetch_influx_data__raises_JobbergateAgentError_if_bad_arguments_are_passed( + mocked_influxdb_client: mock.MagicMock, + time: int | None, + host: int | None, + step: int | None, + task: int | None, +): + job = random.randint(0, 100) + measurement = random.choice(get_args(INFLUXDB_MEASUREMENT)) + + with pytest.raises( + JobbergateAgentError, match="Invalid argument combination: all optional arguments must be either set or None." + ): + await fetch_influx_data( + job, + measurement, + time=time, + host=str(host) if host is not None else None, + step=step, + task=task, + ) + + mocked_influxdb_client.query.assert_not_called() + + +@pytest.mark.asyncio +@mock.patch("jobbergate_agent.jobbergate.update.influxdb_client") +async def test_fetch_influx_data__raises_JobbergateAgentError_if_query_fails(mocked_influxdb_client: mock.MagicMock): """ - Test that the ``fetch_influx_data()`` function will raise a JobbergateApiError + Test that the ``fetch_influx_data()`` function will raise a JobbergateAgentError if the query to InfluxDB fails. """ measurement = random.choice(get_args(INFLUXDB_MEASUREMENT)) @@ -393,33 +480,33 @@ async def test_fetch_influx_data__raises_JobbergateApiError_if_query_fails(mocke task = random.randint(0, 1000) job = random.randint(0, 1000) - query = f""" + query = dedent(f""" SELECT * FROM {measurement} WHERE time > $time AND host = $host AND step = $step AND task = $task AND job = $job - """ + """) params = dict(time=time, host=host, step=str(step), task=str(task), job=str(job)) - with pytest.raises(JobbergateApiError, match="Failed to fetch data from InfluxDB"): + with pytest.raises(JobbergateAgentError, match="Failed to fetch measures from InfluxDB -- Exception: BOOM!"): await fetch_influx_data( + job=job, + measurement=measurement, time=time, host=host, step=step, task=task, - job=job, - measurement=measurement, ) mocked_influxdb_client.query.assert_called_once_with(query, bind_params=params, epoch="us") @pytest.mark.asyncio -async def test_fetch_influx_data__raises_JobbergateApiError_if_influxdb_client_is_None(): +async def test_fetch_influx_data__raises_JobbergateAgentError_if_influxdb_client_is_None(): """ - Test that the ``fetch_influx_data()`` function will raise a JobbergateApiError + Test that the ``fetch_influx_data()`` function will raise a JobbergateAgentError if the influxdb_client is None. """ measurement = random.choice(get_args(INFLUXDB_MEASUREMENT)) with mock.patch("jobbergate_agent.jobbergate.update.influxdb_client", None): - with pytest.raises(JobbergateApiError, match="Failed to fetch data from InfluxDB"): + with pytest.raises(JobbergateAgentError, match="Failed to fetch measures from InfluxDB -- AssertionError:"): await fetch_influx_data( time=random.randint(0, 1000), host="test-host", @@ -587,12 +674,12 @@ async def test_update_job_metrics__error_sending_metrics_to_api( mocked_fetch_influx_data.assert_has_calls( [ mock.call( - job_max_time["max_time"], - job_max_time["node_host"], - job_max_time["step"], - job_max_time["task"], slurm_job_id, measurement["name"], + time=job_max_time["max_time"], + host=job_max_time["node_host"], + step=job_max_time["step"], + task=job_max_time["task"], ) for job_max_time in job_max_times["max_times"] for measurement in measurements @@ -635,8 +722,8 @@ async def test_update_job_metrics__success( job_max_times_response: Callable[[int, int, int, int], dict[str, int | list[dict[str, int | str]]]], ): """ - Test that the ``update_job_metrics()`` function will log an error if it fails - to send the job metrics to the API. + Test that the ``update_job_metrics()`` function will execute its logic properly + when the API requests are un. """ active_job_submission = ActiveJobSubmission(id=job_submission_id, slurm_job_id=slurm_job_id) job_max_times = job_max_times_response(job_submission_id, num_hosts, num_steps, num_tasks) @@ -679,12 +766,12 @@ async def test_update_job_metrics__success( mocked_fetch_influx_data.assert_has_calls( [ mock.call( - job_max_time["max_time"], - job_max_time["node_host"], - job_max_time["step"], - job_max_time["task"], slurm_job_id, measurement["name"], + time=job_max_time["max_time"], + host=job_max_time["node_host"], + step=job_max_time["step"], + task=job_max_time["task"], ) for job_max_time in job_max_times["max_times"] for measurement in measurements @@ -695,3 +782,79 @@ async def test_update_job_metrics__success( ) mocked_aggregate_influx_measures.assert_called_once_with(iter_dummy_data_points) mocked_msgpack.packb.assert_called_once_with("super-dummy-aggregated-data") + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_access_token") +@pytest.mark.parametrize( + "job_submission_id, slurm_job_id, measurements", + [ + (1, 22, [{"name": "measurement1"}, {"name": "measurement2"}]), + (2, 33, [{"name": "measurement1"}]), + (3, 11, [{"name": "measurement1"}, {"name": "measurement2"}, {"name": "measurement3"}]), + ], +) +@mock.patch("jobbergate_agent.jobbergate.update.fetch_influx_measurements") +@mock.patch("jobbergate_agent.jobbergate.update.fetch_influx_data") +@mock.patch("jobbergate_agent.jobbergate.update.aggregate_influx_measures") +@mock.patch("jobbergate_agent.jobbergate.update.msgpack") +@mock.patch("jobbergate_agent.jobbergate.update.chain") +async def test_update_job_metrics__success_with_max_times_empty( + mocked_chain: mock.MagicMock, + mocked_msgpack: mock.MagicMock, + mocked_aggregate_influx_measures: mock.MagicMock, + mocked_fetch_influx_data: mock.MagicMock, + mocked_fetch_influx_measurements: mock.MagicMock, + job_submission_id: int, + slurm_job_id: int, + measurements: list[dict[str, str]], + job_max_times_response: Callable[[int, int, int, int], dict[str, int | list[dict[str, int | str]]]], +): + """ + Test that the ``update_job_metrics()`` function will execute the proper logic when + the API response to get the `max_times` is empty. + """ + active_job_submission = ActiveJobSubmission(id=job_submission_id, slurm_job_id=slurm_job_id) + job_max_times = job_max_times_response(job_submission_id, 0, 0, 0) + + dummy_data_point = { + "time": 1, + "host": "host_1", + "job": "1", + "step": "1", + "task": "1", + "value": 1.0, + "measurement": "measurement1", + } + dummy_data_points = [dummy_data_point] * len(measurements) + iter_dummy_data_points = iter(dummy_data_points) + + mocked_fetch_influx_measurements.return_value = measurements + mocked_fetch_influx_data.return_value = dummy_data_points + # doesn't return the real aggregated data due to test complexity + mocked_chain.from_iterable.return_value = iter_dummy_data_points + mocked_aggregate_influx_measures.return_value = "super-dummy-aggregated-data" + mocked_msgpack.packb.return_value = b"dummy-msgpack-data" + + with respx.mock: + respx.get(f"{SETTINGS.BASE_API_URL}/jobbergate/job-submissions/agent/metrics/{job_submission_id}").mock( + return_value=httpx.Response( + status_code=200, + json=job_max_times, + ) + ) + respx.put( + f"{SETTINGS.BASE_API_URL}/jobbergate/job-submissions/agent/metrics/{job_submission_id}", + content=b"dummy-msgpack-data", + headers={"Content-Type": "application/octet-stream"}, + ).mock(return_value=httpx.Response(status_code=200)) + + await update_job_metrics(active_job_submission) + + mocked_fetch_influx_measurements.assert_called_once_with() + mocked_fetch_influx_data.assert_has_calls( + [mock.call(slurm_job_id, measurement["name"]) for measurement in measurements] + ) + mocked_chain.from_iterable.assert_called_once_with([dummy_data_points] * len(measurements)) + mocked_aggregate_influx_measures.assert_called_once_with(iter_dummy_data_points) + mocked_msgpack.packb.assert_called_once_with("super-dummy-aggregated-data")