Skip to content

Commit

Permalink
Run black against the code base
Browse files Browse the repository at this point in the history
  • Loading branch information
ross-p-smith committed Feb 8, 2024
1 parent 97eb155 commit 1701d36
Show file tree
Hide file tree
Showing 12 changed files with 390 additions and 195 deletions.
97 changes: 63 additions & 34 deletions code/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from flask import Flask, Response, request, jsonify
from dotenv import load_dotenv
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from azure.keyvault.secrets import SecretClient
import sys
from backend.batch.utilities.helpers.EnvHelper import EnvHelper

# Fixing MIME types for static files under Windows
mimetypes.add_type("application/javascript", ".js")
Expand All @@ -28,13 +28,14 @@
def static_file(path):
return app.send_static_file(path)

from backend.batch.utilities.helpers.EnvHelper import EnvHelper

env_helper: EnvHelper = EnvHelper()
AZURE_AUTH_TYPE = env_helper.AZURE_AUTH_TYPE
AZURE_SEARCH_KEY = env_helper.AZURE_SEARCH_KEY
AZURE_OPENAI_KEY = env_helper.AZURE_OPENAI_KEY
AZURE_SPEECH_KEY = env_helper.AZURE_SPEECH_KEY


@app.route("/api/config", methods=["GET"])
def get_config():
# Retrieve the environment variables or other configuration data
Expand Down Expand Up @@ -80,7 +81,9 @@ def get_config():
AZURE_OPENAI_MODEL_NAME = os.environ.get(
"AZURE_OPENAI_MODEL_NAME", "gpt-35-turbo"
) # Name of the model, e.g. 'gpt-35-turbo' or 'gpt-4'
AZURE_TOKEN_PROVIDER = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
AZURE_TOKEN_PROVIDER = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)

SHOULD_STREAM = True if AZURE_OPENAI_STREAM.lower() == "true" else False

Expand All @@ -105,9 +108,11 @@ def prepare_body_headers_with_data(request):
"temperature": AZURE_OPENAI_TEMPERATURE,
"max_tokens": AZURE_OPENAI_MAX_TOKENS,
"top_p": AZURE_OPENAI_TOP_P,
"stop": AZURE_OPENAI_STOP_SEQUENCE.split("|")
if AZURE_OPENAI_STOP_SEQUENCE
else None,
"stop": (
AZURE_OPENAI_STOP_SEQUENCE.split("|")
if AZURE_OPENAI_STOP_SEQUENCE
else None
),
"stream": SHOULD_STREAM,
"dataSources": [
{
Expand All @@ -117,30 +122,42 @@ def prepare_body_headers_with_data(request):
"key": AZURE_SEARCH_KEY,
"indexName": AZURE_SEARCH_INDEX,
"fieldsMapping": {
"contentField": AZURE_SEARCH_CONTENT_COLUMNS.split("|")
if AZURE_SEARCH_CONTENT_COLUMNS
else [],
"titleField": AZURE_SEARCH_TITLE_COLUMN
if AZURE_SEARCH_TITLE_COLUMN
else None,
"urlField": AZURE_SEARCH_URL_COLUMN
if AZURE_SEARCH_URL_COLUMN
else None,
"filepathField": AZURE_SEARCH_FILENAME_COLUMN
if AZURE_SEARCH_FILENAME_COLUMN
else None,
"contentField": (
AZURE_SEARCH_CONTENT_COLUMNS.split("|")
if AZURE_SEARCH_CONTENT_COLUMNS
else []
),
"titleField": (
AZURE_SEARCH_TITLE_COLUMN
if AZURE_SEARCH_TITLE_COLUMN
else None
),
"urlField": (
AZURE_SEARCH_URL_COLUMN if AZURE_SEARCH_URL_COLUMN else None
),
"filepathField": (
AZURE_SEARCH_FILENAME_COLUMN
if AZURE_SEARCH_FILENAME_COLUMN
else None
),
},
"inScope": True
if AZURE_SEARCH_ENABLE_IN_DOMAIN.lower() == "true"
else False,
"inScope": (
True
if AZURE_SEARCH_ENABLE_IN_DOMAIN.lower() == "true"
else False
),
"topNDocuments": AZURE_SEARCH_TOP_K,
"queryType": "semantic"
if AZURE_SEARCH_USE_SEMANTIC_SEARCH.lower() == "true"
else "simple",
"semanticConfiguration": AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG
if AZURE_SEARCH_USE_SEMANTIC_SEARCH.lower() == "true"
and AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG
else "",
"queryType": (
"semantic"
if AZURE_SEARCH_USE_SEMANTIC_SEARCH.lower() == "true"
else "simple"
),
"semanticConfiguration": (
AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG
if AZURE_SEARCH_USE_SEMANTIC_SEARCH.lower() == "true"
and AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG
else ""
),
"roleInformation": AZURE_OPENAI_SYSTEM_MESSAGE,
},
}
Expand Down Expand Up @@ -248,9 +265,17 @@ def stream_without_data(response):
def conversation_without_data(request):
azure_endpoint = f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/"
if AZURE_AUTH_TYPE == "rbac":
openai_client = AzureOpenAI(azure_endpoint=azure_endpoint, api_version=AZURE_OPENAI_API_VERSION, azure_ad_token_provider=AZURE_TOKEN_PROVIDER)
openai_client = AzureOpenAI(
azure_endpoint=azure_endpoint,
api_version=AZURE_OPENAI_API_VERSION,
azure_ad_token_provider=AZURE_TOKEN_PROVIDER,
)
else:
openai_client = AzureOpenAI(azure_endpoint=azure_endpoint, api_version=AZURE_OPENAI_API_VERSION, api_key=AZURE_OPENAI_KEY)
openai_client = AzureOpenAI(
azure_endpoint=azure_endpoint,
api_version=AZURE_OPENAI_API_VERSION,
api_key=AZURE_OPENAI_KEY,
)

