From 82959eaabfea6b9572546518597e7a6e930bd233 Mon Sep 17 00:00:00 2001 From: Ayush Mishra Date: Wed, 18 Sep 2024 20:09:45 +0530 Subject: [PATCH] update for e2e test --- .../components/driver/batch_score_oss/spec.yaml | 2 +- .../components/driver/tests/e2e/conftest.py | 13 ++++++++----- .../driver/tests/e2e/serverless_endpoint_test.py | 9 ++++----- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/assets/batch_score/components/driver/batch_score_oss/spec.yaml b/assets/batch_score/components/driver/batch_score_oss/spec.yaml index 2c769e4902..fc5628e053 100644 --- a/assets/batch_score/components/driver/batch_score_oss/spec.yaml +++ b/assets/batch_score/components/driver/batch_score_oss/spec.yaml @@ -51,7 +51,7 @@ input_data: ${{inputs.data_input_table}} task: code: ../src/batch_score_oss type: run_function - entry_script: root.main + entry_script: main # Enable PRS safe append row configuration that is needed when dealing with large outputs with Unicode characters. # Using --append_row_safe_output true program_arguments: >- diff --git a/assets/batch_score/components/driver/tests/e2e/conftest.py b/assets/batch_score/components/driver/tests/e2e/conftest.py index 3dd2c76aff..bb0c90f202 100644 --- a/assets/batch_score/components/driver/tests/e2e/conftest.py +++ b/assets/batch_score/components/driver/tests/e2e/conftest.py @@ -15,6 +15,9 @@ from .util import _get_component_name, _set_and_get_component_name_ver, create_copy +BATCH_SCORE_COMPONENT_YAML_NAME = "batch_score_oss" + + # Marks all tests in this directory as e2e tests @pytest.fixture(autouse=True, params=[pytest.param(None, marks=pytest.mark.e2e)]) def mark_as_e2e_test(): @@ -129,25 +132,25 @@ def register_components(main_worker_lock, asset_version): if not _is_main_worker(main_worker_lock): return - _register_component("batch_score", asset_version) + _register_component(BATCH_SCORE_COMPONENT_YAML_NAME, asset_version) @pytest.fixture(scope="session") def batch_score_yml_component(asset_version): """Return the component name batch_score.yml.""" - return _get_component_metadata("batch_score.yml", asset_version) + return _get_component_metadata(f"{BATCH_SCORE_COMPONENT_YAML_NAME}.yml", asset_version) @pytest.fixture(scope="session") def llm_batch_score_yml_component(asset_version): """Return the component version for batch_score_llm.yml.""" - return _get_component_metadata("batch_score", asset_version) + return _get_component_metadata(BATCH_SCORE_COMPONENT_YAML_NAME, asset_version) def _register_component(component_yml_name, asset_version): # Copy component to a temporary file to not muddle dev environments batch_score_component_filepath = os.path.join( - pytest.source_dir, "assets", "managed_batch_inference", "components", component_yml_name, "spec.yaml" + pytest.source_dir, component_yml_name, "spec.yaml" ) create_copy(batch_score_component_filepath, pytest.copied_batch_score_component_filepath) @@ -169,6 +172,6 @@ def _register_component(component_yml_name, asset_version): def _get_component_metadata(component_yml_name, asset_version): batch_score_component_filepath = os.path.join( - pytest.source_dir, "assets", "managed_batch_inference", "components", component_yml_name, "spec.yaml" + pytest.source_dir, component_yml_name, "spec.yaml" ) return _get_component_name(batch_score_component_filepath), asset_version diff --git a/assets/batch_score/components/driver/tests/e2e/serverless_endpoint_test.py b/assets/batch_score/components/driver/tests/e2e/serverless_endpoint_test.py index 9e681f98e6..73e7da73e4 100644 --- a/assets/batch_score/components/driver/tests/e2e/serverless_endpoint_test.py +++ b/assets/batch_score/components/driver/tests/e2e/serverless_endpoint_test.py @@ -6,15 +6,14 @@ import os import pytest +from pathlib import Path from pydantic.utils import deep_update from .util import _submit_job_and_monitor_till_completion, set_component # Common configuration -source_dir = os.getcwd() -gated_llm_pipeline_filepath = os.path.join(source_dir, - "assets", "managed_batch_inference", "components", "tests", "batch_score", - "e2e", "prs_pipeline_templates", "base_llm.yml") +source_dir = Path(__file__).parent +gated_llm_pipeline_filepath = os.path.join(source_dir, "prs_pipeline_templates", "base_llm.yml") JOB_NAME = "gated_batch_score_llm" # Should be equivalent to base_llm.yml's job name YAML_COMPONENT = {"jobs": {JOB_NAME: {"component": None}}} # Placeholder for component name set below. @@ -54,7 +53,7 @@ @pytest.mark.smoke -@pytest.mark.e2e +# @pytest.mark.e2e @pytest.mark.timeout(20 * 60) def test_gated_serverless_endpoint_batch_score_completion(llm_batch_score_yml_component): """Test gate for batch score serverless endpoints completion models."""