Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference API] Fix Azure AI Studio Integration for Completions and Embeddings #119818

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ public AzureAiStudioChatCompletionRequest(AzureAiStudioChatCompletionModel model
this.stream = stream;
}

public boolean isRealtimeEndpoint() {
return isRealtimeEndpoint;
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(this.uri);
Expand Down Expand Up @@ -71,11 +67,12 @@ private AzureAiStudioChatCompletionRequestEntity createRequestEntity() {
var serviceSettings = completionModel.getServiceSettings();
return new AzureAiStudioChatCompletionRequestEntity(
input,
serviceSettings.endpointType(),
serviceSettings.deploymentType(),
serviceSettings.model(),
taskSettings.temperature(),
taskSettings.topP(),
taskSettings.doSample(),
taskSettings.maxNewTokens(),
taskSettings.maxTokens(),
isStreaming()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,61 +10,52 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioDeploymentType;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.INPUT_DATA_OBJECT;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.INPUT_STRING_ARRAY;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.MESSAGES_ARRAY;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.MESSAGE_CONTENT;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.PARAMETERS_OBJECT;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.ROLE;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.STREAM;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.USER_ROLE;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.MODEL_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.DO_SAMPLE_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.MAX_NEW_TOKENS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.MAX_TOKENS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TEMPERATURE_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_P_FIELD;

public record AzureAiStudioChatCompletionRequestEntity(
List<String> messages,
AzureAiStudioEndpointType endpointType,
AzureAiStudioDeploymentType deploymentType,
@Nullable String model,
@Nullable Double temperature,
@Nullable Double topP,
@Nullable Boolean doSample,
@Nullable Integer maxNewTokens,
@Nullable Integer maxTokens,
boolean stream
) implements ToXContentObject {

public AzureAiStudioChatCompletionRequestEntity {
Objects.requireNonNull(messages);
Objects.requireNonNull(endpointType);
Objects.requireNonNull(deploymentType);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

if (endpointType == AzureAiStudioEndpointType.TOKEN) {
createPayAsYouGoRequest(builder, params);
} else {
createRealtimeRequest(builder, params);
if (deploymentType == AzureAiStudioDeploymentType.AZURE_AI_MODEL_INFERENCE_SERVICE) {
builder.field(MODEL_FIELD, model);
}

if (stream) {
builder.field(STREAM, true);
}

builder.endObject();
return builder;
}

private void createRealtimeRequest(XContentBuilder builder, Params params) throws IOException {
builder.startObject(INPUT_DATA_OBJECT);
builder.startArray(INPUT_STRING_ARRAY);
builder.startArray(MESSAGES_ARRAY);

for (String message : messages) {
addMessageContentObject(builder, message);
Expand All @@ -75,18 +66,8 @@ private void createRealtimeRequest(XContentBuilder builder, Params params) throw
addRequestParameters(builder);

builder.endObject();
}

private void createPayAsYouGoRequest(XContentBuilder builder, Params params) throws IOException {
builder.startArray(MESSAGES_ARRAY);

for (String message : messages) {
addMessageContentObject(builder, message);
}

builder.endArray();

addRequestParameters(builder);
return builder;
}

private void addMessageContentObject(XContentBuilder builder, String message) throws IOException {
Expand All @@ -99,12 +80,10 @@ private void addMessageContentObject(XContentBuilder builder, String message) th
}

private void addRequestParameters(XContentBuilder builder) throws IOException {
if (temperature == null && topP == null && doSample == null && maxNewTokens == null) {
if (temperature == null && topP == null && doSample == null && maxTokens == null) {
return;
}

builder.startObject(PARAMETERS_OBJECT);

if (temperature != null) {
builder.field(TEMPERATURE_FIELD, temperature);
}
Expand All @@ -117,10 +96,9 @@ private void addRequestParameters(XContentBuilder builder) throws IOException {
builder.field(DO_SAMPLE_FIELD, doSample);
}

if (maxNewTokens != null) {
builder.field(MAX_NEW_TOKENS_FIELD, maxNewTokens);
if (maxTokens != null) {
builder.field(MAX_TOKENS_FIELD, maxTokens);
}

builder.endObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,18 @@ public AzureAiStudioEmbeddingsRequest(Truncator truncator, Truncator.TruncationR
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(this.uri);

var user = embeddingsModel.getTaskSettings().user();
var dimensions = embeddingsModel.getServiceSettings().dimensions();
var dimensionsSetByUser = embeddingsModel.getServiceSettings().dimensionsSetByUser();
var deploymentType = embeddingsModel.getServiceSettings().deploymentType();
var model = embeddingsModel.getServiceSettings().model();

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new AzureAiStudioEmbeddingsRequestEntity(truncationResult.input(), user, dimensions, dimensionsSetByUser))
.getBytes(StandardCharsets.UTF_8)
ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(
new AzureAiStudioEmbeddingsRequestEntity(deploymentType,
model,
truncationResult.input(),
dimensions,
dimensionsSetByUser))
.getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,39 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioDeploymentType;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.MODEL_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.DIMENSIONS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.INPUT_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.USER_FIELD;

public record AzureAiStudioEmbeddingsRequestEntity(
AzureAiStudioDeploymentType deploymentType,
@Nullable String model,
List<String> input,
@Nullable String user,
@Nullable Integer dimensions,
boolean dimensionsSetByUser
) implements ToXContentObject {

public AzureAiStudioEmbeddingsRequestEntity {
Objects.requireNonNull(input);
Objects.requireNonNull(deploymentType);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.field(INPUT_FIELD, input);

if (user != null) {
builder.field(USER_FIELD, user);
if (deploymentType == AzureAiStudioDeploymentType.AZURE_AI_MODEL_INFERENCE_SERVICE) {
builder.field(MODEL_FIELD, model);
}

builder.field(INPUT_FIELD, input);

if (dimensionsSetByUser && dimensions != null) {
builder.field(DIMENSIONS_FIELD, dimensions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,35 @@

package org.elasticsearch.xpack.inference.external.request.azureaistudio;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioDeploymentType;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
import org.elasticsearch.xpack.inference.external.request.RequestUtils;

import java.net.URI;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER;

public abstract class AzureAiStudioRequest implements Request {

protected final URI uri;
protected final String inferenceEntityId;

protected final boolean isOpenAiRequest;
protected final boolean isRealtimeEndpoint;

protected AzureAiStudioRequest(AzureAiStudioModel model) {
this.uri = model.uri();
this.inferenceEntityId = model.getInferenceEntityId();
this.isOpenAiRequest = (model.provider() == AzureAiStudioProvider.OPENAI);
this.isRealtimeEndpoint = (model.endpointType() == AzureAiStudioEndpointType.REALTIME);
}

protected void setAuthHeader(HttpEntityEnclosingRequestBase request, AzureAiStudioModel model) {
var apiKey = model.getSecretSettings().apiKey();

if (isOpenAiRequest) {
if (model.deploymentType() == AzureAiStudioDeploymentType.AZURE_AI_MODEL_INFERENCE_SERVICE) {
request.setHeader(API_KEY_HEADER, apiKey.toString());
} else if (model.deploymentType() == AzureAiStudioDeploymentType.SERVERLESS_API) {
request.setHeader(RequestUtils.createAuthBearerHeader(apiKey));
brendan-jugan-elastic marked this conversation as resolved.
Show resolved Hide resolved
} else {
if (isRealtimeEndpoint) {
request.setHeader(createAuthBearerHeader(apiKey));
} else {
request.setHeader(HttpHeaders.AUTHORIZATION, apiKey.toString());
}
// default to api-key header
request.setHeader(API_KEY_HEADER, apiKey.toString());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
public final class AzureAiStudioRequestFields {
public static final String API_KEY_HEADER = "api-key";
public static final String MESSAGES_ARRAY = "messages";
public static final String INPUT_DATA_OBJECT = "input_data";
public static final String INPUT_STRING_ARRAY = "input_string";
public static final String PARAMETERS_OBJECT = "parameters";
public static final String MESSAGE_CONTENT = "content";
public static final String ROLE = "role";
public static final String USER_ROLE = "user";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public class AzureOpenAiUtils {
public static final String COMPLETIONS_PATH = "completions";
public static final String API_VERSION_PARAMETER = "api-version";
public static final String API_KEY_HEADER = "api-key";
public static final String BEARER_PREFIX = "Bearer ";
brendan-jugan-elastic marked this conversation as resolved.
Show resolved Hide resolved

private AzureOpenAiUtils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,69 +7,25 @@

package org.elasticsearch.xpack.inference.external.response.azureaistudio;

import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;

public class AzureAiStudioChatCompletionResponseEntity extends BaseResponseEntity {

@Override
protected InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
if (request instanceof AzureAiStudioChatCompletionRequest asChatCompletionRequest) {
if (asChatCompletionRequest.isRealtimeEndpoint()) {
return parseRealtimeEndpointResponse(response);
}

// we can use the OpenAI chat completion type if it's not a realtime endpoint
// we can use the OpenAI chat completion type as it is the same as Azure AI Studio's format
brendan-jugan-elastic marked this conversation as resolved.
Show resolved Hide resolved
return OpenAiChatCompletionResponseEntity.fromResponse(request, response);
}

return null;
}

private ChatCompletionResults parseRealtimeEndpointResponse(HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);

while (token != null && token != XContentParser.Token.END_OBJECT) {
if (token != XContentParser.Token.FIELD_NAME) {
token = jsonParser.nextToken();
continue;
}

var currentName = jsonParser.currentName();
if (currentName == null || currentName.equalsIgnoreCase("output") == false) {
token = jsonParser.nextToken();
continue;
}

token = jsonParser.nextToken();
ensureExpectedToken(XContentParser.Token.VALUE_STRING, token, jsonParser);
String content = jsonParser.text();

return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content)));
}

throw new IllegalStateException("Reached an invalid state while parsing the Azure AI Studio completion response");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
public class AzureAiStudioEmbeddingsResponseEntity extends BaseResponseEntity {
@Override
protected InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
// expected response type is the same as the Open AI Embeddings
return OpenAiEmbeddingsResponseEntity.fromResponse(request, response);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
package org.elasticsearch.xpack.inference.services.azureaistudio;

public class AzureAiStudioConstants {
public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings";
public static final String COMPLETIONS_URI_PATH = "/v1/chat/completions";

// common service settings fields
public static final String TARGET_FIELD = "target";
public static final String ENDPOINT_TYPE_FIELD = "endpoint_type";
public static final String PROVIDER_FIELD = "provider";
public static final String DEPLOYMENT_TYPE_FIELD = "deployment_type";
public static final String MODEL_FIELD = "model";
public static final String API_KEY_FIELD = "api_key";

// embeddings service and request settings
Expand All @@ -30,7 +28,6 @@ public class AzureAiStudioConstants {
public static final String TOP_P_FIELD = "top_p";
public static final String DO_SAMPLE_FIELD = "do_sample";
public static final String MAX_TOKENS_FIELD = "max_tokens";
public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";

public static final Double MIN_TEMPERATURE_TOP_P = 0.0;
public static final Double MAX_TEMPERATURE_TOP_P = 2.0;
Expand Down
Loading