From 914bc5333eb5818b30de39178af79726012d5468 Mon Sep 17 00:00:00 2001 From: Eugen Neufeld Date: Fri, 10 Jan 2025 13:23:10 +0100 Subject: [PATCH 1/2] fix: improve cancel logic in openAi model The cancel of a request was not working correctly. With the changes the cancelToken is better taken into account. --- .../src/browser/code-completion-agent.ts | 8 +++-- .../frontend-language-model-registry.ts | 16 +++++++++- .../src/common/language-model-delegate.ts | 1 + .../node/language-model-frontend-delegate.ts | 17 +++++++--- .../src/node/openai-language-model.ts | 32 +++++++++++++++---- 5 files changed, 59 insertions(+), 15 deletions(-) diff --git a/packages/ai-code-completion/src/browser/code-completion-agent.ts b/packages/ai-code-completion/src/browser/code-completion-agent.ts index abb550cbb3711..7763cc52a5e1a 100644 --- a/packages/ai-code-completion/src/browser/code-completion-agent.ts +++ b/packages/ai-code-completion/src/browser/code-completion-agent.ts @@ -109,7 +109,6 @@ export class CodeCompletionAgentImpl implements CodeCompletionAgent { this.logger.error('No prompt found for code-completion-agent'); return undefined; } - // since we do not actually hold complete conversions, the request/response pair is considered a session const sessionId = generateUuid(); const requestId = generateUuid(); @@ -144,7 +143,12 @@ export class CodeCompletionAgentImpl implements CodeCompletionAgent { items: [{ insertText: completionText }], enableForwardStability: true, }; - } finally { + } catch (e) { + if (!token.isCancellationRequested) { + console.error(e.message, e); + } + } + finally { progress.cancel(); } } diff --git a/packages/ai-core/src/browser/frontend-language-model-registry.ts b/packages/ai-core/src/browser/frontend-language-model-registry.ts index 90b80a0688451..3b1b0bb7c0c8f 100644 --- a/packages/ai-core/src/browser/frontend-language-model-registry.ts +++ b/packages/ai-core/src/browser/frontend-language-model-registry.ts @@ -61,6 +61,10 @@ export class LanguageModelDelegateClientImpl return this.receiver.toolCall(requestId, toolId, args_string); } + error(id: string, error: Error): void { + this.receiver.error(id, error); + } + languageModelAdded(metadata: LanguageModelMetaData): void { this.receiver.languageModelAdded(metadata); } @@ -74,6 +78,7 @@ interface StreamState { id: string; tokens: (LanguageModelStreamResponsePart | undefined)[]; resolve?: (_: unknown) => void; + reject?: (_: unknown) => void; } @injectable() @@ -249,8 +254,9 @@ export class FrontendLanguageModelRegistryImpl yield token; } } else { - await new Promise(resolve => { + await new Promise((resolve, reject) => { state.resolve = resolve; + state.reject = reject; }); } } @@ -286,6 +292,14 @@ export class FrontendLanguageModelRegistryImpl throw new Error(`Could not find a tool for ${toolId}!`); } + error(id: string, error: Error): void { + if (!this.streams.has(id)) { + throw new Error('Somehow we got a callback for a non existing stream!'); + } + const streamState = this.streams.get(id)!; + streamState.reject?.(error); + } + override async selectLanguageModels(request: LanguageModelSelector): Promise { await this.initialized; const userSettings = (await this.settingsService.getAgentSettings(request.agent))?.languageModelRequirements?.find(req => req.purpose === request.purpose); diff --git a/packages/ai-core/src/common/language-model-delegate.ts b/packages/ai-core/src/common/language-model-delegate.ts index 5edbfe4b18ac7..b1f9bd70dfd34 100644 --- a/packages/ai-core/src/common/language-model-delegate.ts +++ b/packages/ai-core/src/common/language-model-delegate.ts @@ -21,6 +21,7 @@ export const LanguageModelDelegateClient = Symbol('LanguageModelDelegateClient') export interface LanguageModelDelegateClient { toolCall(requestId: string, toolId: string, args_string: string): Promise; send(id: string, token: LanguageModelStreamResponsePart | undefined): void; + error(id: string, error: Error): void; } export const LanguageModelRegistryFrontendDelegate = Symbol('LanguageModelRegistryFrontendDelegate'); export interface LanguageModelRegistryFrontendDelegate { diff --git a/packages/ai-core/src/node/language-model-frontend-delegate.ts b/packages/ai-core/src/node/language-model-frontend-delegate.ts index 0255c0111aa6a..c4b2ad312e88f 100644 --- a/packages/ai-core/src/node/language-model-frontend-delegate.ts +++ b/packages/ai-core/src/node/language-model-frontend-delegate.ts @@ -95,7 +95,7 @@ export class LanguageModelFrontendDelegateImpl implements LanguageModelFrontendD const delegate = { streamId: generateUuid(), }; - this.sendTokens(delegate.streamId, response.stream); + this.sendTokens(delegate.streamId, response.stream, cancellationToken); return delegate; } this.logger.error( @@ -105,12 +105,19 @@ export class LanguageModelFrontendDelegateImpl implements LanguageModelFrontendD return response; } - protected sendTokens(id: string, stream: AsyncIterable): void { + protected sendTokens(id: string, stream: AsyncIterable, cancellationToken?: CancellationToken): void { (async () => { - for await (const token of stream) { - this.frontendDelegateClient.send(id, token); + try { + for await (const token of stream) { + this.frontendDelegateClient.send(id, token); + } + } catch (e) { + if (!cancellationToken?.isCancellationRequested) { + this.frontendDelegateClient.error(id, e); + } + } finally { + this.frontendDelegateClient.send(id, undefined); } - this.frontendDelegateClient.send(id, undefined); })(); } } diff --git a/packages/ai-openai/src/node/openai-language-model.ts b/packages/ai-openai/src/node/openai-language-model.ts index dec9b61784fc2..c173c7e2f68fa 100644 --- a/packages/ai-openai/src/node/openai-language-model.ts +++ b/packages/ai-openai/src/node/openai-language-model.ts @@ -69,6 +69,9 @@ export class OpenAiModel implements LanguageModel { if (request.response_format?.type === 'json_schema' && this.supportsStructuredOutput()) { return this.handleStructuredOutputRequest(openai, request); } + if (cancellationToken?.isCancellationRequested) { + return { text: '' }; + } let runner: ChatCompletionStream; const tools = this.createTools(request); @@ -95,42 +98,57 @@ export class OpenAiModel implements LanguageModel { let runnerEnd = false; - let resolve: (part: LanguageModelStreamResponsePart) => void; + let resolve: ((part: LanguageModelStreamResponsePart) => void) | undefined; runner.on('error', error => { console.error('Error in OpenAI chat completion stream:', error); runnerEnd = true; - resolve({ content: error.message }); + resolve?.({ content: error.message }); }); // we need to also listen for the emitted errors, as otherwise any error actually thrown by the API will not be caught runner.emitted('error').then(error => { console.error('Error in OpenAI chat completion stream:', error); runnerEnd = true; - resolve({ content: error.message }); + resolve?.({ content: error.message }); }); runner.emitted('abort').then(() => { - // do nothing, as the abort event is only emitted when the runner is aborted by us + // cancel async iterator + runnerEnd = true; }); runner.on('message', message => { if (message.role === 'tool') { - resolve({ tool_calls: [{ id: message.tool_call_id, finished: true, result: this.getCompletionContent(message) }] }); + resolve?.({ tool_calls: [{ id: message.tool_call_id, finished: true, result: this.getCompletionContent(message) }] }); } console.debug('Received Open AI message', JSON.stringify(message)); }); runner.once('end', () => { runnerEnd = true; // eslint-disable-next-line @typescript-eslint/no-explicit-any - resolve(runner.finalChatCompletion as any); + resolve?.(runner.finalChatCompletion as any); }); + if (cancellationToken?.isCancellationRequested) { + return { text: '' }; + } const asyncIterator = { async *[Symbol.asyncIterator](): AsyncIterator { runner.on('chunk', chunk => { - if (chunk.choices[0]?.delta) { + if (cancellationToken?.isCancellationRequested) { + resolve = undefined; + return; + } + if (resolve && chunk.choices[0]?.delta) { resolve({ ...chunk.choices[0]?.delta }); } }); while (!runnerEnd) { + if (cancellationToken?.isCancellationRequested) { + throw new Error('Iterator canceled'); + } const promise = new Promise((res, rej) => { resolve = res; + cancellationToken?.onCancellationRequested(() => { + rej(new Error('Canceled')); + runnerEnd = true; // Stop the iterator + }); }); yield promise; } From c2af3065e930e7c8fa2ed2657f85c45469e762b5 Mon Sep 17 00:00:00 2001 From: Eugen Neufeld Date: Mon, 20 Jan 2025 17:28:17 +0100 Subject: [PATCH 2/2] review comments --- .../src/browser/frontend-language-model-registry.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/packages/ai-core/src/browser/frontend-language-model-registry.ts b/packages/ai-core/src/browser/frontend-language-model-registry.ts index 3b1b0bb7c0c8f..6239b93b1a169 100644 --- a/packages/ai-core/src/browser/frontend-language-model-registry.ts +++ b/packages/ai-core/src/browser/frontend-language-model-registry.ts @@ -292,9 +292,14 @@ export class FrontendLanguageModelRegistryImpl throw new Error(`Could not find a tool for ${toolId}!`); } + // called by backend via the "delegate client" with the error to use for rejection error(id: string, error: Error): void { if (!this.streams.has(id)) { - throw new Error('Somehow we got a callback for a non existing stream!'); + const newStreamState = { + id, + tokens: [], + }; + this.streams.set(id, newStreamState); } const streamState = this.streams.get(id)!; streamState.reject?.(error);