Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasht86 committed Jan 7, 2025
1 parent da45e8e commit 9c39825
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 38 deletions.
4 changes: 2 additions & 2 deletions tests/unit/test_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def test_wand(self):
condition = qb.wand("keywords", {"apple": 10, "banana": 20})
q = qb.select("*").from_("fruits").where(condition)
expected = (
'select * from fruits where wand(keywords, {"apple":10, "banana":20})'
'select * from fruits where wand(keywords, {"apple": 10, "banana": 20})'
)
self.assertEqual(q, expected)
return q
Expand All @@ -319,7 +319,7 @@ def test_wand_annotations(self):
annotations={"scoreThreshold": 0.13, "targetHits": 7},
)
q = qb.select("*").from_("fruits").where(condition)
expected = 'select * from fruits where ({scoreThreshold: 0.13, targetHits: 7}wand(description, {"a":1, "b":2}))'
expected = 'select * from fruits where ({scoreThreshold: 0.13, targetHits: 7}wand(description, {"a": 1, "b": 2}))'
self.assertEqual(q, expected)
return q

Expand Down
119 changes: 83 additions & 36 deletions vespa/querybuilder/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,50 @@ def __or__(self, other: Any) -> Condition:
def __repr__(self) -> str:
return self.name

def contains(
self, value: Any, annotations: Optional[Dict[str, Any]] = None
) -> Condition:
value_str = self._format_value(value)
def _build_annotated_expression(
self,
operation: str,
value_str: str,
annotations: Optional[Dict[str, Any]] = None,
**kwargs,
) -> str:
"""Helper method to build annotated expressions.
Args:
operation: The operation name (e.g. 'contains', 'matches')
value_str: The formatted value string
annotations: Optional annotations dictionary
**kwargs: Additional keyword arguments to merge with annotations
"""
if kwargs:
annotations = annotations or {}
annotations.update(kwargs)

if annotations:
annotations_str = ",".join(
f"{k}:{self._format_annotation_value(v)}"
for k, v in annotations.items()
)
return Condition(f"{self.name} contains({{{annotations_str}}}{value_str})")
else:
return Condition(f"{self.name} contains {value_str}")
return f"{self.name} {operation}({{{annotations_str}}}{value_str})"
return f"{self.name} {operation} {value_str}"

def contains(
self, value: Any, annotations: Optional[Dict[str, Any]] = None, **kwargs
) -> Condition:
value_str = self._format_value(value)
expr = self._build_annotated_expression(
"contains", value_str, annotations, **kwargs
)
return Condition(expr)

def matches(
self, value: Any, annotations: Optional[Dict[str, Any]] = None
self, value: Any, annotations: Optional[Dict[str, Any]] = None, **kwargs
) -> Condition:
value_str = self._format_value(value)
if annotations:
annotations_str = ",".join(
f"{k}:{self._format_annotation_value(v)}"
for k, v in annotations.items()
)
return Condition(f"{self.name} matches({{{annotations_str}}}{value_str})")
else:
return Condition(f"{self.name} matches {value_str}")
expr = self._build_annotated_expression(
"matches", value_str, annotations, **kwargs
)
return Condition(expr)

