diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 53bd7481..ffb96f77 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -300,7 +300,7 @@ def test_wand(self): condition = qb.wand("keywords", {"apple": 10, "banana": 20}) q = qb.select("*").from_("fruits").where(condition) expected = ( - 'select * from fruits where wand(keywords, {"apple":10, "banana":20})' + 'select * from fruits where wand(keywords, {"apple": 10, "banana": 20})' ) self.assertEqual(q, expected) return q @@ -319,7 +319,7 @@ def test_wand_annotations(self): annotations={"scoreThreshold": 0.13, "targetHits": 7}, ) q = qb.select("*").from_("fruits").where(condition) - expected = 'select * from fruits where ({scoreThreshold: 0.13, targetHits: 7}wand(description, {"a":1, "b":2}))' + expected = 'select * from fruits where ({scoreThreshold: 0.13, targetHits: 7}wand(description, {"a": 1, "b": 2}))' self.assertEqual(q, expected) return q diff --git a/vespa/querybuilder/builder/builder.py b/vespa/querybuilder/builder/builder.py index 1032f1d1..5fb29653 100644 --- a/vespa/querybuilder/builder/builder.py +++ b/vespa/querybuilder/builder/builder.py @@ -37,31 +37,50 @@ def __or__(self, other: Any) -> Condition: def __repr__(self) -> str: return self.name - def contains( - self, value: Any, annotations: Optional[Dict[str, Any]] = None - ) -> Condition: - value_str = self._format_value(value) + def _build_annotated_expression( + self, + operation: str, + value_str: str, + annotations: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> str: + """Helper method to build annotated expressions. + + Args: + operation: The operation name (e.g. 'contains', 'matches') + value_str: The formatted value string + annotations: Optional annotations dictionary + **kwargs: Additional keyword arguments to merge with annotations + """ + if kwargs: + annotations = annotations or {} + annotations.update(kwargs) + if annotations: annotations_str = ",".join( f"{k}:{self._format_annotation_value(v)}" for k, v in annotations.items() ) - return Condition(f"{self.name} contains({{{annotations_str}}}{value_str})") - else: - return Condition(f"{self.name} contains {value_str}") + return f"{self.name} {operation}({{{annotations_str}}}{value_str})" + return f"{self.name} {operation} {value_str}" + + def contains( + self, value: Any, annotations: Optional[Dict[str, Any]] = None, **kwargs + ) -> Condition: + value_str = self._format_value(value) + expr = self._build_annotated_expression( + "contains", value_str, annotations, **kwargs + ) + return Condition(expr) def matches( - self, value: Any, annotations: Optional[Dict[str, Any]] = None + self, value: Any, annotations: Optional[Dict[str, Any]] = None, **kwargs ) -> Condition: value_str = self._format_value(value) - if annotations: - annotations_str = ",".join( - f"{k}:{self._format_annotation_value(v)}" - for k, v in annotations.items() - ) - return Condition(f"{self.name} matches({{{annotations_str}}}{value_str})") - else: - return Condition(f"{self.name} matches {value_str}") + expr = self._build_annotated_expression( + "matches", value_str, annotations, **kwargs + ) + return Condition(expr) def in_(self, *values) -> Condition: values_str = ", ".join( @@ -70,9 +89,15 @@ def in_(self, *values) -> Condition: return Condition(f"{self.name} in ({values_str})") def in_range( - self, start: Any, end: Any, annotations: Optional[Dict[str, Any]] = None + self, + start: Any, + end: Any, + annotations: Optional[Dict[str, Any]] = None, + **kwargs, ) -> Condition: if annotations: + if kwargs: + annotations.update(kwargs) annotations_str = ",".join( f"{k}:{self._format_annotation_value(v)}" for k, v in annotations.items() @@ -631,7 +656,7 @@ def userQuery(value: str = "") -> Condition: @staticmethod def dotProduct( field: str, - weights: Dict[str, float], + weights: Union[List[float], Dict[str, float], str], annotations: Optional[Dict[str, Any]] = None, ) -> Condition: """Creates a dot product calculation condition. @@ -640,7 +665,8 @@ def dotProduct( Args: field (str): Field containing vectors - weights (Dict[str, float]): Feature weights to apply + weights (Union[List[float], Dict[str, float], str]): + Either list of numeric weights or dict mapping elements to weights or a parameter substitution string starting with '@' annotations (Optional[Dict]): Optional modifiers like label Returns: @@ -648,6 +674,7 @@ def dotProduct( Examples: >>> import vespa.querybuilder as qb + >>> # Using dict weights with annotation >>> condition = qb.dotProduct( ... "weightedset_field", ... {"feature1": 1, "feature2": 2}, @@ -656,6 +683,16 @@ def dotProduct( >>> query = qb.select("*").from_("sd1").where(condition) >>> str(query) 'select * from sd1 where ({label:"myDotProduct"}dotProduct(weightedset_field, {"feature1": 1, "feature2": 2}))' + >>> # Using list weights + >>> condition = qb.dotProduct("weightedset_field", [0.4, 0.6]) + >>> query = qb.select("*").from_("sd1").where(condition) + >>> str(query) + 'select * from sd1 where dotProduct(weightedset_field, [0.4, 0.6])' + >>> # Using parameter substitution + >>> condition = qb.dotProduct("weightedset_field", "@myweights") + >>> query = qb.select("*").from_("sd1").where(condition).add_parameter("myweights", [0.4, 0.6]) + >>> str(query) + 'select * from sd1 where dotProduct(weightedset_field, "@myweights")&myweights=[0.4, 0.6]' """ weights_str = json.dumps(weights) expr = f"dotProduct({field}, {weights_str})" @@ -669,7 +706,7 @@ def dotProduct( @staticmethod def weightedSet( field: str, - weights: Dict[str, float], + weights: Union[List[float], Dict[str, float], str], annotations: Optional[Dict[str, Any]] = None, ) -> Condition: """Creates a weighted set condition. @@ -678,7 +715,8 @@ def weightedSet( Args: field (str): Field containing weighted set data - weights (Dict[str, float]): Weights to apply to the set elements + weights (Union[List[float], Dict[str, float], str]): + Either list of numeric weights or dict mapping elements to weights or a parameter substitution string starting with annotations (Optional[Dict]): Optional annotations like targetNumHits Returns: @@ -686,6 +724,7 @@ def weightedSet( Examples: >>> import vespa.querybuilder as qb + >>> # using map weights >>> condition = qb.weightedSet( ... "weightedset_field", ... {"element1": 1, "element2": 2}, @@ -694,6 +733,16 @@ def weightedSet( >>> query = qb.select("*").from_("sd1").where(condition) >>> str(query) 'select * from sd1 where ({targetNumHits:10}weightedSet(weightedset_field, {"element1": 1, "element2": 2}))' + >>> # using list weights + >>> condition = qb.weightedSet("weightedset_field", [0.4, 0.6]) + >>> query = qb.select("*").from_("sd1").where(condition) + >>> str(query) + 'select * from sd1 where weightedSet(weightedset_field, [0.4, 0.6])' + >>> # using parameter substitution + >>> condition = qb.weightedSet("weightedset_field", "@myweights") + >>> query = qb.select("*").from_("sd1").where(condition).add_parameter("myweights", [0.4, 0.6]) + >>> str(query) + 'select * from sd1 where weightedSet(weightedset_field, "@myweights")&myweights=[0.4, 0.6]' """ weights_str = json.dumps(weights) expr = f"weightedSet({field}, {weights_str})" @@ -733,7 +782,9 @@ def nonEmpty(condition: Union[Condition, QueryField]) -> Condition: @staticmethod def wand( - field: str, weights, annotations: Optional[Dict[str, Any]] = None + field: str, + weights: Union[List[float], Dict[str, float], str], + annotations: Optional[Dict[str, Any]] = None, ) -> Condition: """Creates a Weighted AND (WAND) operator for efficient top-k retrieval. @@ -741,8 +792,8 @@ def wand( Args: field (str): Field name to search - weights (Union[List[float], Dict[str, float]]): - Either list of numeric weights or dict mapping terms to weights + weights (Union[List[float], Dict[str, float], str]): + Either list of numeric weights or dict mapping terms to weights or a parameter substitution string starting with '@' annotations (Optional[Dict[str, Any]]): Optional annotations like targetHits Returns: @@ -765,16 +816,9 @@ def wand( ... ) >>> query = qb.select("*").from_("sd1").where(condition) >>> str(query) - 'select * from sd1 where ({targetHits: 100}wand(title, {"hello":0.3, "world":0.7}))' + 'select * from sd1 where ({targetHits: 100}wand(title, {"hello": 0.3, "world": 0.7}))' """ - if isinstance(weights, list): - weights_str = "[" + ", ".join(str(item) for item in weights) + "]" - elif isinstance(weights, dict): - weights_str = ( - "{" + ", ".join(f'"{k}":{v}' for k, v in weights.items()) + "}" - ) - else: - raise ValueError("Invalid weights for wand") + weights_str = json.dumps(weights) expr = f"wand({field}, {weights_str})" if annotations: annotations_str = ", ".join( @@ -966,7 +1010,7 @@ def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: def near( *terms, distance: Optional[int] = None, - annotations: Dict[str, Any] = {}, + annotations: Optional[Dict[str, Any]] = None, **kwargs, ) -> Condition: """Creates a near operator for proximity search. @@ -991,7 +1035,8 @@ def near( """ terms_str = ", ".join(f'"{term}"' for term in terms) expr = f"near({terms_str})" - # if kwargs - add to annotations + if annotations is None: + annotations = {} if kwargs: annotations.update(kwargs) if distance is not None: @@ -1008,7 +1053,7 @@ def near( def onear( *terms, distance: Optional[int] = None, - annotations: Dict[str, Any] = {}, + annotations: Optional[Dict[str, Any]] = None, **kwargs, ) -> Condition: """Creates an ordered near operator for ordered proximity search. @@ -1032,6 +1077,8 @@ def onear( """ terms_str = ", ".join(f'"{term}"' for term in terms) expr = f"onear({terms_str})" + if annotations is None: + annotations = {} if kwargs: annotations.update(kwargs) if distance is not None: