Skip to content

Commit

Permalink
feat: Include vectorized text in search queries (#953)
Browse files Browse the repository at this point in the history
  • Loading branch information
cecheta authored May 21, 2024
1 parent 290fb05 commit 8642df3
Show file tree
Hide file tree
Showing 20 changed files with 562 additions and 56 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from typing import List
from urllib.parse import urljoin
from azure.identity import DefaultAzureCredential, get_bearer_token_provider

Expand All @@ -15,6 +14,7 @@ class AzureComputerVisionClient:

__TOKEN_SCOPE = "https://cognitiveservices.azure.com/.default"
__VECTORIZE_IMAGE_PATH = "computervision/retrieval:vectorizeImage"
__VECTORIZE_TEXT_PATH = "computervision/retrieval:vectorizeText"
__RESPONSE_VECTOR_KEY = "vector"

def __init__(self, env_helper: EnvHelper) -> None:
Expand All @@ -27,15 +27,29 @@ def __init__(self, env_helper: EnvHelper) -> None:
env_helper.AZURE_COMPUTER_VISION_VECTORIZE_IMAGE_MODEL_VERSION
)

def vectorize_image(self, image_url: str) -> List[float]:
def vectorize_image(self, image_url: str) -> list[float]:
logger.info(f"Making call to computer vision to vectorize image: {image_url}")
response = self.__make_request(image_url)
response = self.__make_request(
self.__VECTORIZE_IMAGE_PATH,
body={"url": image_url},
)
self.__validate_response(response)

response_json = self.__get_json_body(response)
return self.__get_vectors(response_json)

def vectorize_text(self, text: str) -> list[float]:
logger.debug(f"Making call to computer vision to vectorize text: {text}")
response = self.__make_request(
self.__VECTORIZE_TEXT_PATH,
body={"text": text},
)
self.__validate_response(response)

response_json = self.__get_json_body(response)
return self.__get_vectors(response_json)

def __make_request(self, image_url: str) -> Response:
def __make_request(self, path: str, body) -> Response:
try:
headers = {}
if self.use_keys:
Expand All @@ -47,36 +61,36 @@ def __make_request(self, image_url: str) -> Response:
headers["Authorization"] = "Bearer " + token_provider()

return requests.post(
url=urljoin(self.host, self.__VECTORIZE_IMAGE_PATH),
url=urljoin(self.host, path),
params={
"api-version": self.api_version,
"model-version": self.model_version,
},
json={"url": image_url},
json=body,
headers=headers,
timeout=self.timeout,
)
except Exception as e:
raise Exception(f"Call to vectorize image failed: {image_url}") from e
raise Exception("Call to Azure Computer Vision failed") from e

def __validate_response(self, response: Response):
if response.status_code != 200:
raise Exception(
f"Call to vectorize image failed with status: {response.status_code} body: {response.text}"
f"Call to Azure Computer Vision failed with status: {response.status_code}, body: {response.text}"
)

def __get_json_body(self, response: Response) -> dict:
try:
return response.json()
except Exception as e:
raise Exception(
f"Call to vectorize image returned malformed response body: {response.text}",
f"Call to Azure Computer Vision returned malformed response body: {response.text}",
) from e

def __get_vectors(self, response_json: dict) -> List[float]:
def __get_vectors(self, response_json: dict) -> list[float]:
if self.__RESPONSE_VECTOR_KEY in response_json:
return response_json[self.__RESPONSE_VECTOR_KEY]
else:
raise Exception(
f"Call to vectorize image returned no vector: {response_json}"
f"Call to Azure Computer Vision returned no vector: {response_json}"
)
14 changes: 13 additions & 1 deletion code/backend/batch/utilities/helpers/azure_search_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)

from ..helpers.azure_computer_vision_client import AzureComputerVisionClient
from .llm_helper import LLMHelper
from .env_helper import EnvHelper

Expand All @@ -32,6 +34,7 @@

class AzureSearchHelper:
_search_dimension: int | None = None
_image_search_dimension: int | None = None

def __init__(self):
self.llm_helper = LLMHelper()
Expand All @@ -40,6 +43,7 @@ def __init__(self):
search_credential = self._search_credential()
self.search_client = self._create_search_client(search_credential)
self.search_index_client = self._create_search_index_client(search_credential)
self.azure_computer_vision_client = AzureComputerVisionClient(self.env_helper)