request_messages = request.json["messages"]
messages = [{"role": "system", "content": AZURE_OPENAI_SYSTEM_MESSAGE}]
Expand All @@ -265,9 +290,11 @@ def conversation_without_data(request):
temperature=float(AZURE_OPENAI_TEMPERATURE),
max_tokens=int(AZURE_OPENAI_MAX_TOKENS),
top_p=float(AZURE_OPENAI_TOP_P),
stop=AZURE_OPENAI_STOP_SEQUENCE.split("|")
if AZURE_OPENAI_STOP_SEQUENCE
else None,
stop=(
AZURE_OPENAI_STOP_SEQUENCE.split("|")
if AZURE_OPENAI_STOP_SEQUENCE
else None
),
stream=SHOULD_STREAM,
)

Expand Down Expand Up @@ -322,7 +349,9 @@ def conversation_azure_byod():

@app.route("/api/conversation/custom", methods=["GET", "POST"])
def conversation_custom():
from backend.batch.utilities.helpers.OrchestratorHelper import Orchestrator, OrchestrationSettings
from backend.batch.utilities.helpers.OrchestratorHelper import (
Orchestrator,
)

message_orchestrator = Orchestrator()

Expand Down
83 changes: 54 additions & 29 deletions code/backend/batch/utilities/helpers/AzureBlobStorageHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,55 @@ def __init__(

self.auth_type = env_helper.AZURE_AUTH_TYPE
if self.auth_type == "rbac":
self.account_name = account_name if account_name else env_helper.AZURE_BLOB_ACCOUNT_NAME
self.container_name : str = container_name if container_name else env_helper.AZURE_BLOB_CONTAINER_NAME
self.account_name = (
account_name if account_name else env_helper.AZURE_BLOB_ACCOUNT_NAME
)
self.container_name: str = (
container_name
if container_name
else env_helper.AZURE_BLOB_CONTAINER_NAME
)
credential = DefaultAzureCredential()
account_url = f"https://{self.account_name}.blob.core.windows.net/"
self.blob_service_client = BlobServiceClient(account_url=account_url, credential=credential)
self.user_delegation_key = self.request_user_delegation_key(blob_service_client=self.blob_service_client)
self.blob_service_client = BlobServiceClient(
account_url=account_url, credential=credential
)
self.user_delegation_key = self.request_user_delegation_key(
blob_service_client=self.blob_service_client
)
self.account_key = None
else:
self.account_name = account_name if account_name else env_helper.AZURE_BLOB_ACCOUNT_NAME
self.account_key = account_key if account_key else env_helper.AZURE_BLOB_ACCOUNT_KEY
self.account_name = (
account_name if account_name else env_helper.AZURE_BLOB_ACCOUNT_NAME
)
self.account_key = (
account_key if account_key else env_helper.AZURE_BLOB_ACCOUNT_KEY
)
self.user_delegation_key = None
self.connect_str = f"DefaultEndpointsProtocol=https;AccountName={self.account_name};AccountKey={self.account_key};EndpointSuffix=core.windows.net"
self.container_name : str = container_name if container_name else env_helper.AZURE_BLOB_CONTAINER_NAME
self.blob_service_client : BlobServiceClient = BlobServiceClient.from_connection_string(self.connect_str)

def request_user_delegation_key(self, blob_service_client: BlobServiceClient) -> UserDelegationKey:
self.container_name: str = (
container_name
if container_name
else env_helper.AZURE_BLOB_CONTAINER_NAME
)
self.blob_service_client: BlobServiceClient = (
BlobServiceClient.from_connection_string(self.connect_str)
)

def request_user_delegation_key(
self, blob_service_client: BlobServiceClient
) -> UserDelegationKey:
# Get a user delegation key that's valid for 1 day
delegation_key_start_time = datetime.utcnow()
delegation_key_expiry_time = delegation_key_start_time + timedelta(days=1)

user_delegation_key = blob_service_client.get_user_delegation_key(
key_start_time=delegation_key_start_time,
key_expiry_time=delegation_key_expiry_time
key_expiry_time=delegation_key_expiry_time,
)
return user_delegation_key
def upload_file(self, bytes_data, file_name, content_type='application/pdf'):

def upload_file(self, bytes_data, file_name, content_type="application/pdf"):
# Create a blob client using the local file name as the name for the blob
blob_client = self.blob_service_client.get_blob_client(
container=self.container_name, blob=file_name
Expand Down Expand Up @@ -117,21 +139,22 @@ def get_all_files(self):
files.append(
{
"filename": blob.name,
"converted": blob.metadata.get("converted", "false") == "true"
if blob.metadata
else False,
"embeddings_added": blob.metadata.get(
"embeddings_added", "false"
)
== "true"
if blob.metadata
else False,
"converted": (
blob.metadata.get("converted", "false") == "true"
if blob.metadata
else False
),
"embeddings_added": (
blob.metadata.get("embeddings_added", "false") == "true"
if blob.metadata
else False
),
"fullpath": f"https://{self.account_name}.blob.core.windows.net/{self.container_name}/{blob.name}?{sas}",
"converted_filename": blob.metadata.get(
"converted_filename", ""
)
if blob.metadata
else "",
"converted_filename": (
blob.metadata.get("converted_filename", "")
if blob.metadata
else ""
),
"converted_path": "",
}
)
Expand All @@ -150,7 +173,9 @@ def get_all_files(self):

def upsert_blob_metadata(self, file_name, metadata):
if self.auth_type == "rbac":
blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=file_name)
blob_client = self.blob_service_client.get_blob_client(
container=self.container_name, blob=file_name
)
else:
blob_client = BlobServiceClient.from_connection_string(
self.connect_str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ def __init__(self) -> None:
)
if env_helper.AZURE_AUTH_TYPE == "rbac":
self.document_analysis_client = DocumentAnalysisClient(
endpoint=self.AZURE_FORM_RECOGNIZER_ENDPOINT,
credential=DefaultAzureCredential(),
endpoint=self.AZURE_FORM_RECOGNIZER_ENDPOINT,
credential=DefaultAzureCredential(),
headers={
"x-ms-useragent": "chat-with-your-data-solution-accelerator/1.0.0"
},
)
else:
self.AZURE_FORM_RECOGNIZER_KEY : str = env_helper.AZURE_FORM_RECOGNIZER_KEY
self.AZURE_FORM_RECOGNIZER_KEY: str = env_helper.AZURE_FORM_RECOGNIZER_KEY

self.document_analysis_client = DocumentAnalysisClient(
endpoint=self.AZURE_FORM_RECOGNIZER_ENDPOINT,
credential=AzureKeyCredential(self.AZURE_FORM_RECOGNIZER_KEY),
Expand Down
16 changes: 10 additions & 6 deletions code/backend/batch/utilities/helpers/AzureSearchHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ def get_vector_store(self):

return AzureSearch(
azure_search_endpoint=env_helper.AZURE_SEARCH_SERVICE,
azure_search_key=env_helper.AZURE_SEARCH_KEY
if env_helper.AZURE_AUTH_TYPE == "keys"
else None,
azure_search_key=(
env_helper.AZURE_SEARCH_KEY
if env_helper.AZURE_AUTH_TYPE == "keys"
else None
),
index_name=env_helper.AZURE_SEARCH_INDEX,
embedding_function=llm_helper.get_embedding_model().embed_query,
fields=fields,
Expand Down Expand Up @@ -139,9 +141,11 @@ def get_conversation_logger(self):

return AzureSearch(
azure_search_endpoint=env_helper.AZURE_SEARCH_SERVICE,
azure_search_key=env_helper.AZURE_SEARCH_KEY
if env_helper.AZURE_AUTH_TYPE == "keys"
else None,
azure_search_key=(
env_helper.AZURE_SEARCH_KEY
if env_helper.AZURE_AUTH_TYPE == "keys"
else None
),
index_name=env_helper.AZURE_SEARCH_CONVERSATIONS_LOG_INDEX,
embedding_function=llm_helper.get_embedding_model().embed_query,
fields=fields,
Expand Down
Loading

0 comments on commit 1701d36

Please sign in to comment.