Skip to content

Commit

Permalink
Merge pull request #62 from iQuxLE/main
Browse files Browse the repository at this point in the history
DuckDBAdapter: Batch OpenAI calls
  • Loading branch information
cmungall authored Aug 21, 2024
2 parents 8b3bd77 + 5a61ac6 commit aa69288
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 32 deletions.
135 changes: 103 additions & 32 deletions src/curate_gpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from dataclasses import dataclass, field
from typing import Any, Callable, ClassVar, Dict, Iterable, Iterator, List, Mapping, Optional, Union

import llm
import duckdb
import numpy as np
import openai
Expand All @@ -33,6 +33,8 @@
IDS,
METADATAS,
MODEL_DIMENSIONS,
MODEL_MAP,
DEFAULT_MODEL,
MODELS,
OBJECT,
OPENAI_MODEL_DIMENSIONS,
Expand Down Expand Up @@ -174,7 +176,7 @@ def create_index(self, collection: str):
"""
self.conn.execute(create_index_sql)

def _embedding_function(self, texts: Union[str, List[str]], model: str = None) -> list:
def _embedding_function(self, texts: Union[str, List[str], List[List[str]]], model: str = None) -> list:
"""
Get the embeddings for the given texts using the specified model
:param texts: A single text or a list of texts to embed
Expand All @@ -192,12 +194,12 @@ def _embedding_function(self, texts: Union[str, List[str]], model: str = None) -
if model.startswith("openai:"):
self._initialize_openai_client()
openai_model = model.split(":", 1)[1]
if openai_model == "" or openai_model not in MODELS:
if openai_model == "" or openai_model not in MODEL_MAP.keys():
logger.info(
f"The model {openai_model} is not "
f"one of {MODELS}. Defaulting to {MODELS[1]}"
f"one of {[MODEL_MAP.keys()]}. Defaulting to {DEFAULT_MODEL}"
)
openai_model = MODELS[1]
openai_model = DEFAULT_MODEL

responses = [
self.openai_client.embeddings.create(input=text, model=openai_model)
Expand Down Expand Up @@ -320,33 +322,102 @@ def _process_objects(
cumulative_len = 0
sql_command = self._generate_sql_command(collection, method)
sql_command = sql_command.format(collection=collection)
for next_objs in chunk(objs, batch_size):
next_objs = list(next_objs)
logger.info("Processing batch of objects in DuckDB process_objects ...")
docs = [self._text(o, text_field) for o in next_objs]
docs_len = sum([len(d) for d in docs])
cumulative_len += docs_len
if self._is_openai(collection) and cumulative_len > 3000000:
logger.warning(f"Cumulative length = {cumulative_len}, pausing ...")
time.sleep(60)
cumulative_len = 0
metadatas = [self._dict(o) for o in next_objs]
ids = [self._id(o, id_field) for o in next_objs]
embeddings = self._embedding_function(docs, cm.model)
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
sql_command, list(zip(ids, metadatas, embeddings, docs, strict=False))
)
self.conn.execute("COMMIT;")
except Exception as e:
self.conn.execute("ROLLBACK;")
logger.error(
f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}"
)
raise
finally:
self.create_index(collection)
if not self._is_openai(collection):
for next_objs in chunk(objs, batch_size):
next_objs = list(next_objs)
docs = [self._text(o, text_field) for o in next_objs]
docs_len = sum([len(d) for d in docs])
metadatas = [self._dict(o) for o in next_objs]
ids = [self._id(o, id_field) for o in next_objs]
embeddings = self._embedding_function(docs, cm.model)
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
sql_command, list(zip(ids, metadatas, embeddings, docs, strict=False))
)
self.conn.execute("COMMIT;")
except Exception as e:
self.conn.execute("ROLLBACK;")
logger.error(f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}")
raise
finally:
self.create_index(collection)
else:
if model.startswith("openai:"):
openai_model = model.split(":", 1)[1]
if openai_model == "" or openai_model not in MODEL_MAP.keys():
logger.info(f"The model {openai_model} is not "
f"one of {MODEL_MAP.keys()}. Defaulting to {DEFAULT_MODEL}")
openai_model = DEFAULT_MODEL #ada 002
else:
logger.error(f"Something went wonky ## model: {model}")
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
for next_objs in chunk(objs, batch_size): # Existing chunking
next_objs = list(next_objs)
docs = [self._text(o, text_field) for o in next_objs]
docs_len = sum([len(d) for d in docs])
metadatas = [self._dict(o) for o in next_objs]
ids = [self._id(o, id_field) for o in next_objs]

tokenized_docs = [tokenizer.encode(doc) for doc in docs]
current_batch = []
current_token_count = 0
batch_embeddings = []

i = 0
while i < len(tokenized_docs):
doc_tokens = tokenized_docs[i]
# peek
if current_token_count + len(doc_tokens) <= 8192:
current_batch.append(doc_tokens)
current_token_count += len(doc_tokens)
i += 1
else:
if current_batch:
logger.info(f"Tokens: {current_token_count}")
texts = [tokenizer.decode(tokens) for tokens in current_batch]
short_name, _ = MODEL_MAP[openai_model]
embedding_model = llm.get_embedding_model(short_name)
embeddings = list(embedding_model.embed_multi(texts))
logger.info(f"Number of Documents in batch: {len(embeddings)}")
batch_embeddings.extend(embeddings)

if len(doc_tokens) > 8192:
logger.warning(
f"Document with ID {ids[i]} exceeds the token limit alone and will be skipped.")
# try:
# embeddings = OpenAIEmbeddings(model=model, tiktoken_model_name=model).embed_query(texts,
# embeddings.average model)
# batch_embeddings.extend(embeddings)
# skipping
i += 1
continue
else:
current_batch = []
current_token_count = 0

if current_batch:
logger.info(f"Last batch, token count: {current_token_count}")
texts = [tokenizer.decode(tokens) for tokens in current_batch]
short_name, _ = MODEL_MAP[openai_model]
embedding_model = llm.get_embedding_model(short_name)
embeddings = list(embedding_model.embed_multi(texts))
batch_embeddings.extend(embeddings)
logger.info(f"Trying to insert: {len(ids)} IDS, {len(metadatas)} METADATAS, {len(batch_embeddings)} EMBEDDINGS")
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
sql_command, list(zip(ids, metadatas, batch_embeddings, docs, strict=False))
)
self.conn.execute("COMMIT;")
except Exception as e:
self.conn.execute("ROLLBACK;")
logger.error(
f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}")
raise
finally:
self.create_index(collection)

def remove_collection(self, collection: str = None, exists_ok=False, **kwargs):
"""
Expand Down
11 changes: 11 additions & 0 deletions src/curate_gpt/store/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,14 @@
"text-embedding-3-large": 3072,
}
MODELS = ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]

MODEL_MAP = {
"text-embedding-ada-002": ("ada-002", 1536),
"text-embedding-3-small": ("3-small", 1536),
"text-embedding-3-large": ("3-large", 3072),
"text-embedding-3-small-512": ("3-small-512", 512),
"text-embedding-3-large-256": ("3-large-256", 256),
"text-embedding-3-large-1024": ("3-large-1024", 1024)
}

DEFAULT_MODEL = "text-embedding-ada-002"

0 comments on commit aa69288

Please sign in to comment.