Skip to content

Commit

Permalink
Add handling for Optional types
Browse files Browse the repository at this point in the history
  • Loading branch information
Tarjei400 committed Nov 4, 2024
1 parent bd88ad9 commit d3303c9
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ public class ToolProcessor {
Object.class);
private static final Logger log = Logger.getLogger(ToolProcessor.class);

public static final DotName OPTIONAL = DotName.createSimple("java.util.Optional");
public static final DotName OPTIONAL_INT = DotName.createSimple("java.util.OptionalInt");
public static final DotName OPTIONAL_LONG = DotName.createSimple("java.util.OptionalLong");
public static final DotName OPTIONAL_DOUBLE = DotName.createSimple("java.util.OptionalDouble");

@BuildStep
public void telemetry(Capabilities capabilities, BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER);
Expand Down Expand Up @@ -488,11 +493,19 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
ClassInfo classInfo = index.getClassByName(type.name());

List<String> required = new ArrayList<>();

if (classInfo != null) {
for (FieldInfo field : classInfo.fields()) {
String fieldName = field.name();
Type fieldType = field.type();

Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(field.type(), index, null);
// Check if the field is Optional and unwrap it if necessary
boolean isOptional = isJavaOptionalType(fieldType);
if (isOptional) {
fieldType = unwrapOptionalType(fieldType); // Unwrap the Optional type
}

Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(fieldType, index, null);
Map<String, Object> fieldDescription = new HashMap<>();

for (JsonSchemaProperty fieldProperty : fieldSchema) {
Expand All @@ -506,6 +519,10 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
fieldDescription.put("description", String.join(",", descriptionValue));
}
}
if (!isOptional) {
required.add(fieldName);
}

properties.put(fieldName, fieldDescription);
}
}
Expand All @@ -517,10 +534,39 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
throw new IllegalArgumentException("Unsupported type: " + type);
}

private boolean isJavaOptionalType(Type type) {
DotName typeName = type.name();
return typeName.equals(DotName.createSimple("java.util.Optional"))
|| typeName.equals(DotName.createSimple("java.util.OptionalInt"))
|| typeName.equals(DotName.createSimple("java.util.OptionalLong"))
|| typeName.equals(DotName.createSimple("java.util.OptionalDouble"));
}

private Type unwrapOptionalType(Type optionalType) {
if (optionalType.kind() == Type.Kind.PARAMETERIZED_TYPE) {
ParameterizedType parameterizedType = optionalType.asParameterizedType();
return parameterizedType.arguments().get(0);
}
return optionalType;
}

private boolean isComplexType(Type type) {
return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE;
}

private boolean isOptionalField(FieldInfo field, IndexView index) {
Type fieldType = field.type();
DotName fieldTypeName = fieldType.name();

if (OPTIONAL.equals(fieldTypeName) || OPTIONAL_INT.equals(fieldTypeName) || OPTIONAL_LONG.equals(fieldTypeName)
|| OPTIONAL_DOUBLE.equals(fieldTypeName)) {
return true;
}

return fieldTypeName.toString().endsWith("?"); //Check for kotlin nullable type

}

