Skip to content

Commit

Permalink
added chat session manager
Browse files Browse the repository at this point in the history
  • Loading branch information
adeelehsan committed Nov 20, 2024
1 parent 718ae78 commit 0a485cd
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 7 deletions.
74 changes: 74 additions & 0 deletions int_tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import time
import unittest
from datetime import timedelta
from pathlib import Path

from vectara import SearchCorporaParameters, KeyedSearchCorpus, ContextConfiguration, \
GenerationParameters, CitationParameters, ChatParameters
from vectara.client import ChatSessionManager
from vectara.factory import Factory
from vectara.managers import CreateCorpusRequest


class TestChat(unittest.TestCase):

def setUp(self):
self.client = Factory().build()
request = CreateCorpusRequest(name="int-test-upload-fern", key="int-test-upload-fern")
create_response = self.client.lab_helper.create_lab_corpus(request, user_prefix=False)
self.key = create_response.key

self.client.session_manager = ChatSessionManager(session_expiry_time=timedelta(seconds=5),
cleanup_interval_in_seconds=6)

self.search_params = SearchCorporaParameters(
corpora=[
KeyedSearchCorpus(
corpus_key=self.key,
lexical_interpolation=0,
)
],
offset=1,
limit=1,
context_configuration=ContextConfiguration(
characters_before=1,
characters_after=1,
sentences_before=1,
sentences_after=1,
),
)
self.generation_params = GenerationParameters(
citations=CitationParameters(
style="none",
),
enable_factual_consistency_score=True,
)
self.chat_params = ChatParameters(store=True)

path = Path("examples/01_getting_started/resources/arxiv/2409.05865v1.pdf")
with open(path, "rb") as f:
content = f.read()
self.client.upload.file(self.key, file=(content, "application/pdf"),
metadata={"test_sdk": "ok"}, filename="test-document")

def test_chat(self):
response = self.client.chat(
query="what is vectara?",
search=self.search_params,
generation=self.generation_params,
chat=self.chat_params
)

self.assertIsNotNone(response.chat_id)
self.assertIsNotNone(response.answer)

response = self.client.chat(
query="what is vectara?",
chat_id=response.chat_id
)

self.assertIsNotNone(response.chat_id)
self.assertIsNotNone(response.answer)

def tearDown(self):
self.client.corpora.delete(corpus_key=self.key)
108 changes: 108 additions & 0 deletions int_tests/test_chat_session_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import unittest
from datetime import timedelta
import time

from vectara import SearchCorporaParameters, KeyedSearchCorpus, ContextConfiguration, CustomerSpecificReranker, \
GenerationParameters, ModelParameters, CitationParameters, ChatParameters
from vectara.client import ChatSessionManager
from vectara.core import RequestOptions


class TestChatSessionManagerIntegration(unittest.TestCase):
def setUp(self):
self.session_manager = ChatSessionManager(session_expiry_time=timedelta(seconds=5),
cleanup_interval_in_seconds=6)
self.search_params = SearchCorporaParameters(
corpora=[
KeyedSearchCorpus(
corpus_key="test_corpus_key",
metadata_filter="",
lexical_interpolation=1,
)
],
offset=1,
limit=1,
context_configuration=ContextConfiguration(
characters_before=1,
characters_after=1,
sentences_before=1,
sentences_after=1,
start_tag="%",
end_tag="%",
),
reranker=CustomerSpecificReranker(
reranker_id="test_id",
reranker_name="test",
),
)
self.generation_params = GenerationParameters(
generation_preset_name="test",
prompt_name="test",
max_used_search_results=1,
prompt_template="test",
prompt_text="test",
max_response_characters=1,
response_language="test",
model_parameters=ModelParameters(
max_tokens=1,
temperature=1.1,
frequency_penalty=1.1,
presence_penalty=1.1,
),
citations=CitationParameters(
style="none",
),
enable_factual_consistency_score=True,
)
self.chat_params = ChatParameters(store=True)
self.request_options = RequestOptions(timeout_in_seconds=100)
self.request_timeout = 1
self.request_timeout_millis = 1

def test_create_and_retrieve_session(self):
chat_id = "test_chat"

self.session_manager.create_session(chat_id, self.search_params, self.generation_params, self.chat_params,
self.request_options, self.request_timeout, self.request_timeout_millis)

session = self.session_manager.get_session(chat_id)
self.assertIsNotNone(session)
self.assertEqual(session["search"], self.search_params)
self.assertEqual(session["generation"], self.generation_params)
self.assertEqual(session["chat"], self.chat_params)
self.assertEqual(session["request_options"], self.request_options)
self.assertEqual(session["request_timeout"], self.request_timeout)
self.assertEqual(session["request_timeout_millis"], self.request_timeout_millis)

def test_session_expiration(self):
chat_id = "test_chat_expiration"

self.session_manager.create_session(chat_id, self.search_params, self.generation_params, self.chat_params)

time.sleep(7)

session = self.session_manager.get_session(chat_id)
self.assertIsNone(session)

def test_clean_expired_sessions(self):
self.session_manager.create_session("chat1", self.search_params, self.generation_params, self.chat_params)
self.session_manager.create_session("chat2", self.search_params, self.generation_params, self.chat_params)

time.sleep(7)

self.session_manager.clean_expired_sessions()

self.assertIsNone(self.session_manager.get_session("chat1"))
self.assertIsNone(self.session_manager.get_session("chat2"))

def test_threaded_cleanup(self):
chat_id = "test_threaded_cleanup"
self.session_manager.create_session(chat_id, self.search_params, self.generation_params, self.chat_params)

time.sleep(7)

session = self.session_manager.get_session(chat_id)
self.assertIsNone(session)

