Skip to content

Commit

Permalink
feat: [IV] Reprocess All documents functionality (#870)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Dougal <adamdougal@microsoft.com>
  • Loading branch information
komalg1 and adamdougal authored May 14, 2024
1 parent 13602ee commit 89e328b
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ unittest-frontend: build-frontend ## 🧪 Unit test the Frontend webapp

functionaltest: ## 🧪 Run the functional tests
@echo -e "\e[34m$@\e[0m" || true
@ poetry run pytest -m "functional"
@poetry run pytest code/tests/functional -m "functional"

uitest: ## 🧪 Run the ui tests in headless mode
@echo -e "\e[34m$@\e[0m" || true
Expand Down
21 changes: 17 additions & 4 deletions code/backend/batch/BatchStartProcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import json
import azure.functions as func

from utilities.helpers.embedders.integrated_vectorization_embedder import (
IntegratedVectorizationEmbedder,
)
from utilities.helpers.env_helper import EnvHelper
from utilities.helpers.azure_blob_storage_client import (
AzureBlobStorageClient,
create_queue_client,
Expand All @@ -16,19 +20,28 @@
@bp_batch_start_processing.route(route="BatchStartProcessing")
def batch_start_processing(req: func.HttpRequest) -> func.HttpResponse:
logger.info("Requested to start processing all documents received")
env_helper: EnvHelper = EnvHelper()
# Set up Blob Storage Client
azure_blob_storage_client = AzureBlobStorageClient()
# Get all files from Blob Storage
files_data = azure_blob_storage_client.get_all_files()

files_data = list(map(lambda x: {"filename": x["filename"]}, files_data))

# Send a message to the queue for each file
queue_client = create_queue_client()
for fd in files_data:
queue_client.send_message(json.dumps(fd).encode("utf-8"))
if env_helper.AZURE_SEARCH_USE_INTEGRATED_VECTORIZATION:
reprocess_integrated_vectorization(env_helper)
else:
# Send a message to the queue for each file
queue_client = create_queue_client()
for fd in files_data:
queue_client.send_message(json.dumps(fd).encode("utf-8"))

return func.HttpResponse(
f"Conversion started successfully for {len(files_data)} documents.",
status_code=200,
)


def reprocess_integrated_vectorization(env_helper: EnvHelper):
indexer_embedder = IntegratedVectorizationEmbedder(env_helper)
indexer_embedder.reprocess_all()
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

class EmbedderBase(ABC):
@abstractmethod
def embed_file(self, source_url: str, file_name: str):
def embed_file(self, source_url: str, file_name: str = None):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, env_helper: EnvHelper):
self.env_helper = env_helper
self.llm_helper: LLMHelper = LLMHelper()

def embed_file(self, source_url: str, file_name: str):
def embed_file(self, source_url: str, file_name: str = None):
self.process_using_integrated_vectorization(source_url=source_url)

def process_using_integrated_vectorization(self, source_url: str):
Expand All @@ -39,3 +39,10 @@ def process_using_integrated_vectorization(self, source_url: str):
except Exception as e:
logger.error(f"Error processing {source_url}: {e}")
raise e

def reprocess_all(self):
search_indexer = AzureSearchIndexer(self.env_helper)
if search_indexer.indexer_exists(self.env_helper.AZURE_SEARCH_INDEXER_NAME):
search_indexer.run_indexer(self.env_helper.AZURE_SEARCH_INDEXER_NAME)
else:
self.process_using_integrated_vectorization(source_url="all")
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,14 @@ def create_or_update_indexer(self, indexer_name: str, skillset_name: str):
)
return indexer_result

# To be updated for 'Reprocess All'
def run_indexer(self, indexer_name: str):
self.indexer_client.reset_indexer(indexer_name)
self.indexer_client.run_indexer(indexer_name)
logger.info(
f" {indexer_name} is created and running. If queries return no results, please wait a bit and try again."
)

def indexer_exists(self, indexer_name: str):
return indexer_name in [
name for name in self.indexer_client.get_indexer_names()
]
4 changes: 2 additions & 2 deletions code/backend/pages/01_Ingest_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
st.markdown(mod_page_style, unsafe_allow_html=True)


def remote_convert_files_and_add_embeddings():
def reprocess_all():
backend_url = urllib.parse.urljoin(
env_helper.BACKEND_URL, "/api/BatchStartProcessing"
)
Expand Down Expand Up @@ -103,7 +103,7 @@ def add_url_embeddings(urls: list[str]):
with col3:
st.button(
"Reprocess all documents in the Azure Storage account",
on_click=remote_convert_files_and_add_embeddings,
on_click=reprocess_all,
)

with st.expander("Add URLs to the knowledge base", expanded=True):
Expand Down
57 changes: 55 additions & 2 deletions code/tests/test_BatchStartProcessing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
import sys
import os
import pytest
from unittest.mock import call, patch, Mock

sys.path.append(os.path.join(os.path.dirname(sys.path[0]), "backend", "batch"))

from backend.batch.BatchStartProcessing import batch_start_processing # noqa: E402


@pytest.fixture(autouse=True)
def env_helper_mock():
with patch("backend.batch.BatchStartProcessing.EnvHelper") as mock:
env_helper = mock.return_value
env_helper.AZURE_SEARCH_INDEXER_NAME = "AZURE_SEARCH_INDEXER_NAME"

yield env_helper


@pytest.fixture(autouse=True)
def mock_integrated_vectorization_embedder():
with patch(
"backend.batch.BatchStartProcessing.IntegratedVectorizationEmbedder"
) as mock:
yield mock


