Skip to content

Commit

Permalink
Make Schema a discriminated union
Browse files Browse the repository at this point in the history
This leverages the type system to better describe the API's requirements for schemas. For example, rather than saying that any schema might have an optional `items` property, we're able to express that `items` is required on array schemas and forbidden on all others.

More info on discriminated unions: https://www.typescriptlang.org/docs/handbook/2/narrowing.html#discriminated-unions
  • Loading branch information
rictic committed Nov 13, 2024
1 parent 6ec2c27 commit 0c417fb
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 65 deletions.
5 changes: 5 additions & 0 deletions .changeset/young-rivers-shout.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": minor
---

The schema types are now more specific, using a [discriminated union](https://www.typescriptlang.org/docs/handbook/2/narrowing.html#discriminated-unions) based on the 'type' field to more accurately define which fields are allowed.
86 changes: 69 additions & 17 deletions common/api-review/generative-ai-server.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,33 @@
```ts

// Warning: (ae-incompatible-release-tags) The symbol "ArraySchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface ArraySchema extends BaseSchema {
items: Schema;
maxItems?: number;
minItems?: number;
// (undocumented)
type: typeof SchemaType.ARRAY;
}

// Warning: (ae-internal-missing-underscore) The name "BaseSchema" should be prefixed with an underscore because the declaration is marked as @internal
//
// @internal
export interface BaseSchema {
description?: string;
nullable?: boolean;
}

// Warning: (ae-incompatible-release-tags) The symbol "BooleanSchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface BooleanSchema extends BaseSchema {
// (undocumented)
type: typeof SchemaType.BOOLEAN;
}

// @public
export interface CachedContent extends CachedContentBase {
createTime?: string;
Expand Down Expand Up @@ -286,8 +313,7 @@ export interface FunctionDeclarationSchema {
}

// @public
export interface FunctionDeclarationSchemaProperty extends Schema {
}
export type FunctionDeclarationSchemaProperty = Schema;

// @public
export interface FunctionDeclarationsTool {
Expand Down Expand Up @@ -368,6 +394,15 @@ export interface InlineDataPart {
text?: never;
}

// Warning: (ae-incompatible-release-tags) The symbol "IntegerSchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface IntegerSchema extends BaseSchema {
format?: "int32" | "int64";
// (undocumented)
type: typeof SchemaType.INTEGER;
}

// @public (undocumented)
export interface ListCacheResponse {
// (undocumented)
Expand All @@ -392,6 +427,27 @@ export interface ListParams {
pageToken?: string;
}

// Warning: (ae-incompatible-release-tags) The symbol "NumberSchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface NumberSchema extends BaseSchema {
format?: "float" | "double";
// (undocumented)
type: typeof SchemaType.NUMBER;
}

// Warning: (ae-incompatible-release-tags) The symbol "ObjectSchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface ObjectSchema extends BaseSchema {
properties: {
[k: string]: Schema;
};
required?: string[];
// (undocumented)
type: typeof SchemaType.OBJECT;
}

// @public
export enum Outcome {
OUTCOME_DEADLINE_EXCEEDED = "outcome_deadline_exceeded",
Expand All @@ -413,8 +469,7 @@ export interface RequestOptions {
}

// @public
export interface ResponseSchema extends Schema {
}
export type ResponseSchema = Schema;

// @public
export interface RpcStatus {
Expand All @@ -424,19 +479,7 @@ export interface RpcStatus {
}

// @public
export interface Schema {
description?: string;
enum?: string[];
example?: unknown;
format?: string;
items?: Schema;
nullable?: boolean;
properties?: {
[k: string]: Schema;
};
required?: string[];
type?: SchemaType;
}
export type Schema = StringSchema | NumberSchema | IntegerSchema | BooleanSchema | ArraySchema | ObjectSchema;

// @public
export enum SchemaType {
Expand All @@ -453,6 +496,15 @@ export interface SingleRequestOptions extends RequestOptions {
signal?: AbortSignal;
}

// Warning: (ae-incompatible-release-tags) The symbol "StringSchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface StringSchema extends BaseSchema {
enum?: string[];
// (undocumented)
type: typeof SchemaType.STRING;
}

// @public
export interface TextPart {
// (undocumented)
Expand Down
86 changes: 69 additions & 17 deletions common/api-review/generative-ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
```ts

// Warning: (ae-incompatible-release-tags) The symbol "ArraySchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface ArraySchema extends BaseSchema {
items: Schema;
maxItems?: number;
minItems?: number;
// (undocumented)
type: typeof SchemaType.ARRAY;
}

