diff --git a/src/api/resources/tts/client/Client.ts b/src/api/resources/tts/client/Client.ts index 38627be..a4e3be5 100644 --- a/src/api/resources/tts/client/Client.ts +++ b/src/api/resources/tts/client/Client.ts @@ -2,13 +2,13 @@ * This file was auto-generated by Fern from our API Definition. */ -import * as environments from "../../../../environments"; -import * as core from "../../../../core"; -import * as Cartesia from "../../../index"; import * as stream from "stream"; -import * as serializers from "../../../../serialization/index"; import urlJoin from "url-join"; +import * as core from "../../../../core"; +import * as environments from "../../../../environments"; import * as errors from "../../../../errors/index"; +import * as serializers from "../../../../serialization/index"; +import * as Cartesia from "../../../index"; export declare namespace Tts { interface Options { diff --git a/src/wrapper/Websocket.ts b/src/wrapper/Websocket.ts index c3c7c86..a1ed12d 100644 --- a/src/wrapper/Websocket.ts +++ b/src/wrapper/Websocket.ts @@ -1,23 +1,23 @@ +import Emittery from "emittery"; +import { humanId } from "human-id"; +import qs from "qs"; +import type { RawEncoding, WebSocketRequest, WebSocketStreamOptions, WordTimestamps } from "../api"; +import { Tts } from "../api/resources/tts/client/Client"; import * as core from "../core"; +import { Options, ReconnectingWebSocket } from "../core/websocket"; import * as environments from "../environments"; import * as serializers from "../serialization"; -import Emittery from "emittery"; -import { humanId } from "human-id"; -import { ReconnectingWebSocket, Options } from "../core/websocket"; -import type { RawEncoding, WebSocketStreamOptions, WebSocketTtsRequest, WordTimestamps } from "../api"; +import Source from "./source"; import { base64ToArray, - resolveOutputFormat, ConnectionEventData, createMessageHandlerForContextId, + EmitteryCallbacks, getEmitteryCallbacks, isSentinel, + resolveOutputFormat, WebSocketOptions, - EmitteryCallbacks, } from "./utils"; -import { Tts } from "../api/resources/tts/client/Client"; -import Source from "./source"; -import qs from "qs"; export default class Websocket { socket?: ReconnectingWebSocket; @@ -43,93 +43,127 @@ export default class Websocket { * @returns An Emittery instance that emits messages from the WebSocket. */ async send( - inputs: WebSocketTtsRequest, + request: WebSocketRequest, { timeout = 0 }: WebSocketStreamOptions = {} - ): Promise< - EmitteryCallbacks<{ - message: string; - timestamps: WordTimestamps; - }> & { source: Source; stop: unknown } - > { + ): Promise< + | (EmitteryCallbacks<{ message: string; timestamps: WordTimestamps }> & { + source: Source; + stop: () => void; + }) + | { status: "cancelled"; contextId: string } + > { if (!this.#isConnected) { - throw new Error("Not connected to WebSocket. Call .connect() first."); + throw new Error("Not connected to WebSocket. Call .connect() first."); } - - if (!inputs.contextId) { - inputs.contextId = this.#generateId(); - } - if (!inputs.outputFormat) { - inputs.outputFormat = resolveOutputFormat( - this.#container as "raw" | "wav" | "mp3", - this.#encoding as RawEncoding, - this.#sampleRate + + if ("cancel" in request) { + this.socket?.send( + JSON.stringify( + serializers.WebSocketRequest.jsonOrThrow(request, { + unrecognizedObjectKeys: "strip", + }) + ) + ); + + return { + status: "cancelled", + contextId: request.contextId, + }; + } else if ("transcript" in request && "modelId" in request && "voice" in request && "outputFormat" in request) { + if (!request.contextId) { + request.contextId = this.#generateId(); + } + if (!request.outputFormat) { + request.outputFormat = resolveOutputFormat( + this.#encoding as RawEncoding, + this.#sampleRate ); - } - - this.socket?.send( - JSON.stringify(serializers.WebSocketTtsRequest.jsonOrThrow(inputs, { unrecognizedObjectKeys: "strip" })) - ); - - const emitter = new Emittery<{ + } + + this.socket?.send( + JSON.stringify( + serializers.WebSocketRequest.jsonOrThrow(request, { + unrecognizedObjectKeys: "strip", + }) + ) + ); + + const emitter = new Emittery<{ message: string; timestamps: WordTimestamps; - }>(); - const source = new Source({ + }>(); + const source = new Source({ sampleRate: this.#sampleRate, encoding: this.#encoding, container: this.#container, - }); - // Used to signal that the stream is complete, either because the - // WebSocket has closed, or because the stream has finished. - const streamCompleteController = new AbortController(); - // Set a timeout. - let timeoutId: ReturnType | null = null; - if (timeout > 0) { - timeoutId = setTimeout(streamCompleteController.abort, timeout); - } - const handleMessage = createMessageHandlerForContextId(inputs.contextId, async ({ chunk, message, data }) => { - emitter.emit("message", message); - if (data.type === "timestamps" && data.wordTimestamps) { + }); + + const streamCompleteController = new AbortController(); + + let timeoutId: ReturnType | null = null; + if (timeout > 0) { + timeoutId = setTimeout(() => { + streamCompleteController.abort(); + }, timeout); + } + + const handleMessage = createMessageHandlerForContextId( + request.contextId, + async ({ chunk, message, data }) => { + emitter.emit("message", message); + + if (data.type === "timestamps" && data.wordTimestamps) { emitter.emit("timestamps", data.wordTimestamps); return; - } - if (isSentinel(chunk)) { + } + + if (isSentinel(chunk)) { await source.close(); streamCompleteController.abort(); return; - } - if (timeoutId) { + } + + if (timeoutId) { clearTimeout(timeoutId); - timeoutId = setTimeout(streamCompleteController.abort, timeout); + timeoutId = setTimeout(() => { + streamCompleteController.abort(); + }, timeout); + } + + if (chunk) { + await source.enqueue(base64ToArray([chunk], this.#encoding)); + } } - if (!chunk) { - return; - } - await source.enqueue(base64ToArray([chunk], this.#encoding)); - }); - this.socket?.addEventListener("message", handleMessage); - this.socket?.addEventListener("close", () => { + ); + + this.socket?.addEventListener("message", handleMessage); + this.socket?.addEventListener("close", () => { streamCompleteController.abort(); - }); - this.socket?.addEventListener("error", () => { + }); + this.socket?.addEventListener("error", () => { streamCompleteController.abort(); - }); - streamCompleteController.signal.addEventListener("abort", () => { + }); + + streamCompleteController.signal.addEventListener("abort", () => { source.close(); if (timeoutId) { - clearTimeout(timeoutId); + clearTimeout(timeoutId); } emitter.clearListeners(); - }); - - return { + this.socket?.removeEventListener("message", handleMessage); + }); + + return { source, ...getEmitteryCallbacks(emitter), stop: streamCompleteController.abort.bind(streamCompleteController), - }; - } + }; + } + + throw new Error(`Unknown request type: ${(request as any).type}`); + } - continue(inputs: WebSocketTtsRequest) { + continue(inputs: WebSocketRequest) { if (!this.#isConnected) { throw new Error("Not connected to WebSocket. Call .connect() first."); } @@ -137,9 +171,8 @@ export default class Websocket { if (!inputs.contextId) { throw new Error("context_id is required to continue a context."); } - if (!inputs.outputFormat) { + if ("transcript" in inputs && !inputs.outputFormat) { inputs.outputFormat = resolveOutputFormat( - this.#container as "raw" | "wav" | "mp3", this.#encoding as RawEncoding, this.#sampleRate ); @@ -148,7 +181,7 @@ export default class Websocket { this.socket?.send( JSON.stringify({ continue: true, - ...serializers.WebSocketTtsRequest.jsonOrThrow(inputs, { unrecognizedObjectKeys: "strip" }), + ...serializers.WebSocketRequest.jsonOrThrow(inputs, { unrecognizedObjectKeys: "strip" }), }) ); } diff --git a/src/wrapper/utils.ts b/src/wrapper/utils.ts index e7b1f11..028086a 100644 --- a/src/wrapper/utils.ts +++ b/src/wrapper/utils.ts @@ -1,6 +1,6 @@ import base64 from "base64-js"; import type Emittery from "emittery"; -import type { OutputFormat, RawEncoding, WebSocketResponse, WebSocketTtsRequest } from "../api"; +import type { RawEncoding, WebSocketRawOutputFormat, WebSocketResponse } from "../api"; export type EmitteryCallbacks = { on: Emittery["on"]; @@ -56,33 +56,14 @@ export const ENCODING_MAP: Record = { * @returns The output format for the WebSocket request. */ export function resolveOutputFormat( - container: "raw" | "wav" | "mp3", encoding: RawEncoding, sampleRate: number -): OutputFormat { - switch (container) { - case "wav": - return { - container: "wav", - encoding, - sampleRate, - } as OutputFormat.Wav; - case "raw": - return { - container: "raw", - encoding, - sampleRate, - } as OutputFormat.Raw; - case "mp3": - return { - container: "mp3", - encoding, - sampleRate, - bitRate: 128, - } as OutputFormat.Mp3; - default: - throw new Error(`Unsupported container type: ${container}`); - } +): WebSocketRawOutputFormat { + return { + container: "raw", + encoding, + sampleRate, + }; } /**