Skip to content

Commit

Permalink
Add support for jina-v3-embeddings model
Browse files Browse the repository at this point in the history
  • Loading branch information
antas-marcin committed Oct 5, 2024
1 parent ef1ad13 commit c951920
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 105 deletions.
55 changes: 37 additions & 18 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@


app = FastAPI()
vec : Vectorizer
meta_config : Meta
logger = getLogger('uvicorn')
vec: Vectorizer
meta_config: Meta
logger = getLogger("uvicorn")


@app.on_event("startup")
Expand All @@ -20,16 +20,22 @@ def startup_event():
cuda_per_process_memory_fraction = 1.0
if "CUDA_PER_PROCESS_MEMORY_FRACTION" in os.environ:
try:
cuda_per_process_memory_fraction = float(os.getenv("CUDA_PER_PROCESS_MEMORY_FRACTION"))
cuda_per_process_memory_fraction = float(
os.getenv("CUDA_PER_PROCESS_MEMORY_FRACTION")
)
except ValueError:
logger.error(f"Invalid CUDA_PER_PROCESS_MEMORY_FRACTION (should be between 0.0-1.0)")
logger.error(
f"Invalid CUDA_PER_PROCESS_MEMORY_FRACTION (should be between 0.0-1.0)"
)
if 0.0 <= cuda_per_process_memory_fraction <= 1.0:
logger.info(f"CUDA_PER_PROCESS_MEMORY_FRACTION set to {cuda_per_process_memory_fraction}")
cuda_support=False
cuda_core=""
logger.info(
f"CUDA_PER_PROCESS_MEMORY_FRACTION set to {cuda_per_process_memory_fraction}"
)
cuda_support = False
cuda_core = ""

if cuda_env is not None and cuda_env == "true" or cuda_env == "1":
cuda_support=True
cuda_support = True
cuda_core = os.getenv("CUDA_CORE")
if cuda_core is None or cuda_core == "":
cuda_core = "cuda:0"
Expand All @@ -40,10 +46,15 @@ def startup_event():
# Batch text tokenization enabled by default
direct_tokenize = False
transformers_direct_tokenize = os.getenv("T2V_TRANSFORMERS_DIRECT_TOKENIZE")
if transformers_direct_tokenize is not None and transformers_direct_tokenize == "true" or transformers_direct_tokenize == "1":
if (
transformers_direct_tokenize is not None
and transformers_direct_tokenize == "true"
or transformers_direct_tokenize == "1"
):
direct_tokenize = True

model_dir = "./models/model"

def get_model_directory() -> (str, bool):
if os.path.exists(f"{model_dir}/model_name"):
with open(f"{model_dir}/model_name", "r") as f:
Expand All @@ -65,17 +76,27 @@ def log_info_about_onnx(onnx_runtime: bool):
if os.path.exists(f"{model_dir}/onnx_quantization_info"):
with open(f"{model_dir}/onnx_quantization_info", "r") as f:
onnx_quantization_info = f.read()
logger.info(f"Running ONNX vectorizer with quantized model for {onnx_quantization_info}")
logger.info(
f"Running ONNX vectorizer with quantized model for {onnx_quantization_info}"
)

model_name, use_sentence_transformer_vectorizer = get_model_directory()
onnx_runtime = get_onnx_runtime()
log_info_about_onnx(onnx_runtime)

meta_config = Meta(model_dir, model_name, use_sentence_transformer_vectorizer)
vec = Vectorizer(model_dir, cuda_support, cuda_core, cuda_per_process_memory_fraction,
meta_config.get_model_type(), meta_config.get_architecture(),
direct_tokenize, onnx_runtime, use_sentence_transformer_vectorizer,
model_name)
vec = Vectorizer(
model_dir,
cuda_support,
cuda_core,
cuda_per_process_memory_fraction,
meta_config.get_model_type(),
meta_config.get_architecture(),
direct_tokenize,
onnx_runtime,
use_sentence_transformer_vectorizer,
model_name,
)


@app.get("/.well-known/live", response_class=Response)
Expand All @@ -96,8 +117,6 @@ async def read_item(item: VectorInput, response: Response):
vector = await vec.vectorize(item.text, item.config)
return {"text": item.text, "vector": vector.tolist(), "dim": len(vector)}
except Exception as e:
logger.exception(
'Something went wrong while vectorizing data.'
)
logger.exception("Something went wrong while vectorizing data.")
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return {"error": str(e)}
2 changes: 1 addition & 1 deletion custom_prerequisites.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import nltk

nltk.download('punkt')
nltk.download("punkt")
4 changes: 3 additions & 1 deletion download.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def quantization_config(onnx_cpu_arch: str):
os.remove(f"{model_dir}/model.onnx")
# Save information about ONNX runtime
save_to_file(f"{model_dir}/onnx_runtime", onnx_runtime)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)
tokenizer.save_pretrained(onnx_path)


