diff --git a/src/compound/tagged_union.ts b/src/compound/tagged_union.ts index d1f21f7..fa247b2 100644 --- a/src/compound/tagged_union.ts +++ b/src/compound/tagged_union.ts @@ -11,6 +11,7 @@ type FindDiscriminant = (variant: V) => D; type Keys = Exclude; +/** Union for when the inner type's don't write their own discriminant */ export class TaggedUnion< T extends Record>, V extends ValueOf<{ [K in keyof T]: InnerType }> = ValueOf< diff --git a/src/compound/tagged_union_test.ts b/src/compound/tagged_union_test.ts index 4a64201..af9b8d3 100644 --- a/src/compound/tagged_union_test.ts +++ b/src/compound/tagged_union_test.ts @@ -41,10 +41,10 @@ Deno.test({ dt.setBigUint64(0, 0n); await t.step("Write Packed", () => { - type.write(32, dt); + type.writePacked(32, dt); assertEquals( new Uint8Array(ab).subarray(0, 5), - Uint8Array.of(0, 0, 0, 0, 32), + Uint8Array.of(0, 32, 0, 0, 0), ); }); diff --git a/src/compound/union.ts b/src/compound/union.ts new file mode 100644 index 0000000..ca4dd14 --- /dev/null +++ b/src/compound/union.ts @@ -0,0 +1,67 @@ +import { u8 } from "../primitives/mod.ts"; +import { + type InnerType, + type Options, + UnsizedType, + type ValueOf, +} from "../types/mod.ts"; +import { getBiggestAlignment } from "../util.ts"; + +type FindDiscriminant = (variant: V) => D; + +type Keys = Exclude; + +/** Union for when the inner type's do write their own discriminant */ +export class Union< + T extends Record>, + V extends ValueOf<{ [K in keyof T]: InnerType }>, +> extends UnsizedType { + #record: T; + #variantFinder: FindDiscriminant>; + #discriminant = u8; + + constructor( + input: T, + variantFinder: FindDiscriminant>, + ) { + super(getBiggestAlignment(input)); + this.#record = input; + this.#variantFinder = variantFinder; + } + + readPacked(dt: DataView, options: Options = { byteOffset: 0 }): V { + const discriminant = this.#discriminant.readPacked(dt, { + byteOffset: options.byteOffset, + }); + const codec = this.#record[discriminant]; + if (!codec) throw new TypeError("Unknown discriminant"); + return codec.readPacked(dt, options) as V; + } + + read(dt: DataView, options: Options = { byteOffset: 0 }): V { + const discriminant = this.#discriminant.read(dt, { + byteOffset: options.byteOffset, + }); + const codec = this.#record[discriminant]; + if (!codec) throw new TypeError("Unknown discriminant"); + return codec.readPacked(dt, options) as V; + } + + writePacked( + variant: V, + dt: DataView, + options: Options = { byteOffset: 0 }, + ): void { + const discriminant = this.#variantFinder(variant); + const codec = this.#record[discriminant]; + if (!codec) throw new TypeError("Unknown discriminant"); + codec.writePacked(variant, dt, options); + } + + write(variant: V, dt: DataView, options: Options = { byteOffset: 0 }): void { + const discriminant = this.#variantFinder(variant); + const codec = this.#record[discriminant]; + if (!codec) throw new TypeError("Unknown discriminant"); + codec.write(variant, dt, options); + } +} diff --git a/src/compound/union_test.ts b/src/compound/union_test.ts new file mode 100644 index 0000000..3505866 --- /dev/null +++ b/src/compound/union_test.ts @@ -0,0 +1,57 @@ +import { u32le, u8 } from "../mod.ts"; +import { assertEquals, assertThrows } from "../../test_deps.ts"; +import { Union } from "./union.ts"; + +Deno.test({ + name: "Union", + fn: async (t) => { + const ab = new ArrayBuffer(8); + const dt = new DataView(ab); + const type = new Union({ + 0: u32le, + 1: u8, + 2: u8, + }, (a) => a === 32 ? 0 : 1); + + await t.step("Read", () => { + dt.setUint8(0, 1); + dt.setUint8(1, 11); + dt.setUint8(2, 22); + dt.setUint8(4, 33); + const result = type.read(dt); + assertEquals(result, 1); + }); + + await t.step("Read Packed", () => { + dt.setUint8(0, 1); + dt.setUint8(1, 11); + dt.setUint8(2, 22); + dt.setUint8(4, 33); + const result = type.readPacked(dt); + assertEquals(result, 1); + }); + + dt.setBigUint64(0, 0n); + + await t.step("Write", () => { + type.write(32, dt); + assertEquals(new Uint32Array(ab), Uint32Array.of(32, 0)); + }); + + dt.setBigUint64(0, 0n); + + await t.step("Write Packed", () => { + type.writePacked(32, dt); + assertEquals( + new Uint8Array(ab).subarray(0, 5), + Uint8Array.of(32, 0, 0, 0, 0), + ); + }); + + await t.step("OOB Read", () => { + assertThrows(() => { + type.read(dt, { byteOffset: 9 }); + }, RangeError); + }); + }, +});