Skip to content

Commit

Permalink
Backend: Link multiple jobs to a single experiment (#570)
Browse files Browse the repository at this point in the history
Create a new experiment instance in the database and link each job the
new experiment spawns to it.

Refs #570

Signed-off-by: Dimitris Poulopoulos <dimitris@mozilla.ai>
  • Loading branch information
dpoulopoulos committed Jan 10, 2025
1 parent 6f0f1f5 commit fbe3fab
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 11 deletions.
6 changes: 4 additions & 2 deletions lumigator/python/mzai/backend/backend/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from backend.db import session_manager
from backend.repositories.datasets import DatasetRepository
from backend.repositories.experiments import ExperimentRepository
from backend.repositories.jobs import JobRepository, JobResultRepository
from backend.services.completions import MistralCompletionService, OpenAICompletionService
from backend.services.datasets import DatasetService
Expand Down Expand Up @@ -61,9 +62,10 @@ def get_job_service(session: DBSessionDep, dataset_service: DatasetServiceDep) -


def get_experiment_service(
job_service: JobServiceDep, dataset_service: DatasetServiceDep
session: DBSessionDep, job_service: JobServiceDep, dataset_service: DatasetServiceDep
) -> ExperimentService:
return ExperimentService(job_service, dataset_service)
experiment_repo = ExperimentRepository(session)
return ExperimentService(experiment_repo, job_service, dataset_service)


ExperimentServiceDep = Annotated[ExperimentService, Depends(get_experiment_service)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def get_experiment(service: JobServiceDep, experiment_id: UUID) -> ExperimentRes

@router.get("/")
def list_experiments(
service: JobServiceDep,
service: ExperimentServiceDep,
skip: int = 0,
limit: int = 100,
) -> ListingResponse[ExperimentResponse]:
return ListingResponse[ExperimentResponse].model_validate(
service.list_jobs(skip, limit).model_dump()
service.list_experiments(skip, limit).model_dump()
)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from sqlalchemy.orm import Session

from backend.records.experiments import ExperimentRecord
from backend.repositories.base import BaseRepository


class ExperimentRepository(BaseRepository[ExperimentRecord]):
def __init__(self, session: Session):
super().__init__(ExperimentRecord, session)
42 changes: 37 additions & 5 deletions lumigator/python/mzai/backend/backend/services/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,28 @@
from fastapi import BackgroundTasks, UploadFile
from lumigator_schemas.datasets import DatasetFormat
from lumigator_schemas.experiments import ExperimentCreate, ExperimentResponse
from lumigator_schemas.extras import ListingResponse
from lumigator_schemas.jobs import (
JobEvalCreate,
JobInferenceCreate,
JobStatus,
)
from s3fs import S3FileSystem

from backend.repositories.experiments import ExperimentRepository
from backend.services.datasets import DatasetService
from backend.services.jobs import JobService
from backend.settings import settings


class ExperimentService:
def __init__(self, job_service: JobService, dataset_service: DatasetService):
def __init__(
self,
experiment_repo: ExperimentRepository,
job_service: JobService,
dataset_service: DatasetService,
):
self._experiment_repo = experiment_repo
self._job_service = job_service
self._dataset_service = dataset_service

Expand Down Expand Up @@ -113,7 +121,9 @@ async def on_job_complete(self, job_id: UUID, task: Callable = None, *args):
if task is not None:
task(*args)

def _run_eval(self, inference_job_id: UUID, request: ExperimentCreate):
def _run_eval(
self, inference_job_id: UUID, request: ExperimentCreate, experiment_id: UUID | None = None
):
# use the inference job id to recover the dataset record
dataset_record = self._dataset_service._get_dataset_record_by_job_id(inference_job_id)

Expand All @@ -127,7 +137,9 @@ def _run_eval(self, inference_job_id: UUID, request: ExperimentCreate):
}

# submit the job
self._job_service.create_job(JobEvalCreate.model_validate(job_eval_dict))
self._job_service.create_job(
JobEvalCreate.model_validate(job_eval_dict), experiment_id=experiment_id
)

# TODO: do something with the job_response.id (e.g. add to the experiments' job list)

Expand All @@ -140,6 +152,11 @@ def create_experiment(
# and will run only once the response has been sent.
# See here: https://www.starlette.io/background/

experiment_record = self._experiment_repo.create(
name=request.name, description=request.description
)
loguru.logger.info(f"Created experiment '{request.name}' with ID '{experiment_record.id}'.")

# input is ExperimentCreate, we need to split the configs and generate one
# JobInferenceCreate and one JobEvalCreate
job_inference_dict = {
Expand All @@ -154,7 +171,8 @@ def create_experiment(

# submit inference job first
job_response = self._job_service.create_job(
JobInferenceCreate.model_validate(job_inference_dict)
JobInferenceCreate.model_validate(job_inference_dict),
experiment_id=experiment_record.id,
)

# Inference jobs produce a new dataset
Expand All @@ -170,7 +188,21 @@ def create_experiment(
# run evaluation job afterwards
# (NOTE: tasks in starlette are executed sequentially: https://www.starlette.io/background/)
background_tasks.add_task(
self.on_job_complete, job_response.id, self._run_eval, job_response.id, request
self.on_job_complete,
job_response.id,
self._run_eval,
job_response.id,
request,
experiment_record.id,
)

return job_response

def list_experiments(
self, skip: int = 0, limit: int = 100
) -> ListingResponse[ExperimentResponse]:
records = self._experiment_repo.list(skip, limit)
return ListingResponse(
total=self._experiment_repo.count(),
items=[ExperimentResponse.model_validate(x) for x in records],
)
8 changes: 6 additions & 2 deletions lumigator/python/mzai/backend/backend/services/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def _get_job_params(self, job_type: str, record, request: BaseModel) -> dict:

return job_params

def create_job(self, request: JobEvalCreate | JobInferenceCreate) -> JobResponse:
def create_job(
self, request: JobEvalCreate | JobInferenceCreate, experiment_id: UUID | None = None
) -> JobResponse:
"""Creates a new evaluation workload to run on Ray and returns the response status."""
if isinstance(request, JobEvalCreate):
job_type = JobType.EVALUATION
Expand All @@ -173,7 +175,9 @@ def create_job(self, request: JobEvalCreate | JobInferenceCreate) -> JobResponse
raise HTTPException(status.HTTP_501_NOT_IMPLEMENTED, "Job type not implemented.")

# Create a db record for the job
record = self.job_repo.create(name=request.name, description=request.description)
record = self.job_repo.create(
name=request.name, description=request.description, experiment_id=experiment_id
)

# prepare configuration parameters, which depend both on the user inputs
# (request) and on the job type
Expand Down

0 comments on commit fbe3fab

Please sign in to comment.