def in_(self, *values) -> Condition:
values_str = ", ".join(
Expand All @@ -70,9 +89,15 @@ def in_(self, *values) -> Condition:
return Condition(f"{self.name} in ({values_str})")

def in_range(
self, start: Any, end: Any, annotations: Optional[Dict[str, Any]] = None
self,
start: Any,
end: Any,
annotations: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Condition:
if annotations:
if kwargs:
annotations.update(kwargs)
annotations_str = ",".join(
f"{k}:{self._format_annotation_value(v)}"
for k, v in annotations.items()
Expand Down Expand Up @@ -631,7 +656,7 @@ def userQuery(value: str = "") -> Condition:
@staticmethod
def dotProduct(
field: str,
weights: Dict[str, float],
weights: Union[List[float], Dict[str, float], str],
annotations: Optional[Dict[str, Any]] = None,
) -> Condition:
"""Creates a dot product calculation condition.
Expand All @@ -640,14 +665,16 @@ def dotProduct(
Args:
field (str): Field containing vectors
weights (Dict[str, float]): Feature weights to apply
weights (Union[List[float], Dict[str, float], str]):
Either list of numeric weights or dict mapping elements to weights or a parameter substitution string starting with '@'
annotations (Optional[Dict]): Optional modifiers like label
Returns:
Condition: A dot product calculation condition
Examples:
>>> import vespa.querybuilder as qb
>>> # Using dict weights with annotation
>>> condition = qb.dotProduct(
... "weightedset_field",
... {"feature1": 1, "feature2": 2},
Expand All @@ -656,6 +683,16 @@ def dotProduct(
>>> query = qb.select("*").from_("sd1").where(condition)
>>> str(query)
'select * from sd1 where ({label:"myDotProduct"}dotProduct(weightedset_field, {"feature1": 1, "feature2": 2}))'
>>> # Using list weights
>>> condition = qb.dotProduct("weightedset_field", [0.4, 0.6])
>>> query = qb.select("*").from_("sd1").where(condition)
>>> str(query)
'select * from sd1 where dotProduct(weightedset_field, [0.4, 0.6])'
>>> # Using parameter substitution
>>> condition = qb.dotProduct("weightedset_field", "@myweights")
>>> query = qb.select("*").from_("sd1").where(condition).add_parameter("myweights", [0.4, 0.6])
>>> str(query)
'select * from sd1 where dotProduct(weightedset_field, "@myweights")&myweights=[0.4, 0.6]'
"""
weights_str = json.dumps(weights)
expr = f"dotProduct({field}, {weights_str})"
Expand All @@ -669,7 +706,7 @@ def dotProduct(
@staticmethod
def weightedSet(
field: str,
weights: Dict[str, float],
weights: Union[List[float], Dict[str, float], str],
annotations: Optional[Dict[str, Any]] = None,
) -> Condition:
"""Creates a weighted set condition.
Expand All @@ -678,14 +715,16 @@ def weightedSet(
Args:
field (str): Field containing weighted set data
weights (Dict[str, float]): Weights to apply to the set elements
weights (Union[List[float], Dict[str, float], str]):
Either list of numeric weights or dict mapping elements to weights or a parameter substitution string starting with
annotations (Optional[Dict]): Optional annotations like targetNumHits
Returns:
Condition: A weighted set condition
Examples:
>>> import vespa.querybuilder as qb
>>> # using map weights
>>> condition = qb.weightedSet(
... "weightedset_field",
... {"element1": 1, "element2": 2},
Expand All @@ -694,6 +733,16 @@ def weightedSet(
>>> query = qb.select("*").from_("sd1").where(condition)
>>> str(query)
'select * from sd1 where ({targetNumHits:10}weightedSet(weightedset_field, {"element1": 1, "element2": 2}))'
>>> # using list weights
>>> condition = qb.weightedSet("weightedset_field", [0.4, 0.6])
>>> query = qb.select("*").from_("sd1").where(condition)
>>> str(query)
'select * from sd1 where weightedSet(weightedset_field, [0.4, 0.6])'
>>> # using parameter substitution
>>> condition = qb.weightedSet("weightedset_field", "@myweights")
>>> query = qb.select("*").from_("sd1").where(condition).add_parameter("myweights", [0.4, 0.6])
>>> str(query)
'select * from sd1 where weightedSet(weightedset_field, "@myweights")&myweights=[0.4, 0.6]'
"""
weights_str = json.dumps(weights)
expr = f"weightedSet({field}, {weights_str})"
Expand Down Expand Up @@ -733,16 +782,18 @@ def nonEmpty(condition: Union[Condition, QueryField]) -> Condition:

@staticmethod
def wand(
field: str, weights, annotations: Optional[Dict[str, Any]] = None
field: str,
weights: Union[List[float], Dict[str, float], str],
annotations: Optional[Dict[str, Any]] = None,
) -> Condition:
"""Creates a Weighted AND (WAND) operator for efficient top-k retrieval.
For more information, see https://docs.vespa.ai/en/reference/query-language-reference.html#wand
Args:
field (str): Field name to search
weights (Union[List[float], Dict[str, float]]):
Either list of numeric weights or dict mapping terms to weights
weights (Union[List[float], Dict[str, float], str]):
Either list of numeric weights or dict mapping terms to weights or a parameter substitution string starting with '@'
annotations (Optional[Dict[str, Any]]): Optional annotations like targetHits
Returns:
Expand All @@ -765,16 +816,9 @@ def wand(
... )
>>> query = qb.select("*").from_("sd1").where(condition)
>>> str(query)
'select * from sd1 where ({targetHits: 100}wand(title, {"hello":0.3, "world":0.7}))'
'select * from sd1 where ({targetHits: 100}wand(title, {"hello": 0.3, "world": 0.7}))'
"""
if isinstance(weights, list):
weights_str = "[" + ", ".join(str(item) for item in weights) + "]"
elif isinstance(weights, dict):
weights_str = (
"{" + ", ".join(f'"{k}":{v}' for k, v in weights.items()) + "}"
)
else:
raise ValueError("Invalid weights for wand")
weights_str = json.dumps(weights)
expr = f"wand({field}, {weights_str})"
if annotations:
annotations_str = ", ".join(
Expand Down Expand Up @@ -966,7 +1010,7 @@ def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition:
def near(
*terms,
distance: Optional[int] = None,
annotations: Dict[str, Any] = {},
annotations: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Condition:
"""Creates a near operator for proximity search.
Expand All @@ -991,7 +1035,8 @@ def near(
"""
terms_str = ", ".join(f'"{term}"' for term in terms)
expr = f"near({terms_str})"
# if kwargs - add to annotations
if annotations is None:
annotations = {}
if kwargs:
annotations.update(kwargs)
if distance is not None:
Expand All @@ -1008,7 +1053,7 @@ def near(
def onear(
*terms,
distance: Optional[int] = None,
annotations: Dict[str, Any] = {},
annotations: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Condition:
"""Creates an ordered near operator for ordered proximity search.
Expand All @@ -1032,6 +1077,8 @@ def onear(
"""
terms_str = ", ".join(f'"{term}"' for term in terms)
expr = f"onear({terms_str})"
if annotations is None:
annotations = {}
if kwargs:
annotations.update(kwargs)
if distance is not None:
Expand Down

0 comments on commit 9c39825

Please sign in to comment.