def tearDown(self):
self.session_manager.cleanup_thread.join(timeout=0.1)
154 changes: 148 additions & 6 deletions src/vectara/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,73 @@
from .base_client import BaseVectara, AsyncBaseVectara
import logging
import threading
import typing

from datetime import datetime, timedelta

from . import SearchCorporaParameters, GenerationParameters, ChatParameters, ChatFullResponse
from .base_client import BaseVectara, AsyncBaseVectara, OMIT
from vectara.managers.corpus import CorpusManager
from vectara.managers.upload import UploadManager, UploadWrapper
from vectara.managers.upload import UploadManager
from vectara.utils import LabHelper

from typing import Union, Optional, Callable
import logging
from typing import Union

from .core import RequestOptions


class ChatSessionManager:
def __init__(self, session_expiry_time: timedelta = timedelta(days=7),
cleanup_interval_in_seconds: int = 43200):
self.sessions = {}
self.session_expiry_time = session_expiry_time
self.lock = threading.Lock()
self.cleanup_event = threading.Event()
self.cleanup_interval_in_seconds = cleanup_interval_in_seconds # Default is set to 12 hours in
# seconds
self.cleanup_thread = threading.Thread(target=self._run_cleanup, daemon=True)
self.cleanup_thread.start()

def create_session(self, chat_id: str,
search: SearchCorporaParameters,
generation: typing.Optional[GenerationParameters] = OMIT,
chat: typing.Optional[ChatParameters] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
request_timeout: typing.Optional[int] = None,
request_timeout_millis: typing.Optional[int] = None,
):
with self.lock:
self.sessions[chat_id] = {
"search": search,
"generation": generation,
"chat": chat,
"created_at": datetime.now(),
"request_timeout": request_timeout,
"request_timeout_millis": request_timeout_millis,
"request_options": request_options
}

def get_session(self, chat_id: str):
with self.lock:
session = self.sessions.get(chat_id)
if not session:
return None

return session

def clean_expired_sessions(self):
with self.lock:
now = datetime.now()
expired_chat_ids = [
chat_id for chat_id, session in self.sessions.items()
if now - session["created_at"] > self.session_expiry_time
]
for chat_id in expired_chat_ids:
del self.sessions[chat_id]

def _run_cleanup(self):
while not self.cleanup_event.wait(self.cleanup_interval_in_seconds):
self.clean_expired_sessions()


class Vectara(BaseVectara):
"""
Expand All @@ -19,6 +82,7 @@ class Vectara(BaseVectara):
def __init__(self, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.session_manager = ChatSessionManager()
self.logger = logging.getLogger(self.__class__.__name__)
self.corpus_manager: Union[None, CorpusManager] = None
self.upload_manager: Union[None, UploadManager] = None
Expand All @@ -33,5 +97,83 @@ def set_upload_manager(self, upload_manager: UploadManager) -> None:
def set_lab_helper(self, lab_helper: LabHelper) -> None:
self.lab_helper = lab_helper

class AsyncVectara(AsyncBaseVectara):
pass
def chat(
self,
*,
query: str,
search: typing.Optional[SearchCorporaParameters] = OMIT,
request_timeout: typing.Optional[int] = None,
request_timeout_millis: typing.Optional[int] = None,
generation: typing.Optional[GenerationParameters] = OMIT,
chat: typing.Optional[ChatParameters] = OMIT,
chat_id: typing.Optional[str] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> ChatFullResponse:
"""
Modified `chat` method to support chat_id for session reuse.
Parameters
----------
query : str
The chat message or question.
search : typing.Optional[SearchCorporaParameters]
Retrieval parameters. Optional if chat_id is provided.
request_timeout : typing.Optional[int]
The API will make a best effort to complete the request in the specified seconds or time out.
request_timeout_millis : typing.Optional[int]
The API will make a best effort to complete the request in the specified milliseconds or time out.
generation : typing.Optional[GenerationParameters]
Parameters for response generation.
chat : typing.Optional[ChatParameters]
Chat-specific parameters. Optional if chat_id is provided.
chat_id : typing.Optional[str]
ID of the chat session to reuse stored session data.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
ChatFullResponse
"""

if not chat_id and not search:
raise ValueError("`search` parameter is required when `chat_id` is not provided.")

if chat_id:
session = self.session_manager.get_session(chat_id)
if session:
search = session.get("search")
generation = session.get("generation")
chat = session.get("chat")
request_timeout = session.get("request_timeout")
request_timeout_millis = session.get("request_timeout_millis")
request_options = session.get("request_options")
return self.chats.create_turns(
chat_id=chat_id,
query=query,
search=search,
generation=generation,
chat=chat,
request_options=request_options,
request_timeout=request_timeout,
request_timeout_millis=request_timeout_millis,
)

response = super().chat(
query=query,
search=search,
request_timeout=request_timeout,
request_timeout_millis=request_timeout_millis,
generation=generation,
chat=chat,
request_options=request_options,
)

if response.chat_id:
self.session_manager.create_session(response.chat_id, search, generation, chat, request_timeout,
request_timeout_millis, request_options)
return response


class AsyncVectara(AsyncBaseVectara):
pass
2 changes: 1 addition & 1 deletion src/vectara/utils/lab_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def create_lab_corpus(self, corpus: CreateCorpusRequest, user_prefix=True, usern

corpus_clone = corpus.copy()

name, key = self._build_lab_name_and_key(corpus.name, corpus.key)
name, key = self._build_lab_name_and_key(corpus.name, corpus.key, user_prefix)
corpus_clone.name = name
corpus_clone.key = key

Expand Down

0 comments on commit 0a485cd

Please sign in to comment.