diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index 4fcd861e5..2718cc486 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -83,6 +83,11 @@ public class ToolProcessor { Object.class); private static final Logger log = Logger.getLogger(ToolProcessor.class); + public static final DotName OPTIONAL = DotName.createSimple("java.util.Optional"); + public static final DotName OPTIONAL_INT = DotName.createSimple("java.util.OptionalInt"); + public static final DotName OPTIONAL_LONG = DotName.createSimple("java.util.OptionalLong"); + public static final DotName OPTIONAL_DOUBLE = DotName.createSimple("java.util.OptionalDouble"); + @BuildStep public void telemetry(Capabilities capabilities, BuildProducer additionalBeanProducer) { var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER); @@ -488,11 +493,18 @@ private Iterable toJsonSchemaProperties(Type type, IndexView ClassInfo classInfo = index.getClassByName(type.name()); List required = new ArrayList<>(); + if (classInfo != null) { for (FieldInfo field : classInfo.fields()) { String fieldName = field.name(); + Type fieldType = field.type(); - Iterable fieldSchema = toJsonSchemaProperties(field.type(), index, null); + boolean isOptional = isJavaOptionalType(fieldType); + if (isOptional) { + fieldType = unwrapOptionalType(fieldType); + } + + Iterable fieldSchema = toJsonSchemaProperties(fieldType, index, null); Map fieldDescription = new HashMap<>(); for (JsonSchemaProperty fieldProperty : fieldSchema) { @@ -506,6 +518,10 @@ private Iterable toJsonSchemaProperties(Type type, IndexView fieldDescription.put("description", String.join(",", descriptionValue)); } } + if (!isOptional) { + required.add(fieldName); + } + properties.put(fieldName, fieldDescription); } } @@ -517,10 +533,39 @@ private Iterable toJsonSchemaProperties(Type type, IndexView throw new IllegalArgumentException("Unsupported type: " + type); } + private boolean isJavaOptionalType(Type type) { + DotName typeName = type.name(); + return typeName.equals(DotName.createSimple("java.util.Optional")) + || typeName.equals(DotName.createSimple("java.util.OptionalInt")) + || typeName.equals(DotName.createSimple("java.util.OptionalLong")) + || typeName.equals(DotName.createSimple("java.util.OptionalDouble")); + } + + private Type unwrapOptionalType(Type optionalType) { + if (optionalType.kind() == Type.Kind.PARAMETERIZED_TYPE) { + ParameterizedType parameterizedType = optionalType.asParameterizedType(); + return parameterizedType.arguments().get(0); + } + return optionalType; + } + private boolean isComplexType(Type type) { return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE; } + private boolean isOptionalField(FieldInfo field, IndexView index) { + Type fieldType = field.type(); + DotName fieldTypeName = fieldType.name(); + + if (OPTIONAL.equals(fieldTypeName) || OPTIONAL_INT.equals(fieldTypeName) || OPTIONAL_LONG.equals(fieldTypeName) + || OPTIONAL_DOUBLE.equals(fieldTypeName)) { + return true; + } + + return false; + + } + private Iterable removeNulls(JsonSchemaProperty... properties) { return stream(properties) .filter(Objects::nonNull) diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java index d801450e8..80a7551ba 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java @@ -13,7 +13,6 @@ import dev.langchain4j.service.Result; import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.TypeUtils; -//import dev.langchain4j.service.output.OutputParser; import dev.langchain4j.service.output.ServiceOutputParser; import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; import io.smallrye.mutiny.Multi; @@ -23,13 +22,17 @@ public class QuarkusServiceOutputParser extends ServiceOutputParser { @Override public String outputFormatInstructions(Type returnType) { - Class rawClass = getRawClass(returnType); + boolean isOptional = isJavaOptional(returnType); + Type actualType = isOptional ? unwrapOptionalType(returnType) : returnType; + + Class rawClass = getRawClass(actualType); if (rawClass != String.class && rawClass != AiMessage.class && rawClass != TokenStream.class && rawClass != Response.class && !Multi.class.equals(rawClass)) { try { var schema = this.toJsonSchema(returnType); - return "You must answer strictly with json according to the following json schema format: " + schema; + return "You must answer strictly with json according to the following json schema format. Use description metadata to fill data properly: " + + schema; } catch (Exception e) { return ""; } @@ -77,7 +80,10 @@ private String extractJsonBlock(String text) { public String toJsonSchema(Type type) throws Exception { Map schema = new HashMap<>(); - Class rawClass = getRawClass(type); + boolean isOptional = isJavaOptional(type); + Type actualType = isOptional ? unwrapOptionalType(type) : type; + + Class rawClass = getRawClass(actualType); if (type instanceof WildcardType wildcardType) { Type boundType = wildcardType.getUpperBounds().length > 0 ? wildcardType.getUpperBounds()[0] @@ -104,22 +110,64 @@ public String toJsonSchema(Type type) throws Exception { schema.put("type", "object"); Map properties = new HashMap<>(); + List required = new ArrayList<>(); for (Field field : rawClass.getDeclaredFields()) { - field.setAccessible(true); - Map fieldSchema = toJsonSchemaMap(field.getGenericType()); - properties.put(field.getName(), fieldSchema); - if (field.isAnnotationPresent(Description.class)) { - Description description = field.getAnnotation(Description.class); - fieldSchema.put("description", description.value()); + try { + field.setAccessible(true); + Type fieldType = field.getGenericType(); + + // Check if the field is Optional and unwrap it if necessary + boolean fieldIsOptional = isJavaOptional(fieldType); + Type fieldActualType = fieldIsOptional ? unwrapOptionalType(fieldType) : fieldType; + + Map fieldSchema = toJsonSchemaMap(fieldActualType); + properties.put(field.getName(), fieldSchema); + + if (field.isAnnotationPresent(Description.class)) { + Description description = field.getAnnotation(Description.class); + fieldSchema.put("description", String.join(",", description.value())); + } + + // Only add to required if it is not Optional + if (!fieldIsOptional) { + required.add(field.getName()); + } else { + fieldSchema.put("nullable", true); // Mark as nullable in the JSON schema + } + + } catch (Exception e) { + } + } schema.put("properties", properties); + if (!required.isEmpty()) { + schema.put("required", required); + } + } + if (isOptional) { + schema.put("nullable", true); } - ObjectMapper mapper = new ObjectMapper(); return mapper.writeValueAsString(schema); // Convert the schema map to a JSON string } + private boolean isJavaOptional(Type type) { + if (type instanceof ParameterizedType) { + Type rawType = ((ParameterizedType) type).getRawType(); + return rawType == Optional.class || rawType == OptionalInt.class || rawType == OptionalLong.class + || rawType == OptionalDouble.class; + } + return false; + } + + private Type unwrapOptionalType(Type optionalType) { + if (optionalType instanceof ParameterizedType) { + return ((ParameterizedType) optionalType).getActualTypeArguments()[0]; + } + return optionalType; + } + private Class getRawClass(Type type) { if (type instanceof Class) { return (Class) type; diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java index 17b3f1602..d33033033 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java @@ -2,6 +2,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import jakarta.annotation.PreDestroy; @@ -36,12 +37,12 @@ public static class TestData { Integer bar; @Description("Foo description for structured output") - Double baz; + Optional baz; TestData(String foo, Integer bar, Double baz) { this.foo = foo; this.bar = bar; - this.baz = baz; + this.baz = Optional.of(baz); } } diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java index 84e2aff77..81828b9d7 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java @@ -2,6 +2,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; @@ -29,12 +30,12 @@ public static class TestData { Integer bar; @Description("Foo description for structured output") - Double baz; + Optional baz; TestData(String foo, Integer bar, Double baz) { this.foo = foo; this.bar = bar; - this.baz = baz; + this.baz = Optional.of(baz); } }