private Iterable<JsonSchemaProperty> removeNulls(JsonSchemaProperty... properties) {
return stream(properties)
.filter(Objects::nonNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import dev.langchain4j.service.Result;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.TypeUtils;
//import dev.langchain4j.service.output.OutputParser;
import dev.langchain4j.service.output.ServiceOutputParser;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.smallrye.mutiny.Multi;
Expand All @@ -23,13 +22,17 @@ public class QuarkusServiceOutputParser extends ServiceOutputParser {

@Override
public String outputFormatInstructions(Type returnType) {
Class<?> rawClass = getRawClass(returnType);
boolean isOptional = isJavaOptional(returnType);
Type actualType = isOptional ? unwrapOptionalType(returnType) : returnType;

Class<?> rawClass = getRawClass(actualType);

if (rawClass != String.class && rawClass != AiMessage.class && rawClass != TokenStream.class
&& rawClass != Response.class && !Multi.class.equals(rawClass)) {
try {
var schema = this.toJsonSchema(returnType);
return "You must answer strictly with json according to the following json schema format: " + schema;
return "You must answer strictly with json according to the following json schema format. Use description metadata to fill data properly: "
+ schema;
} catch (Exception e) {
return "";
}
Expand Down Expand Up @@ -77,7 +80,10 @@ private String extractJsonBlock(String text) {

public String toJsonSchema(Type type) throws Exception {
Map<String, Object> schema = new HashMap<>();
Class<?> rawClass = getRawClass(type);
boolean isOptional = isJavaOptional(type);
Type actualType = isOptional ? unwrapOptionalType(type) : type;

Class<?> rawClass = getRawClass(actualType);

if (type instanceof WildcardType wildcardType) {
Type boundType = wildcardType.getUpperBounds().length > 0 ? wildcardType.getUpperBounds()[0]
Expand All @@ -104,22 +110,64 @@ public String toJsonSchema(Type type) throws Exception {
schema.put("type", "object");
Map<String, Object> properties = new HashMap<>();

List<String> required = new ArrayList<>();
for (Field field : rawClass.getDeclaredFields()) {
field.setAccessible(true);
Map<String, Object> fieldSchema = toJsonSchemaMap(field.getGenericType());
properties.put(field.getName(), fieldSchema);
if (field.isAnnotationPresent(Description.class)) {
Description description = field.getAnnotation(Description.class);
fieldSchema.put("description", description.value());
try {
field.setAccessible(true);
Type fieldType = field.getGenericType();

// Check if the field is Optional and unwrap it if necessary
boolean fieldIsOptional = isJavaOptional(fieldType);
Type fieldActualType = fieldIsOptional ? unwrapOptionalType(fieldType) : fieldType;

Map<String, Object> fieldSchema = toJsonSchemaMap(fieldActualType);
properties.put(field.getName(), fieldSchema);

if (field.isAnnotationPresent(Description.class)) {
Description description = field.getAnnotation(Description.class);
fieldSchema.put("description", String.join(",", description.value()));
}

// Only add to required if it is not Optional
if (!fieldIsOptional) {
required.add(field.getName());
} else {
fieldSchema.put("nullable", true); // Mark as nullable in the JSON schema
}

} catch (Exception e) {

}

}
schema.put("properties", properties);
if (!required.isEmpty()) {
schema.put("required", required);
}
}
if (isOptional) {
schema.put("nullable", true);
}

ObjectMapper mapper = new ObjectMapper();
return mapper.writeValueAsString(schema); // Convert the schema map to a JSON string
}

private boolean isJavaOptional(Type type) {
if (type instanceof ParameterizedType) {
Type rawType = ((ParameterizedType) type).getRawType();
return rawType == Optional.class || rawType == OptionalInt.class || rawType == OptionalLong.class
|| rawType == OptionalDouble.class;
}
return false;
}

private Type unwrapOptionalType(Type optionalType) {
if (optionalType instanceof ParameterizedType) {
return ((ParameterizedType) optionalType).getActualTypeArguments()[0];
}
return optionalType;
}

private Class<?> getRawClass(Type type) {
if (type instanceof Class<?>) {
return (Class<?>) type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import jakarta.annotation.PreDestroy;
Expand Down Expand Up @@ -36,12 +37,12 @@ public static class TestData {
Integer bar;

@Description("Foo description for structured output")
Double baz;
Optional<Double> baz;

TestData(String foo, Integer bar, Double baz) {
this.foo = foo;
this.bar = bar;
this.baz = baz;
this.baz = Optional.of(baz);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
Expand Down Expand Up @@ -29,12 +30,12 @@ public static class TestData {
Integer bar;

@Description("Foo description for structured output")
Double baz;
Optional<Double> baz;

TestData(String foo, Integer bar, Double baz) {
this.foo = foo;
this.bar = bar;
this.baz = baz;
this.baz = Optional.of(baz);
}
}

Expand Down

0 comments on commit d3303c9

Please sign in to comment.