Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN Embedding Dimension Support #1792

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/core/base/providers/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class EmbeddingConfig(ProviderConfig):
provider: str
base_model: str
base_dimension: int
base_dimension: int | float
rerank_model: Optional[str] = None
rerank_url: Optional[str] = None
batch_size: int = 1
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
TextSplitter,
_decorate_vector_type,
_get_str_estimation_output,
_get_vector_column_str,
decrement_version,
deep_update,
format_search_results_for_llm,
Expand Down Expand Up @@ -39,5 +40,6 @@
"validate_uuid",
"deep_update",
"_decorate_vector_type",
"_get_vector_column_str",
"_get_str_estimation_output",
]
47 changes: 38 additions & 9 deletions py/core/database/chunks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import json
import logging
import math
import time
import uuid
from typing import Any, Optional, TypedDict
Expand Down Expand Up @@ -122,13 +123,18 @@ async def create_tables(self):
else f"vec_binary bit({self.dimension}),"
)

if self.dimension > 0:
vector_col = f"vec vector({self.dimension})"
else:
vector_col = "vec vector"

query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
id UUID PRIMARY KEY,
document_id UUID,
owner_id UUID,
collection_ids UUID[],
vec vector({self.dimension}),
{vector_col},
{binary_col}
text TEXT,
metadata JSONB,
Expand All @@ -149,11 +155,15 @@ async def upsert(self, entry: VectorEntry) -> None:
"""
# Check the quantization type to determine which columns to use
if self.quantization_type == VectorQuantizationType.INT1:
bit_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)

# For quantized vectors, use vec_binary column
query = f"""
INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
(id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
ON CONFLICT (id) DO UPDATE SET
document_id = EXCLUDED.document_id,
owner_id = EXCLUDED.owner_id,
Expand Down Expand Up @@ -212,11 +222,15 @@ async def upsert_entries(self, entries: list[VectorEntry]) -> None:
Matches the table schema where vec_binary column only exists for INT1 quantization.
"""
if self.quantization_type == VectorQuantizationType.INT1:
bit_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)

# For quantized vectors, use vec_binary column
query = f"""
INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
(id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
ON CONFLICT (id) DO UPDATE SET
document_id = EXCLUDED.document_id,
owner_id = EXCLUDED.owner_id,
Expand Down Expand Up @@ -313,7 +327,10 @@ async def semantic_search(
)

# Use binary column and binary-specific distance measures for first stage
stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit({self.dimension})"
bit_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit{bit_dim}"
stage1_param = binary_query

cols.append(
Expand All @@ -331,6 +348,10 @@ async def semantic_search(
search_settings.filters, params, mode="where_clause"
)

vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)

# First stage: Get candidates using binary search
query = f"""
WITH candidates AS (
Expand All @@ -350,7 +371,7 @@ async def semantic_search(
collection_ids,
text,
{"metadata," if search_settings.include_metadatas else ""}
(vec <=> ${len(params) + 4}::vector({self.dimension})) as distance
(vec <=> ${len(params) + 4}::vector{vector_dim}) as distance
FROM candidates
ORDER BY distance
LIMIT ${len(params) + 3}
Expand All @@ -367,7 +388,10 @@ async def semantic_search(

else:
# Standard float vector handling
distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector({self.dimension})"
vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector{vector_dim}"
query_param = str(query_vector)

if search_settings.include_scores:
Expand Down Expand Up @@ -1048,19 +1072,24 @@ async def get_semantic_neighbors(
similarity_threshold: float = 0.5,
) -> list[dict[str, Any]]:
table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)

query = f"""
WITH target_vector AS (
SELECT vec FROM {table_name}
SELECT vec::vector{vector_dim} FROM {table_name}
WHERE document_id = $1 AND id = $2
)
SELECT t.id, t.text, t.metadata, t.document_id, (t.vec <=> tv.vec) AS similarity
SELECT t.id, t.text, t.metadata, t.document_id, (t.vec::vector{vector_dim} <=> tv.vec) AS similarity
FROM {table_name} t, target_vector tv
WHERE (t.vec <=> tv.vec) >= $3
WHERE (t.vec::vector{vector_dim} <=> tv.vec) >= $3
AND t.document_id = $1
AND t.id != $2
ORDER BY similarity ASC
LIMIT $4
"""

results = await self.connection_manager.fetch_query(
query,
(str(document_id), str(id), similarity_threshold, limit),
Expand Down
15 changes: 13 additions & 2 deletions py/core/database/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import csv
import json
import logging
import math
import tempfile
from typing import IO, Any, Optional
from uuid import UUID
Expand Down Expand Up @@ -43,6 +44,12 @@ async def create_tables(self):
logger.info(
f"Creating table, if not exists: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
)

vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
vector_type = f"vector{vector_dim}"

try:
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
Expand All @@ -53,7 +60,7 @@ async def create_tables(self):
metadata JSONB,
title TEXT,
summary TEXT NULL,
summary_embedding vector({self.dimension}) NULL,
summary_embedding {vector_type} NULL,
version TEXT,
size_in_bytes INT,
ingestion_status TEXT DEFAULT 'pending',
Expand Down Expand Up @@ -511,6 +518,10 @@ async def semantic_document_search(
where_clauses = ["summary_embedding IS NOT NULL"]
params: list[str | int | bytes] = [str(query_embedding)]

vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)

if search_settings.filters:
filter_condition, params = apply_filters(
search_settings.filters, params, mode="condition_only"
Expand All @@ -537,7 +548,7 @@ async def semantic_document_search(
updated_at,
summary,
summary_embedding,
(summary_embedding <=> $1::vector({self.dimension})) as semantic_distance
(summary_embedding <=> $1::vector({vector_dim})) as semantic_distance
FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
WHERE {where_clause}
ORDER BY semantic_distance ASC
Expand Down
15 changes: 9 additions & 6 deletions py/core/database/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import json
import logging
import math
import os
import tempfile
import time
Expand Down Expand Up @@ -32,6 +33,7 @@
from core.base.utils import (
_decorate_vector_type,
_get_str_estimation_output,
_get_vector_column_str,
llm_cost_per_million_tokens,
)

Expand Down Expand Up @@ -75,8 +77,8 @@ def _get_parent_constraint(self, store_type: StoreType) -> str:

async def create_tables(self) -> None:
"""Create separate tables for graph and document entities."""
vector_column_str = _decorate_vector_type(
f"({self.dimension})", self.quantization_type
vector_column_str = _get_vector_column_str(
self.dimension, self.quantization_type
)

for store_type in StoreType:
Expand Down Expand Up @@ -527,9 +529,10 @@ async def create_tables(self) -> None:
for store_type in StoreType:
table_name = self._get_relationship_table_for_store(store_type)
parent_constraint = self._get_parent_constraint(store_type)
vector_column_str = _decorate_vector_type(
f"({self.dimension})", self.quantization_type
vector_column_str = _get_vector_column_str(
self.dimension, self.quantization_type
)

QUERY = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
Expand Down Expand Up @@ -1011,8 +1014,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore

async def create_tables(self) -> None:
vector_column_str = _decorate_vector_type(
f"({self.dimension})", self.quantization_type
vector_column_str = _get_vector_column_str(
self.dimension, self.quantization_type
)

query = f"""
Expand Down
5 changes: 5 additions & 0 deletions py/core/providers/embeddings/litellm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math
import os
from copy import copy
from typing import Any
Expand Down Expand Up @@ -73,6 +74,10 @@ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
texts = task["texts"]
kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))

