Skip to content

Commit

Permalink
fix: improve cancel logic in openAi model
Browse files Browse the repository at this point in the history
The cancel of a request was not working correctly.
With the changes the cancelToken is better taken into account.
  • Loading branch information
eneufeld committed Jan 13, 2025
1 parent d8022a1 commit 914bc53
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ export class CodeCompletionAgentImpl implements CodeCompletionAgent {
this.logger.error('No prompt found for code-completion-agent');
return undefined;
}

// since we do not actually hold complete conversions, the request/response pair is considered a session
const sessionId = generateUuid();
const requestId = generateUuid();
Expand Down Expand Up @@ -144,7 +143,12 @@ export class CodeCompletionAgentImpl implements CodeCompletionAgent {
items: [{ insertText: completionText }],
enableForwardStability: true,
};
} finally {
} catch (e) {
if (!token.isCancellationRequested) {
console.error(e.message, e);
}
}
finally {
progress.cancel();
}
}
Expand Down
16 changes: 15 additions & 1 deletion packages/ai-core/src/browser/frontend-language-model-registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ export class LanguageModelDelegateClientImpl
return this.receiver.toolCall(requestId, toolId, args_string);
}

error(id: string, error: Error): void {
this.receiver.error(id, error);
}

languageModelAdded(metadata: LanguageModelMetaData): void {
this.receiver.languageModelAdded(metadata);
}
Expand All @@ -74,6 +78,7 @@ interface StreamState {
id: string;
tokens: (LanguageModelStreamResponsePart | undefined)[];
resolve?: (_: unknown) => void;
reject?: (_: unknown) => void;
}

@injectable()
Expand Down Expand Up @@ -249,8 +254,9 @@ export class FrontendLanguageModelRegistryImpl
yield token;
}
} else {
await new Promise(resolve => {
await new Promise((resolve, reject) => {
state.resolve = resolve;
state.reject = reject;
});
}
}
Expand Down Expand Up @@ -286,6 +292,14 @@ export class FrontendLanguageModelRegistryImpl
throw new Error(`Could not find a tool for ${toolId}!`);
}

error(id: string, error: Error): void {
if (!this.streams.has(id)) {
throw new Error('Somehow we got a callback for a non existing stream!');
}
const streamState = this.streams.get(id)!;
streamState.reject?.(error);
}

override async selectLanguageModels(request: LanguageModelSelector): Promise<LanguageModel[]> {
await this.initialized;
const userSettings = (await this.settingsService.getAgentSettings(request.agent))?.languageModelRequirements?.find(req => req.purpose === request.purpose);
Expand Down
1 change: 1 addition & 0 deletions packages/ai-core/src/common/language-model-delegate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export const LanguageModelDelegateClient = Symbol('LanguageModelDelegateClient')
export interface LanguageModelDelegateClient {
toolCall(requestId: string, toolId: string, args_string: string): Promise<unknown>;
send(id: string, token: LanguageModelStreamResponsePart | undefined): void;
error(id: string, error: Error): void;
}
export const LanguageModelRegistryFrontendDelegate = Symbol('LanguageModelRegistryFrontendDelegate');
export interface LanguageModelRegistryFrontendDelegate {
Expand Down
17 changes: 12 additions & 5 deletions packages/ai-core/src/node/language-model-frontend-delegate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ export class LanguageModelFrontendDelegateImpl implements LanguageModelFrontendD
const delegate = {
streamId: generateUuid(),
};
this.sendTokens(delegate.streamId, response.stream);
this.sendTokens(delegate.streamId, response.stream, cancellationToken);
return delegate;
}
this.logger.error(
Expand All @@ -105,12 +105,19 @@ export class LanguageModelFrontendDelegateImpl implements LanguageModelFrontendD
return response;
}

protected sendTokens(id: string, stream: AsyncIterable<LanguageModelStreamResponsePart>): void {
protected sendTokens(id: string, stream: AsyncIterable<LanguageModelStreamResponsePart>, cancellationToken?: CancellationToken): void {
(async () => {
for await (const token of stream) {
this.frontendDelegateClient.send(id, token);
try {
for await (const token of stream) {
this.frontendDelegateClient.send(id, token);
}
} catch (e) {
if (!cancellationToken?.isCancellationRequested) {
this.frontendDelegateClient.error(id, e);
}
} finally {
this.frontendDelegateClient.send(id, undefined);
}
this.frontendDelegateClient.send(id, undefined);
})();
}
}
32 changes: 25 additions & 7 deletions packages/ai-openai/src/node/openai-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ export class OpenAiModel implements LanguageModel {
if (request.response_format?.type === 'json_schema' && this.supportsStructuredOutput()) {
return this.handleStructuredOutputRequest(openai, request);
}
if (cancellationToken?.isCancellationRequested) {
return { text: '' };
}

let runner: ChatCompletionStream;
const tools = this.createTools(request);
Expand All @@ -95,42 +98,57 @@ export class OpenAiModel implements LanguageModel {

let runnerEnd = false;

let resolve: (part: LanguageModelStreamResponsePart) => void;
let resolve: ((part: LanguageModelStreamResponsePart) => void) | undefined;
runner.on('error', error => {
console.error('Error in OpenAI chat completion stream:', error);
runnerEnd = true;
resolve({ content: error.message });
resolve?.({ content: error.message });
});
// we need to also listen for the emitted errors, as otherwise any error actually thrown by the API will not be caught
runner.emitted('error').then(error => {
console.error('Error in OpenAI chat completion stream:', error);
runnerEnd = true;
resolve({ content: error.message });
resolve?.({ content: error.message });
});
runner.emitted('abort').then(() => {
// do nothing, as the abort event is only emitted when the runner is aborted by us
// cancel async iterator
runnerEnd = true;
});
runner.on('message', message => {
if (message.role === 'tool') {
resolve({ tool_calls: [{ id: message.tool_call_id, finished: true, result: this.getCompletionContent(message) }] });
resolve?.({ tool_calls: [{ id: message.tool_call_id, finished: true, result: this.getCompletionContent(message) }] });
}
console.debug('Received Open AI message', JSON.stringify(message));
});
runner.once('end', () => {
runnerEnd = true;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
resolve(runner.finalChatCompletion as any);
resolve?.(runner.finalChatCompletion as any);
});
if (cancellationToken?.isCancellationRequested) {
return { text: '' };
}
const asyncIterator = {
async *[Symbol.asyncIterator](): AsyncIterator<LanguageModelStreamResponsePart> {
runner.on('chunk', chunk => {
if (chunk.choices[0]?.delta) {
if (cancellationToken?.isCancellationRequested) {
resolve = undefined;
return;
}
if (resolve && chunk.choices[0]?.delta) {
resolve({ ...chunk.choices[0]?.delta });
}
});
while (!runnerEnd) {
if (cancellationToken?.isCancellationRequested) {
throw new Error('Iterator canceled');
}
const promise = new Promise<LanguageModelStreamResponsePart>((res, rej) => {
resolve = res;
cancellationToken?.onCancellationRequested(() => {
rej(new Error('Canceled'));
runnerEnd = true; // Stop the iterator
});
});
yield promise;
}
Expand Down

0 comments on commit 914bc53

Please sign in to comment.