Expand Down
17 changes: 9 additions & 8 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
requests==2.32.3
transformers==4.42.4
fastapi==0.112.0
uvicorn==0.30.5
transformers==4.44.2
fastapi==0.115.0
uvicorn==0.31.0
nltk==3.9.1
torch==2.4.0
torch==2.4.1
sentencepiece==0.2.0
sentence-transformers==3.0.1
optimum==1.21.2
onnxruntime==1.18.1
onnx==1.16.2
sentence-transformers==3.1.1
optimum==1.22.0
onnxruntime==1.19.2
onnx==1.17.0
numpy==1.26.4
einops==0.8.0
pytest
17 changes: 9 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
transformers==4.42.4
fastapi==0.112.0
uvicorn==0.30.5
transformers==4.44.2
fastapi==0.115.0
uvicorn==0.31.0
nltk==3.9.1
torch==2.4.0
torch==2.4.1
sentencepiece==0.2.0
sentence-transformers==3.0.1
optimum==1.21.2
onnxruntime==1.18.1
onnx==1.16.2
sentence-transformers==3.1.1
optimum==1.22.0
onnxruntime==1.19.2
onnx==1.17.0
numpy==1.26.4
einops==0.8.0
14 changes: 7 additions & 7 deletions smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

class SmokeTest(unittest.TestCase):
def setUp(self):
self.url = 'http://localhost:8000'
self.url = "http://localhost:8000"

for i in range(0, 100):
try:
res = requests.get(self.url + '/.well-known/ready')
res = requests.get(self.url + "/.well-known/ready")
if res.status_code == 204:
return
else:
Expand All @@ -21,25 +21,25 @@ def setUp(self):
raise Exception("did not start up")

def test_well_known_ready(self):
res = requests.get(self.url + '/.well-known/ready')
res = requests.get(self.url + "/.well-known/ready")

self.assertEqual(res.status_code, 204)

def test_well_known_live(self):
res = requests.get(self.url + '/.well-known/live')
res = requests.get(self.url + "/.well-known/live")

self.assertEqual(res.status_code, 204)

def test_meta(self):
res = requests.get(self.url + '/meta')
res = requests.get(self.url + "/meta")

self.assertEqual(res.status_code, 200)
self.assertIsInstance(res.json(), dict)

def test_vectorizing(self):
def try_to_vectorize(url):
print(f"url: {url}")
req_body = {'text': 'The London Eye is a ferris wheel at the River Thames.'}
req_body = {"text": "The London Eye is a ferris wheel at the River Thames."}

res = requests.post(url, json=req_body)
resBody = res.json()
Expand All @@ -49,7 +49,7 @@ def try_to_vectorize(url):
# below tests that what we deem a reasonable vector is returned. We are
# aware of 384 and 768 dim vectors, which should both fall in that
# range
self.assertTrue(len(resBody['vector']) > 100)
self.assertTrue(len(resBody["vector"]) > 100)
print(f"vector dimensions are: {len(resBody['vector'])}")

try_to_vectorize(self.url + "/vectors/")
Expand Down
31 changes: 18 additions & 13 deletions test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@


def wait_for_uvicorn_start():
url = 'http://localhost:8000/.well-known/ready'
url = "http://localhost:8000/.well-known/ready"

for i in range(0, 100):
try:
res = requests.get(url)
if res.status_code == 204:
return
else:
raise Exception(
"status code is {}".format(res.status_code))
raise Exception("status code is {}".format(res.status_code))
except Exception as e:
print("Attempt {}: {}".format(i, e))
time.sleep(2)
Expand All @@ -32,10 +31,15 @@ def run_server():
uvicorn.run(app)


@pytest.fixture(params=["t5-small",
"distilroberta-base",
"vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
"vblagoje/dpr-question_encoder-single-lfqa-wiki"], scope="function")
@pytest.fixture(
params=[
"t5-small",
"distilroberta-base",
"vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
"vblagoje/dpr-question_encoder-single-lfqa-wiki",
],
scope="function",
)
def server(request):
os.environ["MODEL_NAME"] = request.param
subprocess.call("python download.py", shell=True)
Expand All @@ -48,12 +52,12 @@ def server(request):

def test_vectorizing(server):
wait_for_uvicorn_start()
url = 'http://127.0.0.1:8000/vectors/'
req_body = {'text': 'The London Eye is a ferris wheel at the River Thames.'}
url = "http://127.0.0.1:8000/vectors/"
req_body = {"text": "The London Eye is a ferris wheel at the River Thames."}

res = requests.post(url, json=req_body)
resBody = res.json()
vectorized_text = resBody['vector']
vectorized_text = resBody["vector"]

assert 200 == res.status_code

Expand All @@ -66,14 +70,15 @@ def test_vectorizing(server):

# now let's try two sentences

req_body = {'text': 'The London Eye is a ferris wheel at the River Thames. Here is the second sentence.'}
req_body = {
"text": "The London Eye is a ferris wheel at the River Thames. Here is the second sentence."
}
res = requests.post(url, json=req_body)
resBody = res.json()
vectorized_text = resBody['vector']
vectorized_text = resBody["vector"]

assert 200 == res.status_code

assert type(vectorized_text) is list

assert 128 <= len(vectorized_text) <= 1024

Loading

0 comments on commit c951920

Please sign in to comment.