if "dimensions" in kwargs and math.isnan(kwargs["dimensions"]):
kwargs.pop("dimensions")
logger.warning("Dropping nan dimensions from kwargs")

try:
response = await self.litellm_aembedding(
input=texts,
Expand Down
2 changes: 2 additions & 0 deletions py/shared/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base_utils import (
_decorate_vector_type,
_get_str_estimation_output,
_get_vector_column_str,
decrement_version,
deep_update,
format_search_results_for_llm,
Expand Down Expand Up @@ -42,5 +43,6 @@
"TextSplitter",
# Vector utils
"_decorate_vector_type",
"_get_vector_column_str",
"_get_str_estimation_output",
]
18 changes: 18 additions & 0 deletions py/shared/utils/base_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import logging
import math
from copy import deepcopy
from datetime import datetime
from typing import (
Expand Down Expand Up @@ -300,6 +301,23 @@ def _decorate_vector_type(
return f"{quantization_type.db_type}{input_str}"


def _get_vector_column_str(
dimension: int | float, quantization_type: VectorQuantizationType
) -> str:
"""
Returns a string representation of a vector column type.

Explicitly handles the case where the dimension is not a valid number
meant to support embedding models that do not allow for specifying
the dimension.
"""
if math.isnan(dimension) or dimension <= 0:
vector_dim = "" # Allows for Postgres to handle any dimension
else:
vector_dim = f"({dimension})"
return _decorate_vector_type(vector_dim, quantization_type)


def _get_str_estimation_output(x: tuple[Any, Any]) -> str:
if isinstance(x[0], int) and isinstance(x[1], int):
return " - ".join(map(str, x))
Expand Down
Loading