Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Unified schema API remove name field #119799

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalFloat(topP);
}

public record Message(
Content content,
String role,
@Nullable String name,
@Nullable String toolCallId,
@Nullable List<ToolCall> toolCalls
) implements Writeable {
public record Message(Content content, String role, @Nullable String toolCallId, @Nullable List<ToolCall> toolCalls)
implements
Writeable {

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(
Message.class.getSimpleName(),
args -> new Message((Content) args[0], (String) args[1], (String) args[2], (String) args[3], (List<ToolCall>) args[4])
args -> new Message((Content) args[0], (String) args[1], (String) args[2], (List<ToolCall>) args[3])
);

static {
Expand All @@ -133,7 +129,6 @@ public record Message(
ObjectParser.ValueType.VALUE_ARRAY
);
PARSER.declareString(constructorArg(), new ParseField("role"));
PARSER.declareString(optionalConstructorArg(), new ParseField("name"));
PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id"));
PARSER.declareObjectArray(optionalConstructorArg(), ToolCall.PARSER::apply, new ParseField("tool_calls"));
}
Expand All @@ -155,7 +150,6 @@ public Message(StreamInput in) throws IOException {
in.readOptionalNamedWriteable(Content.class),
in.readString(),
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalCollectionAsList(ToolCall::new)
);
}
Expand All @@ -164,7 +158,6 @@ public Message(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalNamedWriteable(content);
out.writeString(role);
out.writeOptionalString(name);
out.writeOptionalString(toolCallId);
out.writeOptionalCollection(toolCalls);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ public void testParseAllFields() throws IOException {
"type": "string"
}
],
"name": "a name",
"tool_call_id": "100",
"tool_calls": [
{
Expand Down Expand Up @@ -83,7 +82,6 @@ public void testParseAllFields() throws IOException {
List.of(new UnifiedCompletionRequest.ContentObject("some text", "string"))
),
"user",
"a name",
"100",
List.of(
new UnifiedCompletionRequest.ToolCall(
Expand Down Expand Up @@ -155,7 +153,6 @@ public void testParsing() throws IOException {
new UnifiedCompletionRequest.ContentString("What is the weather like in Boston today?"),
"user",
null,
null,
null
)
),
Expand Down Expand Up @@ -200,7 +197,6 @@ public static UnifiedCompletionRequest.Message randomMessage() {
randomContent(),
randomAlphaOfLength(10),
randomAlphaOfLengthOrNull(10),
randomAlphaOfLengthOrNull(10),
randomToolCallListOrNull()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,7 @@ public UnifiedChatInput(List<String> inputs, String roleValue, boolean stream) {

private static List<UnifiedCompletionRequest.Message> convertToMessages(List<String> inputs, String roleValue) {
return inputs.stream()
.map(
value -> new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString(value),
roleValue,
null,
null,
null
)
)
.map(value -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(value), roleValue, null, null))
.toList();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

builder.field(ROLE_FIELD, message.role());
if (message.name() != null) {
builder.field(NAME_FIELD, message.name());
}
if (message.toolCallId() != null) {
builder.field(TOOL_CALL_ID_FIELD, message.toolCallId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,8 @@ public void testConvertsStringInputToMessages() {
Matchers.is(
UnifiedCompletionRequest.of(
List.of(
new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("hello"),
"a role",
null,
null,
null
),
new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("awesome"),
"a role",
null,
null,
null
)
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "a role", null, null),
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("awesome"), "a role", null, null)
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ public void testModelUserFieldsSerialization() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ public void testModelUserFieldsSerialization() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ public void testBasicSerialization() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down Expand Up @@ -78,7 +77,6 @@ public void testSerializationWithAllFields() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
"name",
"tool_call_id",
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
Expand Down Expand Up @@ -127,7 +125,6 @@ public void testSerializationWithAllFields() throws IOException {
{
"content": "Hello, world!",
"role": "user",
"name": "name",
"tool_call_id": "tool_call_id",
"tool_calls": [
{
Expand Down Expand Up @@ -189,7 +186,6 @@ public void testSerializationWithNullOptionalFields() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down Expand Up @@ -240,7 +236,6 @@ public void testSerializationWithEmptyLists() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
Collections.emptyList() // empty toolCalls list
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down Expand Up @@ -290,7 +285,6 @@ public void testSerializationWithNestedObjects() throws IOException {
Random random = Randomness.get();

String randomContent = "Hello, world! " + random.nextInt(1000);
String randomName = "name" + random.nextInt(1000);
String randomToolCallId = "tool_call_id" + random.nextInt(1000);
String randomArguments = "arguments" + random.nextInt(1000);
String randomFunctionName = "function_name" + random.nextInt(1000);
Expand All @@ -303,7 +297,6 @@ public void testSerializationWithNestedObjects() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString(randomContent),
ROLE,
randomName,
randomToolCallId,
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
Expand Down Expand Up @@ -357,7 +350,6 @@ public void testSerializationWithNestedObjects() throws IOException {
{
"content": "%s",
"role": "user",
"name": "%s",
"tool_call_id": "%s",
"tool_calls": [
{
Expand Down Expand Up @@ -416,7 +408,6 @@ public void testSerializationWithNestedObjects() throws IOException {
}
""",
randomContent,
randomName,
randomToolCallId,
randomArguments,
randomFunctionName,
Expand Down Expand Up @@ -449,11 +440,10 @@ public void testSerializationWithDifferentContentTypes() throws IOException {
new UnifiedCompletionRequest.ContentString(randomContentString),
ROLE,
null,
null,
null
);

UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null, null);
UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
messageList.add(messageWithString);
messageList.add(messageWithObjects);
Expand Down Expand Up @@ -502,7 +492,6 @@ public void testSerializationWithSpecialCharacters() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"),
ROLE,
"name\nwith\nnewlines",
"tool_call_id\twith\ttabs",
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
Expand Down Expand Up @@ -541,7 +530,6 @@ public void testSerializationWithSpecialCharacters() throws IOException {
{
"content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /",
"role": "user",
"name": "name\\nwith\\nnewlines",
"tool_call_id": "tool_call_id\\twith\\ttabs",
"tool_calls": [
{
Expand Down Expand Up @@ -571,7 +559,6 @@ public void testSerializationWithBooleanFields() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down Expand Up @@ -641,7 +628,6 @@ public void testSerializationWithoutContentField() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
null,
"assistant",
"name\nwith\nnewlines",
"tool_call_id\twith\ttabs",
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
Expand Down Expand Up @@ -669,7 +655,6 @@ public void testSerializationWithoutContentField() throws IOException {
"messages": [
{
"role": "assistant",
"name": "name\\nwith\\nnewlines",
"tool_call_id": "tool_call_id\\twith\\ttabs",
"tool_calls": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public void testOverridingModelId() {
);

var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null, null)),
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null)),
"new_model_id",
null,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -967,9 +967,7 @@ public void testUnifiedCompletionInfer() throws Exception {
service.unifiedCompletionInfer(
model,
UnifiedCompletionRequest.of(
List.of(
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null, null)
)
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void testOverrideWith_NullMap() {
public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() {
var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)),
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
"different_model",
null,
null,
Expand All @@ -70,7 +70,7 @@ public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() {
public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() {
var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)),
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
null, // not overriding model
null,
null,
Expand Down