Skip to content

Commit

Permalink
fix: use WebsocketRequest in ws endpoint.
Browse files Browse the repository at this point in the history
  • Loading branch information
eyw520 committed Jan 22, 2025
1 parent bd38032 commit b898e72
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 103 deletions.
8 changes: 4 additions & 4 deletions src/api/resources/tts/client/Client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
* This file was auto-generated by Fern from our API Definition.
*/

import * as environments from "../../../../environments";
import * as core from "../../../../core";
import * as Cartesia from "../../../index";
import * as stream from "stream";
import * as serializers from "../../../../serialization/index";
import urlJoin from "url-join";
import * as core from "../../../../core";
import * as environments from "../../../../environments";
import * as errors from "../../../../errors/index";
import * as serializers from "../../../../serialization/index";
import * as Cartesia from "../../../index";

export declare namespace Tts {
interface Options {
Expand Down
179 changes: 106 additions & 73 deletions src/wrapper/Websocket.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import Emittery from "emittery";
import { humanId } from "human-id";
import qs from "qs";
import type { RawEncoding, WebSocketRequest, WebSocketStreamOptions, WordTimestamps } from "../api";
import { Tts } from "../api/resources/tts/client/Client";
import * as core from "../core";
import { Options, ReconnectingWebSocket } from "../core/websocket";
import * as environments from "../environments";
import * as serializers from "../serialization";
import Emittery from "emittery";
import { humanId } from "human-id";
import { ReconnectingWebSocket, Options } from "../core/websocket";
import type { RawEncoding, WebSocketStreamOptions, WebSocketTtsRequest, WordTimestamps } from "../api";
import Source from "./source";
import {
base64ToArray,
resolveOutputFormat,
ConnectionEventData,
createMessageHandlerForContextId,
EmitteryCallbacks,
getEmitteryCallbacks,
isSentinel,
resolveOutputFormat,
WebSocketOptions,
EmitteryCallbacks,
} from "./utils";
import { Tts } from "../api/resources/tts/client/Client";
import Source from "./source";
import qs from "qs";

export default class Websocket {
socket?: ReconnectingWebSocket;
Expand All @@ -43,103 +43,136 @@ export default class Websocket {
* @returns An Emittery instance that emits messages from the WebSocket.
*/
async send(
inputs: WebSocketTtsRequest,
request: WebSocketRequest,
{ timeout = 0 }: WebSocketStreamOptions = {}
): Promise<
EmitteryCallbacks<{
message: string;
timestamps: WordTimestamps;
}> & { source: Source; stop: unknown }
> {
): Promise<
| (EmitteryCallbacks<{ message: string; timestamps: WordTimestamps }> & {
source: Source;
stop: () => void;
})
| { status: "cancelled"; contextId: string }
> {
if (!this.#isConnected) {
throw new Error("Not connected to WebSocket. Call .connect() first.");
throw new Error("Not connected to WebSocket. Call .connect() first.");
}

if (!inputs.contextId) {
inputs.contextId = this.#generateId();
}
if (!inputs.outputFormat) {
inputs.outputFormat = resolveOutputFormat(
this.#container as "raw" | "wav" | "mp3",
this.#encoding as RawEncoding,
this.#sampleRate

if ("cancel" in request) {
this.socket?.send(
JSON.stringify(
serializers.WebSocketRequest.jsonOrThrow(request, {
unrecognizedObjectKeys: "strip",
})
)
);

return {
status: "cancelled",
contextId: request.contextId,
};
} else if ("transcript" in request && "modelId" in request && "voice" in request && "outputFormat" in request) {
if (!request.contextId) {
request.contextId = this.#generateId();
}
if (!request.outputFormat) {
request.outputFormat = resolveOutputFormat(
this.#encoding as RawEncoding,
this.#sampleRate
);
}

this.socket?.send(
JSON.stringify(serializers.WebSocketTtsRequest.jsonOrThrow(inputs, { unrecognizedObjectKeys: "strip" }))
);

const emitter = new Emittery<{
}

this.socket?.send(
JSON.stringify(
serializers.WebSocketRequest.jsonOrThrow(request, {
unrecognizedObjectKeys: "strip",
})
)
);

const emitter = new Emittery<{
message: string;
timestamps: WordTimestamps;
}>();
const source = new Source({
}>();
const source = new Source({
sampleRate: this.#sampleRate,
encoding: this.#encoding,
container: this.#container,
});
// Used to signal that the stream is complete, either because the
// WebSocket has closed, or because the stream has finished.
const streamCompleteController = new AbortController();
// Set a timeout.
let timeoutId: ReturnType<typeof setTimeout> | null = null;
if (timeout > 0) {
timeoutId = setTimeout(streamCompleteController.abort, timeout);
}
const handleMessage = createMessageHandlerForContextId(inputs.contextId, async ({ chunk, message, data }) => {
emitter.emit("message", message);
if (data.type === "timestamps" && data.wordTimestamps) {
});

const streamCompleteController = new AbortController();

let timeoutId: ReturnType<typeof setTimeout> | null = null;
if (timeout > 0) {
timeoutId = setTimeout(() => {
streamCompleteController.abort();
}, timeout);
}

const handleMessage = createMessageHandlerForContextId(
request.contextId,
async ({ chunk, message, data }) => {
emitter.emit("message", message);

if (data.type === "timestamps" && data.wordTimestamps) {
emitter.emit("timestamps", data.wordTimestamps);
return;
}
if (isSentinel(chunk)) {
}

if (isSentinel(chunk)) {
await source.close();
streamCompleteController.abort();
return;
}
if (timeoutId) {
}

if (timeoutId) {
clearTimeout(timeoutId);
timeoutId = setTimeout(streamCompleteController.abort, timeout);
timeoutId = setTimeout(() => {
streamCompleteController.abort();
}, timeout);
}

if (chunk) {
await source.enqueue(base64ToArray([chunk], this.#encoding));
}
}
if (!chunk) {
return;
}
await source.enqueue(base64ToArray([chunk], this.#encoding));
});
this.socket?.addEventListener("message", handleMessage);
this.socket?.addEventListener("close", () => {
);

this.socket?.addEventListener("message", handleMessage);
this.socket?.addEventListener("close", () => {
streamCompleteController.abort();
});
this.socket?.addEventListener("error", () => {
});
this.socket?.addEventListener("error", () => {
streamCompleteController.abort();
});
streamCompleteController.signal.addEventListener("abort", () => {
});

streamCompleteController.signal.addEventListener("abort", () => {
source.close();
if (timeoutId) {
clearTimeout(timeoutId);
clearTimeout(timeoutId);
}
emitter.clearListeners();
});

return {
this.socket?.removeEventListener("message", handleMessage);
});

return {
source,
...getEmitteryCallbacks(emitter),
stop: streamCompleteController.abort.bind(streamCompleteController),
};
}
};
}

throw new Error(`Unknown request type: ${(request as any).type}`);
}

continue(inputs: WebSocketTtsRequest) {
continue(inputs: WebSocketRequest) {
if (!this.#isConnected) {
throw new Error("Not connected to WebSocket. Call .connect() first.");
}

if (!inputs.contextId) {
throw new Error("context_id is required to continue a context.");
}
if (!inputs.outputFormat) {
if ("transcript" in inputs && !inputs.outputFormat) {
inputs.outputFormat = resolveOutputFormat(
this.#container as "raw" | "wav" | "mp3",
this.#encoding as RawEncoding,
this.#sampleRate
);
Expand All @@ -148,7 +181,7 @@ export default class Websocket {
this.socket?.send(
JSON.stringify({
continue: true,
...serializers.WebSocketTtsRequest.jsonOrThrow(inputs, { unrecognizedObjectKeys: "strip" }),
...serializers.WebSocketRequest.jsonOrThrow(inputs, { unrecognizedObjectKeys: "strip" }),
})
);
}
Expand Down
33 changes: 7 additions & 26 deletions src/wrapper/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64 from "base64-js";
import type Emittery from "emittery";
import type { OutputFormat, RawEncoding, WebSocketResponse, WebSocketTtsRequest } from "../api";
import type { RawEncoding, WebSocketRawOutputFormat, WebSocketResponse } from "../api";

export type EmitteryCallbacks<T> = {
on: Emittery<T>["on"];
Expand Down Expand Up @@ -56,33 +56,14 @@ export const ENCODING_MAP: Record<RawEncoding, EncodingInfo> = {
* @returns The output format for the WebSocket request.
*/
export function resolveOutputFormat(
container: "raw" | "wav" | "mp3",
encoding: RawEncoding,
sampleRate: number
): OutputFormat {
switch (container) {
case "wav":
return {
container: "wav",
encoding,
sampleRate,
} as OutputFormat.Wav;
case "raw":
return {
container: "raw",
encoding,
sampleRate,
} as OutputFormat.Raw;
case "mp3":
return {
container: "mp3",
encoding,
sampleRate,
bitRate: 128,
} as OutputFormat.Mp3;
default:
throw new Error(`Unsupported container type: ${container}`);
}
): WebSocketRawOutputFormat {
return {
container: "raw",
encoding,
sampleRate,
};
}

/**
Expand Down

0 comments on commit b898e72

Please sign in to comment.