Skip to content

Commit

Permalink
feat: Generate and add image captions to search index when image is i…
Browse files Browse the repository at this point in the history
…ngested. (#928)

Co-authored-by: Chinedum Echeta <60179183+cecheta@users.noreply.github.com>
Co-authored-by: Ross Smith <ross-p-smith@users.noreply.github.com>
  • Loading branch information
3 people authored May 16, 2024
1 parent f27b68e commit b8e34aa
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 28 deletions.
53 changes: 44 additions & 9 deletions code/backend/batch/utilities/helpers/embedders/push_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

class PushEmbedder(EmbedderBase):
def __init__(self, blob_client: AzureBlobStorageClient, env_helper: EnvHelper):
self.env_helper = env_helper
self.llm_helper = LLMHelper()
self.azure_search_helper = AzureSearchHelper()
self.azure_computer_vision_client = AzureComputerVisionClient(env_helper)
Expand Down Expand Up @@ -59,13 +60,15 @@ def __embed(
in self.config.get_advanced_image_processing_image_types()
):
logger.warning("Advanced image processing is not supported yet")
image_vectors = self.azure_computer_vision_client.vectorize_image(
source_url
)
logger.info("Image vectors: " + str(image_vectors))

caption = self.__generate_image_caption(source_url)
caption_vector = self.llm_helper.generate_embeddings(caption)

image_vector = self.azure_computer_vision_client.vectorize_image(source_url)
documents_to_upload.append(
self.__create_image_document(source_url, image_vectors)
self.__create_image_document(
source_url, image_vector, caption, caption_vector
)
)
else:
documents: List[SourceDocument] = self.document_loading.load(
Expand All @@ -85,6 +88,32 @@ def __embed(
logger.error("Failed to upload documents to search index")
raise Exception(response)

def __generate_image_caption(self, source_url):
model = self.env_helper.AZURE_OPENAI_VISION_MODEL
caption_system_message = """You are an assistant that generates rich descriptions of images.
You need to be accurate in the information you extract and detailed in the descriptons you generate.
Do not abbreviate anything and do not shorten sentances. Explain the image completely.
If you are provided with an image of a flow chart, describe the flow chart in detail.
If the image is mostly text, use OCR to extract the text as it is displayed in the image."""

messages = [
{"role": "system", "content": caption_system_message},
{
"role": "user",
"content": [
{
"text": "Describe this image in detail. Limit the response to 500 words.",
"type": "text",
},
{"image_url": source_url, "type": "image_url"},
],
},
]

response = self.llm_helper.get_chat_completion(messages, model)
caption = response.choices[0].message.content
return caption

def __convert_to_search_document(self, document: SourceDocument):
embedded_content = self.llm_helper.generate_embeddings(document.content)
metadata = {
Expand All @@ -111,7 +140,13 @@ def __generate_document_id(self, source_url: str) -> str:
hash_key = hashlib.sha1(f"{source_url}_1".encode("utf-8")).hexdigest()
return f"doc_{hash_key}"

def __create_image_document(self, source_url: str, image_vectors: List[float]):
def __create_image_document(
self,
source_url: str,
image_vector: List[float],
content: str,
content_vector: List[float],
):
parsed_url = urlparse(source_url)

file_url = parsed_url.scheme + "://" + parsed_url.netloc + parsed_url.path
Expand All @@ -127,9 +162,9 @@ def __create_image_document(self, source_url: str, image_vectors: List[float]):

return {
"id": document_id,
"content": "",
"content_vector": [],
"image_vector": image_vectors,
"content": content,
"content_vector": content_vector,
"image_vector": image_vector,
"metadata": json.dumps(
{
"id": document_id,
Expand Down
1 change: 1 addition & 0 deletions code/backend/batch/utilities/helpers/env_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __load_config(self, **kwargs) -> None:
self.AZURE_OPENAI_MODEL_NAME = os.getenv(
"AZURE_OPENAI_MODEL_NAME", "gpt-35-turbo"
)
self.AZURE_OPENAI_VISION_MODEL = os.getenv("AZURE_OPENAI_VISION_MODEL", "gpt-4")
self.AZURE_OPENAI_TEMPERATURE = os.getenv("AZURE_OPENAI_TEMPERATURE", "0")
self.AZURE_OPENAI_TOP_P = os.getenv("AZURE_OPENAI_TOP_P", "1.0")
self.AZURE_OPENAI_MAX_TOKENS = os.getenv("AZURE_OPENAI_MAX_TOKENS", "1000")
Expand Down
4 changes: 2 additions & 2 deletions code/backend/batch/utilities/helpers/llm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def get_chat_completion_with_functions(
function_call=function_call,
)

def get_chat_completion(self, messages: list[dict]):
def get_chat_completion(self, messages: list[dict], model: str | None = None):
return self.openai_client.chat.completions.create(
model=self.llm_model,
model=model or self.llm_model,
messages=messages,
)

Expand Down
1 change: 1 addition & 0 deletions code/tests/functional/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class AppConfig:
"AZURE_OPENAI_MAX_TOKENS": "1000",
"AZURE_OPENAI_MODEL": "some-openai-model",
"AZURE_OPENAI_MODEL_NAME": "some-openai-model-name",
"AZURE_OPENAI_VISION_MODEL": "some-openai-vision-model",
"AZURE_OPENAI_RESOURCE": "some-openai-resource",
"AZURE_OPENAI_STREAM": "True",
"AZURE_OPENAI_STOP_SEQUENCE": "",
Expand Down
27 changes: 27 additions & 0 deletions code/tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,33 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig):
}
)

httpserver.expect_request(
f"/openai/deployments/{app_config.get('AZURE_OPENAI_VISION_MODEL')}/chat/completions",
method="POST",
).respond_with_json(
{
"id": "chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9",
"object": "chat.completion",
"created": 1679072642,
"model": app_config.get("AZURE_OPENAI_VISION_MODEL"),
"usage": {
"prompt_tokens": 58,
"completion_tokens": 68,
"total_tokens": 126,
},
"choices": [
{
"message": {
"role": "assistant",
"content": "This is a caption for the image",
},
"finish_reason": "stop",
"index": 0,
}
],
}
)

httpserver.expect_request(
f"/indexes('{app_config.get('AZURE_SEARCH_CONVERSATIONS_LOG_INDEX')}')/docs/search.index",
method="POST",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def test_image_passed_to_computer_vision_to_generate_image_embeddings(
RequestMatcher(
path=COMPUTER_VISION_VECTORIZE_IMAGE_PATH,
method=COMPUTER_VISION_VECTORIZE_IMAGE_REQUEST_METHOD,
json={
"url": ANY,
},
query_string="api-version=2024-02-01&model-version=2023-04-15",
headers={
"Content-Type": "application/json",
Expand All @@ -115,7 +118,87 @@ def test_image_passed_to_computer_vision_to_generate_image_embeddings(
)[0]

assert request.get_json()["url"].startswith(
f"{app_config.get('AZURE_COMPUTER_VISION_ENDPOINT')}{app_config.get('AZURE_BLOB_CONTAINER_NAME')}/{FILE_NAME}"
f"{app_config.get('AZURE_STORAGE_ACCOUNT_ENDPOINT')}{app_config.get('AZURE_BLOB_CONTAINER_NAME')}/{FILE_NAME}"
)


def test_image_passed_to_llm_to_generate_caption(
message: QueueMessage, httpserver: HTTPServer, app_config: AppConfig
):
# when
batch_push_results.build().get_user_function()(message)

# then
request = verify_request_made(
mock_httpserver=httpserver,
request_matcher=RequestMatcher(
path=f"/openai/deployments/{app_config.get('AZURE_OPENAI_VISION_MODEL')}/chat/completions",
method="POST",
json={
"messages": [
{
"role": "system",
"content": """You are an assistant that generates rich descriptions of images.
You need to be accurate in the information you extract and detailed in the descriptons you generate.
Do not abbreviate anything and do not shorten sentances. Explain the image completely.
If you are provided with an image of a flow chart, describe the flow chart in detail.
If the image is mostly text, use OCR to extract the text as it is displayed in the image.""",
},
{
"role": "user",
"content": [
{
"text": "Describe this image in detail. Limit the response to 500 words.",
"type": "text",
},
{"image_url": ANY, "type": "image_url"},
],
},
],
"model": app_config.get("AZURE_OPENAI_VISION_MODEL"),
},
headers={
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {app_config.get('AZURE_OPENAI_API_KEY')}",
"Api-Key": app_config.get("AZURE_OPENAI_API_KEY"),
},
query_string="api-version=2024-02-01",
times=1,
),
)[0]

assert request.get_json()["messages"][1]["content"][1]["image_url"].startswith(
f"{app_config.get('AZURE_STORAGE_ACCOUNT_ENDPOINT')}{app_config.get('AZURE_BLOB_CONTAINER_NAME')}/{FILE_NAME}"
)


def test_embeddings_generated_for_caption(
message: QueueMessage, httpserver: HTTPServer, app_config: AppConfig
):
# when
batch_push_results.build().get_user_function()(message)

# then
verify_request_made(
mock_httpserver=httpserver,
request_matcher=RequestMatcher(
path=f"/openai/deployments/{app_config.get('AZURE_OPENAI_EMBEDDING_MODEL')}/embeddings",
method="POST",
json={
"input": ["This is a caption for the image"],
"model": app_config.get("AZURE_OPENAI_EMBEDDING_MODEL"),
"encoding_format": "base64",
},
headers={
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {app_config.get('AZURE_OPENAI_API_KEY')}",
"Api-Key": app_config.get("AZURE_OPENAI_API_KEY"),
},
query_string="api-version=2024-02-01",
times=1,
),
)


Expand Down Expand Up @@ -343,8 +426,11 @@ def test_makes_correct_call_to_store_documents_in_search_index(
"value": [
{
"id": expected_id,
"content": "",
"content_vector": [],
"content": "This is a caption for the image",
"content_vector": [
0.018990106880664825,
-0.0073809814639389515,
],
"image_vector": [1.0, 2.0, 3.0],
"metadata": json.dumps(
{
Expand Down
66 changes: 56 additions & 10 deletions code/tests/utilities/helpers/test_push_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ def llm_helper_mock():
llm_helper.get_embedding_model.return_value.embed_query.return_value = [
0
] * 1536
mock_completion = llm_helper.get_chat_completion.return_value
choice = MagicMock()
choice.message.content = "This is a caption for an image"
mock_completion.choices = [choice]

llm_helper.generate_embeddings.return_value = [123]
yield mock
yield llm_helper


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -129,7 +134,46 @@ def test_embed_file_advanced_image_processing_vectorizes_image(
)


def test_embed_file_advanced_image_processing_uses_vision_model_for_captioning(
llm_helper_mock,
):
# given
env_helper_mock = MagicMock()
env_helper_mock.AZURE_OPENAI_VISION_MODEL = "gpt-4"
push_embedder = PushEmbedder(MagicMock(), env_helper_mock)
source_url = "http://localhost:8080/some-file-name.jpg"

# when
push_embedder.embed_file(source_url, "some-file-name.jpg")

# then
llm_helper_mock.get_chat_completion.assert_called_once_with(
[
{
"role": "system",
"content": """You are an assistant that generates rich descriptions of images.
You need to be accurate in the information you extract and detailed in the descriptons you generate.
Do not abbreviate anything and do not shorten sentances. Explain the image completely.
If you are provided with an image of a flow chart, describe the flow chart in detail.
If the image is mostly text, use OCR to extract the text as it is displayed in the image.""",
},
{
"role": "user",
"content": [
{
"text": "Describe this image in detail. Limit the response to 500 words.",
"type": "text",
},
{"image_url": source_url, "type": "image_url"},
],
},
],
env_helper_mock.AZURE_OPENAI_VISION_MODEL,
)


def test_embed_file_advanced_image_processing_stores_embeddings_in_search_index(
llm_helper_mock,
azure_computer_vision_mock,
azure_search_helper_mock: MagicMock,
):
Expand All @@ -153,12 +197,16 @@ def test_embed_file_advanced_image_processing_stores_embeddings_in_search_index(
hash_key = hashlib.sha1(f"{host_path}_1".encode("utf-8")).hexdigest()
expected_id = f"doc_{hash_key}"

llm_helper_mock.generate_embeddings.assert_called_once_with(
"This is a caption for an image"
)

azure_search_helper_mock.return_value.get_search_client.return_value.upload_documents.assert_called_once_with(
[
{
"id": expected_id,
"content": "",
"content_vector": [],
"content": "This is a caption for an image",
"content_vector": [123],
"image_vector": image_embeddings,
"metadata": json.dumps(
{
Expand Down Expand Up @@ -265,7 +313,7 @@ def test_embed_file_generates_embeddings_for_documents(llm_helper_mock):
)

# then
llm_helper_mock.return_value.generate_embeddings.assert_has_calls(
llm_helper_mock.generate_embeddings.assert_has_calls(
[call("some content"), call("some other content")]
)

Expand All @@ -291,7 +339,7 @@ def test_embed_file_stores_documents_in_search_index(
{
"id": expected_chunked_documents[0].id,
"content": expected_chunked_documents[0].content,
"content_vector": llm_helper_mock.return_value.generate_embeddings.return_value,
"content_vector": llm_helper_mock.generate_embeddings.return_value,
"metadata": json.dumps(
{
"id": expected_chunked_documents[0].id,
Expand All @@ -311,7 +359,7 @@ def test_embed_file_stores_documents_in_search_index(
{
"id": expected_chunked_documents[1].id,
"content": expected_chunked_documents[1].content,
"content_vector": llm_helper_mock.return_value.generate_embeddings.return_value,
"content_vector": llm_helper_mock.generate_embeddings.return_value,
"metadata": json.dumps(
{
"id": expected_chunked_documents[1].id,
Expand All @@ -338,10 +386,8 @@ def test_embed_file_raises_exception_on_failure(
# given
push_embedder = PushEmbedder(MagicMock(), MagicMock())

successful_indexing_result = MagicMock()
successful_indexing_result.succeeded = True
failed_indexing_result = MagicMock()
failed_indexing_result.succeeded = False
successful_indexing_result = MagicMock(succeeded=True)
failed_indexing_result = MagicMock(succeeded=False)
azure_search_helper_mock.return_value.get_search_client.return_value.upload_documents.return_value = [
successful_indexing_result,
failed_indexing_result,
Expand Down
2 changes: 1 addition & 1 deletion infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ param azureOpenAIModelCapacity int = 30
param useAdvancedImageProcessing bool = false

@description('Azure OpenAI Vision Model Deployment Name')
param azureOpenAIVisionModel string = 'gpt-4-vision'
param azureOpenAIVisionModel string = 'gpt-4'

@description('Azure OpenAI Vision Model Name')
param azureOpenAIVisionModelName string = 'gpt-4'
Expand Down
Loading

0 comments on commit b8e34aa

Please sign in to comment.