// @public
export interface BaseParams {
// (undocumented)
Expand All @@ -12,6 +23,14 @@ export interface BaseParams {
safetySettings?: SafetySetting[];
}

// Warning: (ae-internal-missing-underscore) The name "BaseSchema" should be prefixed with an underscore because the declaration is marked as @internal
//
// @internal
export interface BaseSchema {
description?: string;
nullable?: boolean;
}

// @public
export interface BatchEmbedContentsRequest {
// (undocumented)
Expand All @@ -34,6 +53,14 @@ export enum BlockReason {
SAFETY = "SAFETY"
}

// Warning: (ae-incompatible-release-tags) The symbol "BooleanSchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface BooleanSchema extends BaseSchema {
// (undocumented)
type: typeof SchemaType.BOOLEAN;
}

// @public
export interface CachedContent extends CachedContentBase {
createTime?: string;
Expand Down Expand Up @@ -355,8 +382,7 @@ export interface FunctionDeclarationSchema {
}

// @public
export interface FunctionDeclarationSchemaProperty extends Schema {
}
export type FunctionDeclarationSchemaProperty = Schema;

// @public
export interface FunctionDeclarationsTool {
Expand Down Expand Up @@ -645,6 +671,15 @@ export interface InlineDataPart {
text?: never;
}

// Warning: (ae-incompatible-release-tags) The symbol "IntegerSchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface IntegerSchema extends BaseSchema {
format?: "int32" | "int64";
// (undocumented)
type: typeof SchemaType.INTEGER;
}

// @public
export interface LogprobsCandidate {
logProbability: number;
Expand Down Expand Up @@ -672,6 +707,27 @@ export interface ModelParams extends BaseParams {
tools?: Tool[];
}

// Warning: (ae-incompatible-release-tags) The symbol "NumberSchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface NumberSchema extends BaseSchema {
format?: "float" | "double";
// (undocumented)
type: typeof SchemaType.NUMBER;
}

// Warning: (ae-incompatible-release-tags) The symbol "ObjectSchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface ObjectSchema extends BaseSchema {
properties: {
[k: string]: Schema;
};
required?: string[];
// (undocumented)
type: typeof SchemaType.OBJECT;
}

// @public
export enum Outcome {
OUTCOME_DEADLINE_EXCEEDED = "outcome_deadline_exceeded",
Expand Down Expand Up @@ -706,8 +762,7 @@ export interface RequestOptions {
}

// @public
export interface ResponseSchema extends Schema {
}
export type ResponseSchema = Schema;

// @public
export interface RetrievalMetadata {
Expand All @@ -731,19 +786,7 @@ export interface SafetySetting {
}

// @public
export interface Schema {
description?: string;
enum?: string[];
example?: unknown;
format?: string;
items?: Schema;
nullable?: boolean;
properties?: {
[k: string]: Schema;
};
required?: string[];
type?: SchemaType;
}
export type Schema = StringSchema | NumberSchema | IntegerSchema | BooleanSchema | ArraySchema | ObjectSchema;

// @public
export enum SchemaType {
Expand Down Expand Up @@ -779,6 +822,15 @@ export interface StartChatParams extends BaseParams {
tools?: Tool[];
}

// Warning: (ae-incompatible-release-tags) The symbol "StringSchema" is marked as @public, but its signature references "BaseSchema" which is marked as @internal
//
// @public
export interface StringSchema extends BaseSchema {
enum?: string[];
// (undocumented)
type: typeof SchemaType.STRING;
}

// @public
export enum TaskType {
// (undocumented)
Expand Down
22 changes: 11 additions & 11 deletions src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
FunctionCallingMode,
HarmBlockThreshold,
HarmCategory,
ObjectSchema,
SchemaType,
} from "../../types";
import { getMockResponse } from "../../test-utils/mock-response";
Expand Down Expand Up @@ -60,7 +61,6 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down Expand Up @@ -93,7 +93,8 @@ describe("GenerativeModel", () => {
SchemaType.OBJECT,
);
expect(
genModel.generationConfig?.responseSchema.properties.testField.type,
(genModel.generationConfig?.responseSchema as ObjectSchema).properties
.testField.type,
).to.equal(SchemaType.STRING);
expect(genModel.generationConfig?.presencePenalty).to.equal(0.6);
expect(genModel.generationConfig?.frequencyPenalty).to.equal(0.5);
Expand Down Expand Up @@ -172,7 +173,6 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down Expand Up @@ -206,7 +206,6 @@ describe("GenerativeModel", () => {
properties: {
newTestField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down Expand Up @@ -332,16 +331,17 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
},
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
expect(genModel.generationConfig.responseSchema.properties.testField).to
.exist;
expect(
(genModel.generationConfig.responseSchema as ObjectSchema).properties
.testField,
).to.exist;
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
Expand Down Expand Up @@ -372,7 +372,6 @@ describe("GenerativeModel", () => {
properties: {
testField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand All @@ -381,8 +380,10 @@ describe("GenerativeModel", () => {
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.generationConfig.responseSchema.properties.testField).to
.exist;
expect(
(genModel.generationConfig.responseSchema as ObjectSchema).properties
.testField,
).to.exist;
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE,
Expand All @@ -403,7 +404,6 @@ describe("GenerativeModel", () => {
properties: {
newTestField: {
type: SchemaType.STRING,
properties: {},
},
},
},
Expand Down
Loading

0 comments on commit 0c417fb

Please sign in to comment.