Skip to content

Commit

Permalink
Merge branch 'main' into mtrag
Browse files Browse the repository at this point in the history
  • Loading branch information
elronbandel authored Jan 23, 2025
2 parents dff8da3 + 45f0db8 commit 242103c
Showing 1 changed file with 19 additions and 32 deletions.
51 changes: 19 additions & 32 deletions src/unitxt/assistant/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import importlib
import json
import logging
import os
Expand All @@ -10,12 +9,25 @@
import pandas as pd
import streamlit as st
import torch
from litellm import AuthenticationError
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
from transformers import AutoTokenizer

logger = logging.getLogger("unitxt-assistance")


logger = logging.getLogger("unitxt-assistance")

original_validate_environment = IBMWatsonXMixin.validate_environment


def wrapped_validate_environment(self, *args, **kwargs):
kwargs = {**kwargs, "headers": {}}
return original_validate_environment(self, *args, **kwargs)


IBMWatsonXMixin.validate_environment = wrapped_validate_environment


@st.cache_resource
def load_data():
current_file_dir = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -25,36 +37,9 @@ def load_data():
return metadata_df, embeddings


def get_embedding_with_retry(model, input, max_retries=3):
"""This function calls the litellm.embedding method and handles token expiration.
It will retry the call up to `max_retries` times if an AuthenticationError is raised.
"""
retries = 0
actual_exception = None
while retries < max_retries:
try:
return litellm.embedding(model=model, input=input)

except AuthenticationError as e:
actual_exception = e
retries += 1
logger.info(
f"Authentication error: {e}. Retrying... ({retries}/{max_retries})"
)
importlib.reload(
litellm
) # Reload the litellm module to clear any cached state

# If all retries fail, raise an error
raise Exception(
f"Failed to get embedding after {max_retries} attempts. Exception: {actual_exception}"
)


def search(query, metadata_df, embeddings, max_tokens=5000, min_text_length=50):
# Generate embedding for the query using litellm
response = get_embedding_with_retry(
response = litellm.embedding(
model="watsonx/intfloat/multilingual-e5-large",
input=[query],
)
Expand Down Expand Up @@ -119,8 +104,10 @@ def generate_response(messages, metadata_df, embeddings, model, max_tokens=500):

system_prompt = (
"Your job is to assist users with Unitxt Library and Catalog. "
"Refuse to do anything else.\n\n"
"# Answer only based on the following Information:\n\n" + context
"Refuse to do anything else. "
"Based your answers on the information below add link to its origin Path with this format: https://www.unitxt.ai/en/latest/{path}.html"
"\n\n"
"# Answer only based on the following information:\n\n" + context
)

messages = [
Expand Down

0 comments on commit 242103c

Please sign in to comment.