Skip to content

Commit

Permalink
test: Add functional tests for batch_push_results (#873)
Browse files Browse the repository at this point in the history
  • Loading branch information
cecheta authored May 13, 2024
1 parent 27e1e06 commit 828ccd2
Show file tree
Hide file tree
Showing 29 changed files with 363 additions and 78 deletions.
42 changes: 15 additions & 27 deletions code/backend/batch/utilities/helpers/azure_blob_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ContentSettings,
UserDelegationKey,
)
from azure.core.credentials import AzureNamedKeyCredential
from azure.storage.queue import QueueClient, BinaryBase64EncodePolicy
import chardet
from .env_helper import EnvHelper
Expand Down Expand Up @@ -48,38 +49,25 @@ def __init__(
env_helper: EnvHelper = EnvHelper()

self.auth_type = env_helper.AZURE_AUTH_TYPE
self.account_name = account_name or env_helper.AZURE_BLOB_ACCOUNT_NAME
self.container_name = container_name or env_helper.AZURE_BLOB_CONTAINER_NAME
self.endpoint = env_helper.AZURE_STORAGE_ACCOUNT_ENDPOINT

if self.auth_type == "rbac":
self.account_name = (
account_name if account_name else env_helper.AZURE_BLOB_ACCOUNT_NAME
)
self.account_key = None
self.container_name: str = (
container_name
if container_name
else env_helper.AZURE_BLOB_CONTAINER_NAME
)
self.blob_service_client = BlobServiceClient(
account_url=f"https://{self.account_name}.blob.core.windows.net/",
credential=DefaultAzureCredential(),
account_url=self.endpoint, credential=DefaultAzureCredential()
)
self.user_delegation_key = self.request_user_delegation_key(
blob_service_client=self.blob_service_client
)
else:
self.account_name = (
account_name if account_name else env_helper.AZURE_BLOB_ACCOUNT_NAME
)
self.account_key = (
account_key if account_key else env_helper.AZURE_BLOB_ACCOUNT_KEY
)
self.connect_str = connection_string(self.account_name, self.account_key)
self.container_name: str = (
container_name
if container_name
else env_helper.AZURE_BLOB_CONTAINER_NAME
)
self.blob_service_client: BlobServiceClient = (
BlobServiceClient.from_connection_string(self.connect_str)
self.account_key = account_key or env_helper.AZURE_BLOB_ACCOUNT_KEY
self.blob_service_client = BlobServiceClient(
self.endpoint,
credential=AzureNamedKeyCredential(
name=self.account_name, key=self.account_key
),
)
self.user_delegation_key = None

Expand Down Expand Up @@ -202,7 +190,7 @@ def get_all_files(self):
if blob.metadata
else False
),
"fullpath": f"https://{self.account_name}.blob.core.windows.net/{self.container_name}/{blob.name}?{sas}",
"fullpath": f"{self.endpoint}{self.container_name}/{blob.name}?{sas}",
"converted_filename": (
blob.metadata.get("converted_filename", "")
if blob.metadata
Expand All @@ -213,7 +201,7 @@ def get_all_files(self):
)
else:
converted_files[blob.name] = (
f"https://{self.account_name}.blob.core.windows.net/{self.container_name}/{blob.name}?{sas}"
f"{self.endpoint}{self.container_name}/{blob.name}?{sas}"
)

for file in files:
Expand Down Expand Up @@ -249,7 +237,7 @@ def get_container_sas(self):
def get_blob_sas(self, file_name):
# Generate a SAS URL to the blob and return it
return (
f"https://{self.account_name}.blob.core.windows.net/{self.container_name}/{file_name}"
f"{self.endpoint}{self.container_name}/{file_name}"
+ "?"
+ generate_blob_sas(
account_name=self.account_name,
Expand Down
4 changes: 4 additions & 0 deletions code/backend/batch/utilities/helpers/env_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def __load_config(self, **kwargs) -> None:
"AZURE_BLOB_ACCOUNT_KEY"
)
self.AZURE_BLOB_CONTAINER_NAME = os.getenv("AZURE_BLOB_CONTAINER_NAME", "")
self.AZURE_STORAGE_ACCOUNT_ENDPOINT = os.getenv(
"AZURE_STORAGE_ACCOUNT_ENDPOINT",
f"https://{self.AZURE_BLOB_ACCOUNT_NAME}.blob.core.windows.net/",
)
# Azure Form Recognizer
self.AZURE_FORM_RECOGNIZER_ENDPOINT = os.getenv(
"AZURE_FORM_RECOGNIZER_ENDPOINT", ""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import logging
import os

Expand All @@ -9,7 +10,9 @@ class AppConfig:
config: dict[str, str | None] = {
"APPLICATIONINSIGHTS_ENABLED": "False",
"AZURE_AUTH_TYPE": "keys",
"AZURE_BLOB_ACCOUNT_KEY": "some-blob-account-key",
"AZURE_BLOB_ACCOUNT_KEY": str(
base64.b64encode(b"some-blob-account-key"), "utf-8"
),
"AZURE_BLOB_ACCOUNT_NAME": "some-blob-account-name",
"AZURE_BLOB_CONTAINER_NAME": "some-blob-container-name",
"AZURE_CONTENT_SAFETY_ENDPOINT": "some-content-safety-endpoint",
Expand Down Expand Up @@ -61,12 +64,12 @@ class AppConfig:
"BACKEND_URL": "some-backend-url",
"DOCUMENT_PROCESSING_QUEUE_NAME": "some-document-processing-queue-name",
"FUNCTION_KEY": "some-function-key",
"LOAD_CONFIG_FROM_BLOB_STORAGE": "False",
"LOAD_CONFIG_FROM_BLOB_STORAGE": "True",
"LOGLEVEL": "DEBUG",
"ORCHESTRATION_STRATEGY": "openai_function",
"AZURE_SPEECH_RECOGNIZER_LANGUAGES": "en-US,es-ES",
"TIKTOKEN_CACHE_DIR": f"{os.path.dirname(os.path.realpath(__file__))}/resources",
"USE_ADVANCED_IMAGE_PROCESSING": "False",
"USE_ADVANCED_IMAGE_PROCESSING": "True",
"USE_KEY_VAULT": "False",
# These values are set directly within EnvHelper, adding them here ensures
# that they are removed from the environment when remove_from_environment() runs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import ssl
import pytest
from pytest_httpserver import HTTPServer
from tests.functional.backend_api.app_config import AppConfig
from tests.functional.app_config import AppConfig
from backend.batch.utilities.helpers.config.config_helper import (
CONFIG_CONTAINER_NAME,
CONFIG_FILE_NAME,
)
import trustme


Expand Down Expand Up @@ -38,6 +42,91 @@ def httpclient_ssl_context(ca):

@pytest.fixture(scope="function", autouse=True)
def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig):
httpserver.expect_request(
f"/{CONFIG_CONTAINER_NAME}/{CONFIG_FILE_NAME}",
method="HEAD",
).respond_with_data()

httpserver.expect_request(
f"/{CONFIG_CONTAINER_NAME}/{CONFIG_FILE_NAME}",
method="GET",
).respond_with_json(
{
"prompts": {
"condense_question_prompt": "",
"answering_system_prompt": "system prompt",
"answering_user_prompt": "## Retrieved Documents\n{sources}\n\n## User Question\n{question}",
"use_on_your_data_format": True,
"post_answering_prompt": "post answering prompt",
"enable_post_answering_prompt": False,
"enable_content_safety": True,
},
"messages": {"post_answering_filter": "post answering filer"},
"example": {
"documents": '{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}',
"user_question": "user question",
"answer": "answer",
},
"document_processors": [
{
"document_type": "pdf",
"chunking": {"strategy": "layout", "size": 500, "overlap": 100},
"loading": {"strategy": "layout"},
"use_advanced_image_processing": False,
},
{
"document_type": "txt",
"chunking": {"strategy": "layout", "size": 500, "overlap": 100},
"loading": {"strategy": "web"},
"use_advanced_image_processing": False,
},
{
"document_type": "url",
"chunking": {"strategy": "layout", "size": 500, "overlap": 100},
"loading": {"strategy": "web"},
"use_advanced_image_processing": False,
},
{
"document_type": "md",
"chunking": {"strategy": "layout", "size": 500, "overlap": 100},
"loading": {"strategy": "web"},
"use_advanced_image_processing": False,
},
{
"document_type": "html",
"chunking": {"strategy": "layout", "size": 500, "overlap": 100},
"loading": {"strategy": "web"},
"use_advanced_image_processing": False,
},
{
"document_type": "docx",
"chunking": {"strategy": "layout", "size": 500, "overlap": 100},
"loading": {"strategy": "docx"},
"use_advanced_image_processing": False,
},
{
"document_type": "jpg",
"chunking": {"strategy": "layout", "size": 500, "overlap": 100},
"loading": {"strategy": "layout"},
"use_advanced_image_processing": True,
},
{
"document_type": "png",
"chunking": {"strategy": "layout", "size": 500, "overlap": 100},
"loading": {"strategy": "layout"},
"use_advanced_image_processing": False,
},
],
"logging": {"log_user_interactions": True, "log_tokens": True},
"orchestrator": {"strategy": "openai_function"},
"integrated_vectorization_config": None,
},
headers={
"Content-Type": "application/json",
"Content-Range": "bytes 0-12882/12883",
},
)

httpserver.expect_request(
f"/openai/deployments/{app_config.get('AZURE_OPENAI_EMBEDDING_MODEL')}/embeddings",
method="POST",
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import pytest
from tests.functional.backend_api.app_config import AppConfig
from tests.functional.backend_api.common import get_free_port, start_app
from tests.functional.app_config import AppConfig
from tests.functional.tests.backend_api.common import get_free_port, start_app
from backend.batch.utilities.helpers.config.config_helper import ConfigHelper
from backend.batch.utilities.helpers.env_helper import EnvHelper

Expand Down Expand Up @@ -29,6 +29,7 @@ def app_config(make_httpserver, ca):
"AZURE_SEARCH_SERVICE": f"https://localhost:{make_httpserver.port}/",
"AZURE_CONTENT_SAFETY_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_SPEECH_REGION_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_STORAGE_ACCOUNT_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"SSL_CERT_FILE": ca_temp_path,
"CURL_CA_BUNDLE": ca_temp_path,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import requests
from string import Template

from tests.functional.backend_api.request_matching import (
from tests.functional.request_matching import (
RequestMatcher,
verify_request_made,
)
from tests.functional.backend_api.app_config import AppConfig
from tests.functional.app_config import AppConfig

pytestmark = pytest.mark.functional

Expand Down
Loading

0 comments on commit 828ccd2

Please sign in to comment.