def _search_credential(self):
if self.env_helper.is_auth_type_keys():
Expand Down Expand Up @@ -75,6 +79,14 @@ def search_dimensions(self) -> int:
)
return AzureSearchHelper._search_dimension

@property
def image_search_dimensions(self) -> int:
if AzureSearchHelper._image_search_dimension is None:
AzureSearchHelper._image_search_dimension = len(
self.azure_computer_vision_client.vectorize_text("Text")
)
return AzureSearchHelper._image_search_dimension

def create_index(self):
fields = [
SimpleField(
Expand Down Expand Up @@ -128,7 +140,7 @@ def create_index(self):
name="image_vector",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
vector_search_dimensions=1024,
vector_search_dimensions=self.image_search_dimensions,
vector_search_profile_name="myHnswProfile",
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ def __embed(
and file_extension
in self.config.get_advanced_image_processing_image_types()
):
logger.warning("Advanced image processing is not supported yet")

caption = self.__generate_image_caption(source_url)
caption_vector = self.llm_helper.generate_embeddings(caption)

Expand Down
60 changes: 52 additions & 8 deletions code/backend/batch/utilities/search/azure_search_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List
from .search_handler_base import SearchHandlerBase
from ..helpers.llm_helper import LLMHelper
from ..helpers.azure_computer_vision_client import AzureComputerVisionClient
from ..helpers.azure_search_helper import AzureSearchHelper
from ..common.source_document import SourceDocument
import json
Expand All @@ -9,13 +10,12 @@


class AzureSearchHandler(SearchHandlerBase):

_ENCODER_NAME = "cl100k_base"
_VECTOR_FIELD = "content_vector"

def __init__(self, env_helper):
super().__init__(env_helper)
self.llm_helper = LLMHelper()
self.azure_computer_vision_client = AzureComputerVisionClient(env_helper)

def create_search_client(self):
return AzureSearchHelper().get_search_client()
Expand Down Expand Up @@ -66,22 +66,50 @@ def delete_files(self, files):
def query_search(self, question) -> List[SourceDocument]:
encoding = tiktoken.get_encoding(self._ENCODER_NAME)
tokenised_question = encoding.encode(question)

if self.env_helper.USE_ADVANCED_IMAGE_PROCESSING:
vectorized_question = self.azure_computer_vision_client.vectorize_text(
question
)
else:
vectorized_question = None

if self.env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH:
results = self._semantic_search(question, tokenised_question)
results = self._semantic_search(
question, tokenised_question, vectorized_question
)
else:
results = self._hybrid_search(question, tokenised_question)
results = self._hybrid_search(
question, tokenised_question, vectorized_question
)

return self._convert_to_source_documents(results)

def _semantic_search(self, question: str, tokenised_question: list[int]):
def _semantic_search(
self,
question: str,
tokenised_question: list[int],
vectorized_question: list[float] | None,
):
return self.search_client.search(
search_text=question,
vector_queries=[
VectorizedQuery(
vector=self.llm_helper.generate_embeddings(tokenised_question),
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
fields=self._VECTOR_FIELD,
)
),
*(
[
VectorizedQuery(
vector=vectorized_question,
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
fields=self._IMAGE_VECTOR_FIELD,
)
]
if vectorized_question is not None
else []
),
],
filter=self.env_helper.AZURE_SEARCH_FILTER,
query_type="semantic",
Expand All @@ -91,7 +119,12 @@ def _semantic_search(self, question: str, tokenised_question: list[int]):
top=self.env_helper.AZURE_SEARCH_TOP_K,
)

def _hybrid_search(self, question: str, tokenised_question: list[int]):
def _hybrid_search(
self,
question: str,
tokenised_question: list[int],
vectorized_question: list[float] | None,
):
return self.search_client.search(
search_text=question,
vector_queries=[
Expand All @@ -100,7 +133,18 @@ def _hybrid_search(self, question: str, tokenised_question: list[int]):
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
filter=self.env_helper.AZURE_SEARCH_FILTER,
fields=self._VECTOR_FIELD,
)
),
*(
[
VectorizedQuery(
vector=vectorized_question,
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
fields=self._IMAGE_VECTOR_FIELD,
)
]
if vectorized_question is not None
else []
),
],
query_type="simple", # this is the default value
filter=self.env_helper.AZURE_SEARCH_FILTER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _hybrid_search(self, question: str):
vector_query = VectorizableTextQuery(
text=question,
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
fields="content_vector",
fields=self._VECTOR_FIELD,
exhaustive=True,
)
return self.search_client.search(
Expand All @@ -94,7 +94,7 @@ def _semantic_search(self, question: str):
vector_query = VectorizableTextQuery(
text=question,
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
fields="content_vector",
fields=self._VECTOR_FIELD,
exhaustive=True,
)
return self.search_client.search(
Expand Down
7 changes: 5 additions & 2 deletions code/backend/batch/utilities/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
from ..search.integrated_vectorization_search_handler import (
IntegratedVectorizationSearchHandler,
)
from ..search.search_handler_base import SearchHandlerBase
from ..common.source_document import SourceDocument
from ..helpers.env_helper import EnvHelper


class Search:
@staticmethod
def get_search_handler(env_helper: EnvHelper):
def get_search_handler(env_helper: EnvHelper) -> SearchHandlerBase:
if env_helper.AZURE_SEARCH_USE_INTEGRATED_VECTORIZATION:
return IntegratedVectorizationSearchHandler(env_helper)
else:
return AzureSearchHandler(env_helper)

@staticmethod
def get_source_documents(search_handler, question) -> list[SourceDocument]:
def get_source_documents(
search_handler: SearchHandlerBase, question: str
) -> list[SourceDocument]:
return search_handler.query_search(question)
7 changes: 5 additions & 2 deletions code/backend/batch/utilities/search/search_handler_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from ..helpers.env_helper import EnvHelper

from ..common.source_document import SourceDocument
from azure.search.documents import SearchClient


class SearchHandlerBase(ABC):
_VECTOR_FIELD = "content_vector"
_IMAGE_VECTOR_FIELD = "image_vector"

def __init__(self, env_helper: EnvHelper):
self.env_helper = env_helper
self.search_client = self.create_search_client()
Expand All @@ -20,7 +23,7 @@ def get_unique_files(self, results, facet_key: str):
return []

@abstractmethod
def create_search_client(self):
def create_search_client(self) -> SearchClient:
pass

@abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions code/tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@

COMPUTER_VISION_VECTORIZE_IMAGE_PATH = "/computervision/retrieval:vectorizeImage"
COMPUTER_VISION_VECTORIZE_IMAGE_REQUEST_METHOD = "POST"
COMPUTER_VISION_VECTORIZE_TEXT_PATH = "/computervision/retrieval:vectorizeText"
COMPUTER_VISION_VECTORIZE_TEXT_REQUEST_METHOD = "POST"
7 changes: 7 additions & 0 deletions code/tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
AZURE_STORAGE_CONFIG_FILE_NAME,
COMPUTER_VISION_VECTORIZE_IMAGE_PATH,
COMPUTER_VISION_VECTORIZE_IMAGE_REQUEST_METHOD,
COMPUTER_VISION_VECTORIZE_TEXT_PATH,
COMPUTER_VISION_VECTORIZE_TEXT_REQUEST_METHOD,
)


Expand Down Expand Up @@ -128,6 +130,11 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig):
COMPUTER_VISION_VECTORIZE_IMAGE_REQUEST_METHOD,
).respond_with_json({"modelVersion": "2022-04-11", "vector": [1.0, 2.0, 3.0]})

httpserver.expect_request(
COMPUTER_VISION_VECTORIZE_TEXT_PATH,
COMPUTER_VISION_VECTORIZE_TEXT_REQUEST_METHOD,
).respond_with_json({"modelVersion": "2022-04-11", "vector": [1.0, 2.0, 3.0]})

httpserver.expect_request(
f"/indexes('{app_config.get('AZURE_SEARCH_INDEX')}')/docs/search.index",
method="POST",
Expand Down
2 changes: 2 additions & 0 deletions code/tests/functional/tests/backend_api/default/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def app_config(make_httpserver, ca):
"AZURE_CONTENT_SAFETY_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_SPEECH_REGION_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_STORAGE_ACCOUNT_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_COMPUTER_VISION_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"USE_ADVANCED_IMAGE_PROCESSING": "True",
"SSL_CERT_FILE": ca_temp_path,
"CURL_CA_BUNDLE": ca_temp_path,
}
Expand Down
Loading

0 comments on commit 8642df3

Please sign in to comment.