@patch("backend.batch.BatchStartProcessing.create_queue_client")
@patch("backend.batch.BatchStartProcessing.AzureBlobStorageClient")
def test_batch_start_processing_processes_all(
mock_blob_storage_client, mock_create_queue_client
mock_blob_storage_client, mock_create_queue_client, env_helper_mock
):
# given
mock_http_request = Mock()
Expand All @@ -22,7 +40,7 @@ def test_batch_start_processing_processes_all(
{"filename": "file_name_one", "embeddings_added": False},
{"filename": "file_name_two", "embeddings_added": True},
]

env_helper_mock.AZURE_SEARCH_USE_INTEGRATED_VECTORIZATION = False
# when
response = batch_start_processing.build().get_user_function()(mock_http_request)

Expand All @@ -34,3 +52,38 @@ def test_batch_start_processing_processes_all(
assert len(send_message_calls) == 2
assert send_message_calls[0] == call(b'{"filename": "file_name_one"}')
assert send_message_calls[1] == call(b'{"filename": "file_name_two"}')


@patch("backend.batch.BatchStartProcessing.create_queue_client")
@patch("backend.batch.BatchStartProcessing.AzureBlobStorageClient")
def test_batch_start_processing_processes_all_integrated_vectorization(
mock_blob_storage_client,
mock_create_queue_client,
mock_integrated_vectorization_embedder,
env_helper_mock,
):
# given
mock_http_request = Mock()
mock_http_request.params = dict()

mock_queue_client = Mock()
mock_create_queue_client.return_value = mock_queue_client
mock_blob_storage_client.return_value.get_all_files.return_value = [
{"filename": "file_name_one", "embeddings_added": False},
{"filename": "file_name_two", "embeddings_added": True},
]
mock_integrated_vectorization_embedder.return_value.reprocess_all.return_value = (
None
)
env_helper_mock.AZURE_SEARCH_USE_INTEGRATED_VECTORIZATION = True

# when
response = batch_start_processing.build().get_user_function()(mock_http_request)

# then
assert response.status_code == 200
assert response.get_body() == b"Conversion started successfully for 2 documents."

send_message_calls = mock_queue_client.send_message.call_args_list
assert len(send_message_calls) == 0
mock_integrated_vectorization_embedder.return_value.reprocess_all.assert_called_once()
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,48 @@ def test_process_using_integrated_vectorization(
azure_search_iv_skillset_helper_mock.return_value.create_skillset.return_value.skills
is not None
)


def test_reprocess_all_runs_indexer_when_indexer_exists(
env_helper_mock: MagicMock,
llm_helper_mock: MagicMock,
azure_search_iv_index_helper_mock: MagicMock,
azure_search_iv_datasource_helper_mock: MagicMock,
azure_search_iv_skillset_helper_mock: MagicMock,
azure_search_iv_indexer_helper_mock: MagicMock,
mock_config_helper,
):
# Given
azure_search_iv_indexer_helper_mock.indexer_exists.return_value = True
azure_search_iv_indexer_helper_mock.run_indexer.return_value = "Indexer result"

# When
embedder = IntegratedVectorizationEmbedder(env_helper_mock)
embedder.reprocess_all()

# Then
azure_search_iv_indexer_helper_mock.return_value.run_indexer.assert_called_once_with(
env_helper_mock.AZURE_SEARCH_INDEXER_NAME
)
azure_search_iv_indexer_helper_mock.return_value.create_or_update_indexer.assert_not_called()


def test_reprocess_all_calls_process_using_integrated_vectorization_when_indexer_does_not_exist(
env_helper_mock: MagicMock,
llm_helper_mock: MagicMock,
azure_search_iv_index_helper_mock: MagicMock,
azure_search_iv_datasource_helper_mock: MagicMock,
azure_search_iv_skillset_helper_mock: MagicMock,
azure_search_iv_indexer_helper_mock: MagicMock,
mock_config_helper,
):
# Given
azure_search_iv_indexer_helper_mock.return_value.indexer_exists.return_value = False

# When
embedder = IntegratedVectorizationEmbedder(env_helper_mock)
embedder.reprocess_all()

# Then
azure_search_iv_indexer_helper_mock.return_value.run_indexer.assert_not_called()
azure_search_iv_indexer_helper_mock.return_value.create_or_update_indexer.assert_called_once()
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,43 @@ def test_create_or_update_indexer_rbac(
data_source_name=env_helper_mock.AZURE_SEARCH_DATASOURCE_NAME,
field_mappings=ANY,
)


def test_run_indexer(
env_helper_mock: MagicMock,
search_indexer_client_mock: MagicMock,
search_indexer_mock: MagicMock,
):
# given
indexer_name = "indexer_name"
azure_search_indexer = AzureSearchIndexer(env_helper_mock)

# when
azure_search_indexer.run_indexer(indexer_name)

# then
azure_search_indexer.indexer_client.reset_indexer.assert_called_once_with(
indexer_name
)
azure_search_indexer.indexer_client.run_indexer.assert_called_once_with(
indexer_name
)


def test_indexer_exists(
env_helper_mock: MagicMock,
search_indexer_client_mock: MagicMock,
search_indexer_mock: MagicMock,
):
# given
indexer_name = "indexer_name"
azure_search_indexer = AzureSearchIndexer(env_helper_mock)
search_indexer_client_mock.return_value.get_indexer_names.return_value = [
"indexer_name"
]

# when
result = azure_search_indexer.indexer_exists(indexer_name)

# then
assert result is True

0 comments on commit 89e328b

Please sign in to comment.