diff --git a/agixt/Agent.py b/agixt/Agent.py index 34d0a8f7880e..e563afb8c283 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -18,7 +18,6 @@ UserOAuth, OAuthProvider, TaskItem, - Memory, ) from Providers import Providers from Extensions import Extensions diff --git a/agixt/Memories.py b/agixt/Memories.py index c9d8b8c1a1ac..19da793060af 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -2,7 +2,7 @@ import os import asyncio import sys -from DB import Memory, get_session, DATABASE_TYPE +from DB import Memory, Agent, User, get_session, DATABASE_TYPE import spacy import chromadb from chromadb.config import Settings @@ -241,6 +241,21 @@ def get_base_collection_name(user: str, agent_name: str) -> str: return snake(f"{user}_{agent_name}") +def get_agent_id(agent_name: str, email: str) -> str: + """ + Gets the agent ID for the given agent name and user. + """ + session = get_session() + try: + user = session.query(User).filter_by(email=email).first() + agent = session.query(Agent).filter_by(name=agent_name, user_id=user.id).first() + if agent: + return str(agent.id) + return None + finally: + session.close() + + class Memories: def __init__( self, @@ -258,6 +273,7 @@ def __init__( if not user: user = "user" self.user = user + self.agent_id = get_agent_id(agent_name=agent_name, email=self.user) self.collection_name = get_base_collection_name(user, agent_name) self.collection_number = collection_number # Check if collection_number is a number, it might be a string diff --git a/agixt/providers/default.py b/agixt/providers/default.py index df611dde0581..afc0fb928025 100644 --- a/agixt/providers/default.py +++ b/agixt/providers/default.py @@ -1,6 +1,9 @@ from providers.gpt4free import Gpt4freeProvider from providers.google import GoogleProvider -from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2 +from onnxruntime import InferenceSession +from tokenizers import Tokenizer +from typing import List, cast, Union, Sequence +import numpy.typing as npt from faster_whisper import WhisperModel import os import logging @@ -11,6 +14,68 @@ # tts: google # transcription: faster-whisper # translation: faster-whisper +Vector = Union[Sequence[float], Sequence[int]] +Embedding = Vector +Embeddings = List[Embedding] +Document = str +Documents = List[Document] + + +class ONNXMiniLM_L6_V2: + def __init__(self) -> None: + self.tokenizer = None + self.model = None + + def _normalize(self, v: npt.NDArray) -> npt.NDArray: + norm = np.linalg.norm(v, axis=1) + norm[norm == 0] = 1e-12 + return v / norm[:, np.newaxis] + + def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: + self.tokenizer = cast(Tokenizer, self.tokenizer) # type: ignore + self.model = cast(InferenceSession, self.model) # type: ignore + all_embeddings = [] + for i in range(0, len(documents), batch_size): + batch = documents[i : i + batch_size] + encoded = [self.tokenizer.encode(d) for d in batch] + input_ids = np.array([e.ids for e in encoded]) + attention_mask = np.array([e.attention_mask for e in encoded]) + onnx_input = { + "input_ids": np.array(input_ids, dtype=np.int64), + "attention_mask": np.array(attention_mask, dtype=np.int64), + "token_type_ids": np.array( + [np.zeros(len(e), dtype=np.int64) for e in input_ids], + dtype=np.int64, + ), + } + model_output = self.model.run(None, onnx_input) + last_hidden_state = model_output[0] + # Perform mean pooling with attention weighting + input_mask_expanded = np.broadcast_to( + np.expand_dims(attention_mask, -1), last_hidden_state.shape + ) + embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip( + input_mask_expanded.sum(1), a_min=1e-9, a_max=None + ) + embeddings = self._normalize(embeddings).astype(np.float32) + all_embeddings.append(embeddings) + return np.concatenate(all_embeddings) + + def _init_model_and_tokenizer(self) -> None: + if self.model is None and self.tokenizer is None: + self.tokenizer = Tokenizer.from_file( + os.path.join(os.getcwd(), "onnx", "tokenizer.json") + ) + self.tokenizer.enable_truncation(max_length=256) + self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256) + self.model = InferenceSession( + os.path.join(os.getcwd(), "onnx", "model.onnx") + ) + + def __call__(self, texts: Documents) -> Embeddings: + self._init_model_and_tokenizer() + res = cast(Embeddings, self._forward(texts).tolist()) + return res class DefaultProvider: @@ -35,7 +100,6 @@ def __init__( else kwargs["TRANSCRIPTION_MODEL"] ) self.embedder = ONNXMiniLM_L6_V2() - self.embedder.DOWNLOAD_PATH = os.getcwd() self.chunk_size = 256 @staticmethod diff --git a/requirements.txt b/requirements.txt index 26461c193d4b..fd9b9cb22e07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,5 +37,4 @@ watchdog strawberry-graphql[fastapi] broadcaster gql -pgvector -sqlite-vss \ No newline at end of file +pgvector \ No newline at end of file