Skip to content

Commit

Permalink
[Fix #2163] Add contains, containsAny and contaisAll support for vari…
Browse files Browse the repository at this point in the history
…ables field (#2164)

* [Fix #2163] Add contains support

* [Fix #2163] ContainsAny&ConstainsAll
  • Loading branch information
fjtirado authored Dec 17, 2024
1 parent 3c78fe9 commit 846de1c
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ private AttributeFilter<?> mapJsonArgument(String attribute, String key, Object
case LIKE:
return jsonFilter(like(sb.toString(), value.toString()));
case CONTAINS_ALL:
return filterValueList(value, val -> containsAll(sb.toString(), val));
return jsonFilter(filterValueList(value, val -> containsAll(sb.toString(), val)));
case CONTAINS_ANY:
return filterValueList(value, val -> containsAny(sb.toString(), val));
return jsonFilter(filterValueList(value, val -> containsAny(sb.toString(), val)));
case EQUAL:
default:
return jsonFilter(equalTo(sb.toString(), value));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ void testJsonMapperContains() {
jsonFilter(contains("variables.workflowdata.number", 1)));
}

@Test
void testJsonMapperContainsAny() {
assertThat(mapper.mapJsonArgument("variables").apply(Map.of("workflowdata", Map.of("number", Map.of("containsAny", List.of(1, 2, 3)))))).containsExactly(
jsonFilter(containsAny("variables.workflowdata.number", List.of(1, 2, 3))));
}

@Test
void testJsonMapperContainsAll() {
assertThat(mapper.mapJsonArgument("variables").apply(Map.of("workflowdata", Map.of("number", Map.of("containsAll", List.of(1, 2, 3)))))).containsExactly(
jsonFilter(containsAll("variables.workflowdata.number", List.of(1, 2, 3))));
}

@Test
void testJsonMapperLike() {
assertThat(mapper.mapJsonArgument("variables").apply(Map.of("workflowdata", Map.of("number", Map.of("like", "kk"))))).containsExactly(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.kie.kogito.index.postgresql;

import java.util.Iterator;
import java.util.List;

import org.hibernate.dialect.function.StandardSQLFunction;
import org.hibernate.query.ReturnableType;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.type.BasicTypeReference;
import org.hibernate.type.SqlTypes;

public class ContainsSQLFunction extends StandardSQLFunction {

static final String CONTAINS_NAME = "contains";
static final String CONTAINS_ALL_NAME = "containsAll";
static final String CONTAINS_ANY_NAME = "containsAny";

static final String CONTAINS_SEQ = "??";
static final String CONTAINS_ALL_SEQ = "??&";
static final String CONTAINS_ANY_SEQ = "??|";

private final String operator;

private static final BasicTypeReference<Boolean> RETURN_TYPE = new BasicTypeReference<>("boolean", Boolean.class, SqlTypes.BOOLEAN);

ContainsSQLFunction(String name, String operator) {
super(name, RETURN_TYPE);
this.operator = operator;
}

@Override
public void render(
SqlAppender sqlAppender,
List<? extends SqlAstNode> args,
ReturnableType<?> returnType,
SqlAstTranslator<?> translator) {
int size = args.size();
if (size < 2) {
throw new IllegalArgumentException("Function " + getName() + " requires at least two arguments");
}
Iterator<? extends SqlAstNode> iter = args.iterator();
iter.next().accept(translator);
sqlAppender.append(' ');
sqlAppender.append(operator);
sqlAppender.append(' ');
if (size == 2) {
iter.next().accept(translator);
} else {
sqlAppender.append("array[");
do {
iter.next().accept(translator);
sqlAppender.append(iter.hasNext() ? ',' : ']');
} while (iter.hasNext());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.kie.kogito.index.postgresql;

import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.boot.model.FunctionContributor;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;

import static org.kie.kogito.index.postgresql.ContainsSQLFunction.*;

public class CustomFunctionsContributor implements FunctionContributor {

@Override
public void contributeFunctions(FunctionContributions functionContributions) {
SqmFunctionRegistry registry = functionContributions.getFunctionRegistry();
registry.register(CONTAINS_NAME, new ContainsSQLFunction(CONTAINS_NAME, CONTAINS_SEQ));
registry.register(CONTAINS_ANY_NAME, new ContainsSQLFunction(CONTAINS_ANY_NAME, CONTAINS_ANY_SEQ));
registry.register(CONTAINS_ALL_NAME, new ContainsSQLFunction(CONTAINS_ALL_NAME, CONTAINS_ALL_SEQ));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.kie.kogito.persistence.api.query.AttributeFilter;

Expand Down Expand Up @@ -72,10 +73,24 @@ public static Predicate buildPredicate(AttributeFilter<?> filter, Root<?> root,
values = (List<Object>) filter.getValue();
isString = values.get(0) instanceof String;
return buildPathExpression(builder, root, filter.getAttribute(), isString).in(values.stream().map(o -> buildObjectExpression(builder, o, isString)).collect(Collectors.toList()));
case CONTAINS:
return builder.isTrue(
builder.function(ContainsSQLFunction.CONTAINS_NAME, Boolean.class, buildPathExpression(builder, root, filter.getAttribute(), false), builder.literal(filter.getValue())));
case CONTAINS_ANY:
return containsPredicate(filter, root, builder, ContainsSQLFunction.CONTAINS_ANY_NAME);
case CONTAINS_ALL:
return containsPredicate(filter, root, builder, ContainsSQLFunction.CONTAINS_ALL_NAME);
}
throw new UnsupportedOperationException("Filter " + filter + " is not supported");
}

private static Predicate containsPredicate(AttributeFilter<?> filter, Root<?> root, CriteriaBuilder builder, String name) {
return builder.isTrue(
builder.function(name, Boolean.class,
Stream.concat(Stream.of(buildPathExpression(builder, root, filter.getAttribute(), false)), ((List<?>) filter.getValue()).stream().map(o -> builder.literal(o)))
.toArray(Expression[]::new)));
}

private static Expression buildObjectExpression(CriteriaBuilder builder, Object value, boolean isString) {
return isString ? builder.literal(value) : builder.function("to_jsonb", Object.class, builder.literal(value));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
org.kie.kogito.index.postgresql.CustomFunctionsContributor
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,17 @@ void testProcessInstanceVariables() {
processInstanceId);
queryAndAssert(assertWithId(), storage, singletonList(or(List.of(jsonFilter(notNull("variables.traveller.aliases")), jsonFilter(lessThan("variables.traveller.age", 22))))), null, null, null,
processInstanceId);
// TODO add support for json contains (requires writing dialect extension on hibernate)
//queryAndAssert(assertWithId(), storage, singletonList(jsonFilter(contains("variables.traveller.aliases", "TheRealThing"))), null, null, null,
// processInstanceId);
//queryAndAssert(assertEmpty(), storage, singletonList(jsonFilter(contains("variables.traveller.aliases", "TheDummyThing"))), null, null, null,
// processInstanceId);
queryAndAssert(assertWithId(), storage, singletonList(jsonFilter(contains("variables.traveller.aliases", "TheRealThing"))), null, null, null,
processInstanceId);
queryAndAssert(assertNotId(), storage, singletonList(jsonFilter(contains("variables.traveller.aliases", "TheDummyThing"))), null, null, null,
processInstanceId);
queryAndAssert(assertWithId(), storage, singletonList(jsonFilter(containsAny("variables.traveller.aliases", List.of("TheRealThing", "TheDummyThing")))), null, null, null,
processInstanceId);
queryAndAssert(assertNotId(), storage, singletonList(jsonFilter(containsAny("variables.traveller.aliases", List.of("TheRedPandaThing", "TheDummyThing")))), null, null, null,
processInstanceId);
queryAndAssert(assertWithId(), storage, singletonList(jsonFilter(containsAll("variables.traveller.aliases", List.of("Super", "Astonishing", "TheRealThing")))), null, null, null,
processInstanceId);
queryAndAssert(assertNotId(), storage, singletonList(jsonFilter(containsAll("variables.traveller.aliases", List.of("Super", "TheDummyThing")))), null, null, null,
processInstanceId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ public static <T> AttributeFilter<List<T>> in(String attribute, List<T> values)
return new AttributeFilter<>(attribute, FilterCondition.IN, values);
}

public static AttributeFilter<List<String>> containsAny(String attribute, List<String> values) {
public static <T> AttributeFilter<List<T>> containsAny(String attribute, List<T> values) {
return new AttributeFilter<>(attribute, FilterCondition.CONTAINS_ANY, values);
}

public static AttributeFilter<List<String>> containsAll(String attribute, List<String> values) {
public static <T> AttributeFilter<List<T>> containsAll(String attribute, List<T> values) {
return new AttributeFilter<>(attribute, FilterCondition.CONTAINS_ALL, values);
}

Expand Down

0 comments on commit 846de1c

Please sign in to comment.