Skip to content

Commit

Permalink
Merge pull request #1241 from tibor-reiss/feat_1168_fetch_objects_by_ids
Browse files Browse the repository at this point in the history
feature: fetch_objects_by_ids
  • Loading branch information
dirkkul authored Aug 20, 2024
2 parents a89feaa + dcfb8b2 commit e974108
Show file tree
Hide file tree
Showing 9 changed files with 782 additions and 1 deletion.
99 changes: 99 additions & 0 deletions integration/test_collection_async.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import datetime
import uuid
from typing import Iterable

import pytest

import weaviate.classes as wvc
from weaviate.collections.classes.config import DataType, Property
from weaviate.collections.classes.data import DataObject
from weaviate.types import UUID

from .conftest import AsyncCollectionFactory, AsyncOpenAICollectionFactory

UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9")
UUID2 = uuid.uuid4()
UUID3 = uuid.uuid4()

DATE1 = datetime.datetime.strptime("2012-02-09", "%Y-%m-%d").replace(tzinfo=datetime.timezone.utc)

Expand All @@ -32,6 +37,51 @@ async def test_fetch_objects(async_collection_factory: AsyncCollectionFactory) -
assert res.objects[0].properties["name"] == "John Doe"


@pytest.mark.asyncio
@pytest.mark.parametrize(
"ids, expected_len, expected",
[
([], 0, set()),
((), 0, set()),
(
[
UUID3,
],
1,
{
UUID3,
},
),
([UUID1, UUID2], 2, {UUID1, UUID2}),
((UUID1, UUID3), 2, {UUID1, UUID3}),
((UUID1, UUID3, UUID3), 2, {UUID1, UUID3}),
],
)
async def test_fetch_objects_by_ids(
async_collection_factory: AsyncCollectionFactory,
ids: Iterable[UUID],
expected_len: int,
expected: set,
) -> None:
collection = await async_collection_factory(
properties=[
Property(name="name", data_type=DataType.TEXT),
],
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
)
await collection.data.insert_many(
[
DataObject(properties={"name": "first"}, uuid=UUID1),
DataObject(properties={"name": "second"}, uuid=UUID2),
DataObject(properties={"name": "third"}, uuid=UUID3),
]
)

res = await collection.query.fetch_objects_by_ids(ids)
assert len(res.objects) == expected_len
assert {o.uuid for o in res.objects} == expected


@pytest.mark.asyncio
async def test_config_update(async_collection_factory: AsyncCollectionFactory) -> None:
collection = await async_collection_factory(
Expand Down Expand Up @@ -200,3 +250,52 @@ async def test_generate(async_openai_collection: AsyncOpenAICollectionFactory) -
assert len(res.objects) == 2
for obj in res.objects:
assert obj.generated is not None


@pytest.mark.asyncio
@pytest.mark.parametrize(
"ids, expected_len, expected",
[
([], 0, set()),
((), 0, set()),
(
[
UUID3,
],
1,
{
UUID3,
},
),
([UUID1, UUID2], 2, {UUID1, UUID2}),
((UUID1, UUID3), 2, {UUID1, UUID3}),
((UUID1, UUID3, UUID3), 2, {UUID1, UUID3}),
],
)
async def test_generate_by_ids(
async_openai_collection: AsyncOpenAICollectionFactory,
ids: Iterable[UUID],
expected_len: int,
expected: set,
) -> None:
collection = await async_openai_collection(
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
)
await collection.data.insert_many(
[
DataObject(properties={"text": "John Doe"}, uuid=UUID1),
DataObject(properties={"text": "Jane Doe"}, uuid=UUID2),
DataObject(properties={"text": "J. Doe"}, uuid=UUID3),
]
)
res = await collection.generate.fetch_objects_by_ids(
ids,
single_prompt="Who is this? {text}",
grouped_task="Who are these people?",
)
assert res is not None
assert res.generated is not None
assert len(res.objects) == expected_len
assert {o.uuid for o in res.objects} == expected
for obj in res.objects:
assert obj.generated is not None
49 changes: 48 additions & 1 deletion integration/test_collection_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import time
import uuid
from typing import Callable, List, Optional
from typing import Callable, Iterable, List, Optional

import pytest as pytest

Expand All @@ -21,6 +21,7 @@
)
from weaviate.collections.classes.grpc import MetadataQuery, QueryReference, Sort
from weaviate.collections.classes.internal import ReferenceToMulti
from weaviate.types import UUID

NOW = datetime.datetime.now(datetime.timezone.utc)
LATER = NOW + datetime.timedelta(hours=1)
Expand Down Expand Up @@ -548,6 +549,52 @@ def test_filter_id(collection_factory: CollectionFactory, weav_filter: _FilterVa
assert objects[0].uuid == UUID1


@pytest.mark.parametrize(
"ids, expected_len, expected",
[
([], 0, set()),
((), 0, set()),
(
[
UUID3,
],
1,
{
UUID3,
},
),
([UUID1, UUID2], 2, {UUID1, UUID2}),
((UUID1, UUID3), 2, {UUID1, UUID3}),
((UUID1, UUID3, UUID3), 2, {UUID1, UUID3}),
],
)
def test_filter_ids(
collection_factory: CollectionFactory,
ids: Iterable[UUID],
expected_len: int,
expected: set,
) -> None:
collection = collection_factory(
properties=[
Property(name="Name", data_type=DataType.TEXT),
],
vectorizer_config=Configure.Vectorizer.none(),
)

collection.data.insert_many(
[
DataObject(properties={"name": "first"}, uuid=UUID1),
DataObject(properties={"name": "second"}, uuid=UUID2),
DataObject(properties={"name": "third"}, uuid=UUID3),
]
)

objects = collection.query.fetch_objects_by_ids(ids).objects

assert len(objects) == expected_len
assert {o.uuid for o in objects} == expected


@pytest.mark.parametrize("path", ["_creationTimeUnix", "_lastUpdateTimeUnix"])
def test_filter_timestamp_direct_path(collection_factory: CollectionFactory, path: str) -> None:
collection = collection_factory(
Expand Down
6 changes: 6 additions & 0 deletions weaviate/collections/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
_FetchObjectsGenerateAsync,
_FetchObjectsGenerate,
)
from weaviate.collections.queries.fetch_objects_by_ids import (
_FetchObjectsByIDsGenerateAsync,
_FetchObjectsByIDsGenerate,
)
from weaviate.collections.queries.hybrid import _HybridGenerateAsync, _HybridGenerate
from weaviate.collections.queries.near_image import _NearImageGenerateAsync, _NearImageGenerate
from weaviate.collections.queries.near_media import _NearMediaGenerateAsync, _NearMediaGenerate
Expand All @@ -19,6 +23,7 @@ class _GenerateCollectionAsync(
Generic[TProperties, References],
_BM25GenerateAsync[TProperties, References],
_FetchObjectsGenerateAsync[TProperties, References],
_FetchObjectsByIDsGenerateAsync[TProperties, References],
_HybridGenerateAsync[TProperties, References],
_NearImageGenerateAsync[TProperties, References],
_NearMediaGenerateAsync[TProperties, References],
Expand All @@ -33,6 +38,7 @@ class _GenerateCollection(
Generic[TProperties, References],
_BM25Generate[TProperties, References],
_FetchObjectsGenerate[TProperties, References],
_FetchObjectsByIDsGenerate[TProperties, References],
_HybridGenerate[TProperties, References],
_NearImageGenerate[TProperties, References],
_NearMediaGenerate[TProperties, References],
Expand Down
9 changes: 9 additions & 0 deletions weaviate/collections/queries/fetch_objects_by_ids/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .generate import _FetchObjectsByIDsGenerateAsync, _FetchObjectsByIDsGenerate
from .query import _FetchObjectsByIDsQueryAsync, _FetchObjectsByIDsQuery

__all__ = [
"_FetchObjectsByIDsGenerate",
"_FetchObjectsByIDsGenerateAsync",
"_FetchObjectsByIDsQuery",
"_FetchObjectsByIDsQueryAsync",
]
75 changes: 75 additions & 0 deletions weaviate/collections/queries/fetch_objects_by_ids/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Generic, Iterable, List, Optional

from weaviate import syncify
from weaviate.collections.classes.filters import Filter
from weaviate.collections.classes.grpc import METADATA, Sorting
from weaviate.collections.classes.internal import (
GenerativeReturnType,
_Generative,
ReturnProperties,
ReturnReferences,
_QueryOptions,
)
from weaviate.collections.classes.types import Properties, TProperties, References, TReferences
from weaviate.collections.queries.base import _Base
from weaviate.proto.v1 import search_get_pb2
from weaviate.types import UUID, INCLUDE_VECTOR


class _FetchObjectsByIDsGenerateAsync(
Generic[Properties, References], _Base[Properties, References]
):
async def fetch_objects_by_ids(
self,
ids: Iterable[UUID],
*,
single_prompt: Optional[str] = None,
grouped_task: Optional[str] = None,
grouped_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
after: Optional[UUID] = None,
sort: Optional[Sorting] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
return_properties: Optional[ReturnProperties[TProperties]] = None,
return_references: Optional[ReturnReferences[TReferences]] = None
) -> GenerativeReturnType[Properties, References, TProperties, TReferences]:
"""Special case of fetch_objects based on filters on uuid"""
if not ids:
res = search_get_pb2.SearchReply(results=None)
else:
res = await self._query.get(
limit=limit,
offset=offset,
after=after,
filters=Filter.any_of([Filter.by_id().equal(uuid) for uuid in ids]),
sort=sort,
return_metadata=self._parse_return_metadata(return_metadata, include_vector),
return_properties=self._parse_return_properties(return_properties),
return_references=self._parse_return_references(return_references),
generative=_Generative(
single=single_prompt,
grouped=grouped_task,
grouped_properties=grouped_properties,
),
)
return self._result_to_generative_query_return(
res,
_QueryOptions.from_input(
return_metadata,
return_properties,
include_vector,
self._references,
return_references,
),
return_properties,
return_references,
)


@syncify.convert
class _FetchObjectsByIDsGenerate(
Generic[Properties, References], _FetchObjectsByIDsGenerateAsync[Properties, References]
):
pass
Loading

0 comments on commit e974108

Please sign in to comment.