Skip to content

Commit

Permalink
remove unused import
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 14, 2025
1 parent 3270cac commit 104ae8b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 6 deletions.
1 change: 0 additions & 1 deletion agixt/Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
UserOAuth,
OAuthProvider,
TaskItem,
Memory,
)
from Providers import Providers
from Extensions import Extensions
Expand Down
18 changes: 17 additions & 1 deletion agixt/Memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import asyncio
import sys
from DB import Memory, get_session, DATABASE_TYPE
from DB import Memory, Agent, User, get_session, DATABASE_TYPE
import spacy
import chromadb
from chromadb.config import Settings
Expand Down Expand Up @@ -241,6 +241,21 @@ def get_base_collection_name(user: str, agent_name: str) -> str:
return snake(f"{user}_{agent_name}")


def get_agent_id(agent_name: str, email: str) -> str:
"""
Gets the agent ID for the given agent name and user.
"""
session = get_session()
try:
user = session.query(User).filter_by(email=email).first()
agent = session.query(Agent).filter_by(name=agent_name, user_id=user.id).first()
if agent:
return str(agent.id)
return None
finally:
session.close()


class Memories:
def __init__(
self,
Expand All @@ -258,6 +273,7 @@ def __init__(
if not user:
user = "user"
self.user = user
self.agent_id = get_agent_id(agent_name=agent_name, email=self.user)
self.collection_name = get_base_collection_name(user, agent_name)
self.collection_number = collection_number
# Check if collection_number is a number, it might be a string
Expand Down
68 changes: 66 additions & 2 deletions agixt/providers/default.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from providers.gpt4free import Gpt4freeProvider
from providers.google import GoogleProvider
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2
from onnxruntime import InferenceSession
from tokenizers import Tokenizer
from typing import List, cast, Union, Sequence
import numpy.typing as npt
from faster_whisper import WhisperModel
import os
import logging
Expand All @@ -11,6 +14,68 @@
# tts: google
# transcription: faster-whisper
# translation: faster-whisper
Vector = Union[Sequence[float], Sequence[int]]
Embedding = Vector
Embeddings = List[Embedding]
Document = str
Documents = List[Document]


class ONNXMiniLM_L6_V2:
def __init__(self) -> None:
self.tokenizer = None
self.model = None

def _normalize(self, v: npt.NDArray) -> npt.NDArray:
norm = np.linalg.norm(v, axis=1)
norm[norm == 0] = 1e-12
return v / norm[:, np.newaxis]

def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray:
self.tokenizer = cast(Tokenizer, self.tokenizer) # type: ignore
self.model = cast(InferenceSession, self.model) # type: ignore
all_embeddings = []
for i in range(0, len(documents), batch_size):
batch = documents[i : i + batch_size]
encoded = [self.tokenizer.encode(d) for d in batch]
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded])
onnx_input = {
"input_ids": np.array(input_ids, dtype=np.int64),
"attention_mask": np.array(attention_mask, dtype=np.int64),
"token_type_ids": np.array(
[np.zeros(len(e), dtype=np.int64) for e in input_ids],
dtype=np.int64,
),
}
model_output = self.model.run(None, onnx_input)
last_hidden_state = model_output[0]
# Perform mean pooling with attention weighting
input_mask_expanded = np.broadcast_to(
np.expand_dims(attention_mask, -1), last_hidden_state.shape
)
embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(
input_mask_expanded.sum(1), a_min=1e-9, a_max=None
)
embeddings = self._normalize(embeddings).astype(np.float32)
all_embeddings.append(embeddings)
return np.concatenate(all_embeddings)

def _init_model_and_tokenizer(self) -> None:
if self.model is None and self.tokenizer is None:
self.tokenizer = Tokenizer.from_file(
os.path.join(os.getcwd(), "onnx", "tokenizer.json")
)
self.tokenizer.enable_truncation(max_length=256)
self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256)
self.model = InferenceSession(
os.path.join(os.getcwd(), "onnx", "model.onnx")
)

def __call__(self, texts: Documents) -> Embeddings:
self._init_model_and_tokenizer()
res = cast(Embeddings, self._forward(texts).tolist())
return res


class DefaultProvider:
Expand All @@ -35,7 +100,6 @@ def __init__(
else kwargs["TRANSCRIPTION_MODEL"]
)
self.embedder = ONNXMiniLM_L6_V2()
self.embedder.DOWNLOAD_PATH = os.getcwd()
self.chunk_size = 256

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,4 @@ watchdog
strawberry-graphql[fastapi]
broadcaster
gql
pgvector
sqlite-vss
pgvector

0 comments on commit 104ae8b

Please sign in to comment.