diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index fdb117c..58a381d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1,5 +1,6 @@ -import { AnyZodObject, ZodLiteral, ZodObject, z } from "zod"; +import { AnyZodObject, ZodLiteral, ZodObject, ZodOptional, z } from "zod"; import { + BaseRequestParamsSchema, ErrorCode, JSONRPCError, JSONRPCNotification, @@ -312,18 +313,42 @@ export class Protocol< /** * Registers a handler to invoke when this protocol object receives a request with the given method. * + * The handler receives a second callback parameter that can be used to emit progress notifications. + * * Note that this will replace any previous request handler for the same method. */ setRequestHandler< T extends ZodObject<{ method: ZodLiteral; + params: ZodOptional; }>, >( requestSchema: T, - handler: (request: z.infer) => SendResultT | Promise, + handler: ( + request: z.infer, + progress: (progress: Progress) => void, + ) => SendResultT | Promise, ): void { - this._requestHandlers.set(requestSchema.shape.method.value, (request) => - Promise.resolve(handler(requestSchema.parse(request))), + this._requestHandlers.set( + requestSchema.shape.method.value, + async (request) => { + const parsedRequest = requestSchema.parse(request); + const progressToken = parsedRequest.params?._meta?.progressToken; + const progressHandler = + progressToken !== undefined + ? (progress: Progress) => + // Sending directly on the transport to avoid typing conflicts + this._transport + ?.send({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { ...progress, progressToken }, + }) + .catch((error: Error) => this._onerror(error)) + : () => {}; + + return handler(parsedRequest, progressHandler); + }, ); } diff --git a/src/types.ts b/src/types.ts index 8fd4786..3345e96 100644 --- a/src/types.ts +++ b/src/types.ts @@ -15,7 +15,7 @@ export const ProgressTokenSchema = z.union([z.string(), z.number().int()]); */ export const CursorSchema = z.string(); -const BaseRequestParamsSchema = z +export const BaseRequestParamsSchema = z .object({ _meta: z.optional( z