diff --git a/frontend/src/components/data-table/__test__/columns.test.tsx b/frontend/src/components/data-table/__test__/columns.test.tsx index 4e5861e3e29..c890adc3917 100644 --- a/frontend/src/components/data-table/__test__/columns.test.tsx +++ b/frontend/src/components/data-table/__test__/columns.test.tsx @@ -3,6 +3,7 @@ import { expect, test } from "vitest"; import { uniformSample } from "../uniformSample"; import { UrlDetector } from "../url-detector"; import { render } from "@testing-library/react"; +import { inferFieldTypes } from "../columns"; test("uniformSample", () => { const items = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]; @@ -31,3 +32,89 @@ test("UrlDetector renders URLs as hyperlinks", () => { expect(link).toBeTruthy(); expect(link?.href).toBe("https://example.com/"); }); + +test("inferFieldTypes", () => { + const data = [ + { + a: 1, + b: "foo", + c: null, + d: { mime: "text/csv" }, + e: [1, 2, 3], + f: true, + g: false, + h: new Date(), + }, + ]; + const fieldTypes = inferFieldTypes(data); + expect(fieldTypes).toMatchInlineSnapshot(` + { + "a": [ + "number", + "number", + ], + "b": [ + "string", + "string", + ], + "c": [ + "unknown", + "unknown", + ], + "d": [ + "unknown", + "unknown", + ], + "e": [ + "unknown", + "unknown", + ], + "f": [ + "boolean", + "boolean", + ], + "g": [ + "boolean", + "boolean", + ], + "h": [ + "datetime", + "datetime", + ], + } + `); +}); + +test("inferFieldTypes with nulls", () => { + const data = [{ a: 1, b: null }]; + const fieldTypes = inferFieldTypes(data); + expect(fieldTypes).toMatchInlineSnapshot(` + { + "a": [ + "number", + "number", + ], + "b": [ + "unknown", + "unknown", + ], + } + `); +}); + +test("inferFieldTypes with mimetypes", () => { + const data = [{ a: { mime: "text/csv" }, b: { mime: "image/png" } }]; + const fieldTypes = inferFieldTypes(data); + expect(fieldTypes).toMatchInlineSnapshot(` + { + "a": [ + "unknown", + "unknown", + ], + "b": [ + "unknown", + "unknown", + ], + } + `); +}); diff --git a/frontend/src/components/data-table/chart-spec-model.tsx b/frontend/src/components/data-table/chart-spec-model.tsx index 97bc5ff3fad..94f001a172e 100644 --- a/frontend/src/components/data-table/chart-spec-model.tsx +++ b/frontend/src/components/data-table/chart-spec-model.tsx @@ -3,6 +3,7 @@ import type { TopLevelFacetedUnitSpec } from "@/plugins/impl/data-explorer/queri import { mint, orange, slate } from "@radix-ui/colors"; import type { ColumnHeaderSummary, FieldTypes } from "./types"; import { asURL } from "@/utils/url"; +import { parseCsvData } from "@/plugins/impl/vega/loader"; const MAX_BAR_HEIGHT = 24; // px const MAX_BAR_WIDTH = 28; // px @@ -33,6 +34,16 @@ export class ColumnChartSpecModel { includeCharts: boolean; }, ) { + // Support CSV data as a string + const isCsv = + typeof this.data === "string" && + !this.data.startsWith("./@file") && + !this.data.startsWith("/@file") && + !this.data.startsWith("data:text/csv"); + if (isCsv) { + this.data = parseCsvData(this.data) as T[]; + } + this.columnSummaries = new Map(summaries.map((s) => [s.column, s])); } @@ -48,14 +59,12 @@ export class ColumnChartSpecModel { if (!this.data) { return null; } - if (typeof this.data !== "string") { - return null; - } - const base: Omit = { - data: { - url: asURL(this.data).href, - } as TopLevelFacetedUnitSpec["data"], + data: (typeof this.data === "string" + ? { + url: asURL(this.data).href, + } + : { values: this.data }) as TopLevelFacetedUnitSpec["data"], background: "transparent", config: { view: { diff --git a/frontend/src/components/data-table/columns.tsx b/frontend/src/components/data-table/columns.tsx index c9aaa934255..2f82d689651 100644 --- a/frontend/src/components/data-table/columns.tsx +++ b/frontend/src/components/data-table/columns.tsx @@ -8,32 +8,46 @@ import { } from "./column-header"; import { Checkbox } from "../ui/checkbox"; import { MimeCell } from "./mime-cell"; -import { uniformSample } from "./uniformSample"; import type { DataType } from "@/core/kernel/messages"; import { TableColumnSummary } from "./column-summary"; import type { FilterType } from "./filters"; import type { FieldTypesWithExternalType } from "./types"; import { UrlDetector } from "./url-detector"; -import { Arrays } from "@/utils/arrays"; import { cn } from "@/utils/cn"; +import { Objects } from "@/utils/objects"; +import { uniformSample } from "./uniformSample"; -interface ColumnInfo { - key: string; - type: "primitive" | "mime"; +function inferDataType(value: unknown): DataType { + if (typeof value === "string") { + return "string"; + } + if (typeof value === "number") { + return "number"; + } + if (value instanceof Date) { + return "datetime"; + } + if (typeof value === "boolean") { + return "boolean"; + } + if (value == null) { + return "unknown"; + } + return "unknown"; } -function getColumnInfo(items: T[]): ColumnInfo[] { +export function inferFieldTypes(items: T[]): FieldTypesWithExternalType { // No items if (items.length === 0) { - return Arrays.EMPTY; + return {}; } // Not an object if (typeof items[0] !== "object") { - return Arrays.EMPTY; + return {}; } - const keys = new Map(); + const fieldTypes: FieldTypesWithExternalType = {}; // This can be slow for large datasets, // so only sample 10 evenly distributed rows @@ -44,67 +58,58 @@ function getColumnInfo(items: T[]): ColumnInfo[] { // We will be a bit defensive and assume values are not homogeneous. // If any is a mimetype, then we will treat it as a mimetype (i.e. not sortable) Object.entries(item as object).forEach(([key, value], idx) => { - const currentValue = keys.get(key); + const currentValue = fieldTypes[key]; if (!currentValue) { // Set for the first time - keys.set(key, { - key, - type: isPrimitiveOrNullish(value) ? "primitive" : "mime", - }); + const dtype = inferDataType(value); + fieldTypes[key] = [dtype, dtype]; } - // If we have a value, and it is a primitive, we could possibly upgrade it to a mime - if ( - currentValue && - currentValue.type === "primitive" && - !isPrimitiveOrNullish(value) - ) { - keys.set(key, { - key, - type: "mime", - }); + + // If its not null, override the type + if (value != null) { + // This can be lossy as we infer take the last seen type + const dtype = inferDataType(value); + fieldTypes[key] = [dtype, dtype]; } }); }); - return [...keys.values()]; + return fieldTypes; } export const NAMELESS_COLUMN_PREFIX = "__m_column__"; export function generateColumns({ - items, rowHeaders, selection, fieldTypes, textJustifyColumns, wrappedColumns, }: { - items: T[]; rowHeaders: string[]; selection: "single" | "multi" | null; - fieldTypes?: FieldTypesWithExternalType; + fieldTypes: FieldTypesWithExternalType; textJustifyColumns?: Record; wrappedColumns?: string[]; }): Array> { - const columnInfo = getColumnInfo(items); const rowHeadersSet = new Set(rowHeaders); - const columns = columnInfo.map( - (info, idx): ColumnDef => ({ - id: info.key || `${NAMELESS_COLUMN_PREFIX}${idx}`, + const columns = Objects.entries(fieldTypes).map( + ([key, types], idx): ColumnDef => ({ + id: key || `${NAMELESS_COLUMN_PREFIX}${idx}`, // Use an accessorFn instead of an accessorKey because column names // may have periods in them ... // https://github.com/TanStack/table/issues/1671 accessorFn: (row) => { // eslint-disable-next-line @typescript-eslint/no-explicit-any - return (row as any)[info.key]; + return (row as any)[key]; }, header: ({ column }) => { const dtype = column.columnDef.meta?.dtype; const headerWithType = (
- {info.key} + {key} {dtype && ( {dtype} )} @@ -112,7 +117,7 @@ export function generateColumns({ ); // Row headers have no summaries - if (rowHeadersSet.has(info.key)) { + if (rowHeadersSet.has(key)) { return ( ); @@ -120,22 +125,23 @@ export function generateColumns({ return ( } + summary={} /> ); }, cell: ({ column, renderValue, getValue }) => { // Row headers are bold - if (rowHeadersSet.has(info.key)) { + if (rowHeadersSet.has(key)) { return {String(renderValue())}; } const value = getValue(); - const justify = textJustifyColumns?.[info.key]; - const wrapped = wrappedColumns?.includes(info.key); + const justify = textJustifyColumns?.[key]; + const wrapped = wrappedColumns?.includes(key); const format = column.getColumnFormatting?.(); if (format) { @@ -166,16 +172,13 @@ export function generateColumns({
); }, - // Only enable sorting for primitive types and non-row headers - enableSorting: info.type === "primitive" && !rowHeadersSet.has(info.key), // Remove any default filtering filterFn: undefined, meta: { - type: info.type, - rowHeader: rowHeadersSet.has(info.key), - filterType: getFilterTypeForFieldType(fieldTypes?.[info.key]?.[0]), - dtype: fieldTypes?.[info.key]?.[1], - dataType: fieldTypes?.[info.key]?.[0], + rowHeader: rowHeadersSet.has(key), + filterType: getFilterTypeForFieldType(types[0]), + dtype: types[1], + dataType: types[0], }, }), ); diff --git a/frontend/src/components/data-table/filters.ts b/frontend/src/components/data-table/filters.ts index 14b422212fe..ca0d058f9c3 100644 --- a/frontend/src/components/data-table/filters.ts +++ b/frontend/src/components/data-table/filters.ts @@ -10,7 +10,6 @@ import type { RowData } from "@tanstack/react-table"; declare module "@tanstack/react-table" { //allows us to define custom properties for our columns interface ColumnMeta { - type?: "primitive" | "mime"; rowHeader?: boolean; dtype?: string; dataType?: DataType; diff --git a/frontend/src/components/datasets/icons.tsx b/frontend/src/components/datasets/icons.tsx index d72b19e4f51..922ccfb3c30 100644 --- a/frontend/src/components/datasets/icons.tsx +++ b/frontend/src/components/datasets/icons.tsx @@ -6,7 +6,6 @@ import { CalendarIcon, HashIcon, TypeIcon, - ListOrderedIcon, type LucideIcon, CalendarClockIcon, ClockIcon, @@ -23,6 +22,6 @@ export const DATA_TYPE_ICON: Record = { datetime: CalendarClockIcon, number: HashIcon, string: TypeIcon, - integer: ListOrderedIcon, + integer: HashIcon, unknown: CurlyBracesIcon, }; diff --git a/frontend/src/components/editor/file-tree/renderers.tsx b/frontend/src/components/editor/file-tree/renderers.tsx index aa46377ebac..e74ad2a6ccb 100644 --- a/frontend/src/components/editor/file-tree/renderers.tsx +++ b/frontend/src/components/editor/file-tree/renderers.tsx @@ -1,5 +1,8 @@ /* Copyright 2024 Marimo. All rights reserved. */ -import { generateColumns } from "@/components/data-table/columns"; +import { + generateColumns, + inferFieldTypes, +} from "@/components/data-table/columns"; import { DataTable } from "@/components/data-table/data-table"; import { parseCsvData } from "@/plugins/impl/vega/loader"; import { Objects } from "@/utils/objects"; @@ -17,14 +20,15 @@ export const CsvViewer: React.FC<{ contents: string }> = ({ contents }) => { pageIndex: 0, pageSize: PAGE_SIZE, }); + const fieldTypes = useMemo(() => inferFieldTypes(data), [data]); const columns = useMemo( () => - generateColumns({ - items: data, + generateColumns({ rowHeaders: Arrays.EMPTY, selection: null, + fieldTypes, }), - [data], + [fieldTypes], ); return ( diff --git a/frontend/src/plugins/impl/DataTablePlugin.tsx b/frontend/src/plugins/impl/DataTablePlugin.tsx index f2e705a6a78..2d36a8598c9 100644 --- a/frontend/src/plugins/impl/DataTablePlugin.tsx +++ b/frontend/src/plugins/impl/DataTablePlugin.tsx @@ -2,7 +2,10 @@ import { memo, useEffect, useMemo, useState } from "react"; import { z } from "zod"; import { DataTable } from "../../components/data-table/data-table"; -import { generateColumns } from "../../components/data-table/columns"; +import { + generateColumns, + inferFieldTypes, +} from "../../components/data-table/columns"; import { Labeled } from "./common/labeled"; import { Alert, AlertTitle } from "@/components/ui/alert"; import { rpc } from "../core/rpc"; @@ -462,19 +465,19 @@ const DataTableComponent = ({ ); }, [fieldTypes, columnSummaries]); + const fieldTypesOrInferred = fieldTypes ?? inferFieldTypes(data); + const columns = useMemo( () => generateColumns({ - items: data, rowHeaders: rowHeaders, selection, - fieldTypes: fieldTypes ?? {}, + fieldTypes: fieldTypesOrInferred, textJustifyColumns, wrappedColumns, }), /* eslint-disable react-hooks/exhaustive-deps */ [ - data, useDeepCompareMemoize([ selection, fieldTypes, @@ -483,7 +486,6 @@ const DataTableComponent = ({ wrappedColumns, ]), ], - /* eslint-enable react-hooks/exhaustive-deps */ ); const rowSelection = useMemo( diff --git a/frontend/src/stories/data-table.stories.tsx b/frontend/src/stories/data-table.stories.tsx index 6de94389f22..03a2cc8f157 100644 --- a/frontend/src/stories/data-table.stories.tsx +++ b/frontend/src/stories/data-table.stories.tsx @@ -1,6 +1,9 @@ /* Copyright 2024 Marimo. All rights reserved. */ import { DataTable } from "@/components/data-table/data-table"; -import { generateColumns } from "@/components/data-table/columns"; +import { + generateColumns, + inferFieldTypes, +} from "@/components/data-table/columns"; import { Functions } from "@/utils/functions"; export default { @@ -25,7 +28,7 @@ export const Default = { }, ]} columns={generateColumns({ - items: [ + fieldTypes: inferFieldTypes([ { first_name: "Michael", last_name: "Scott", @@ -34,7 +37,7 @@ export const Default = { first_name: "Dwight", last_name: "Schrute", }, - ], + ]), rowHeaders: [], selection: null, })} @@ -52,7 +55,7 @@ export const Empty1 = { setPaginationState={Functions.NOOP} data={[]} columns={generateColumns({ - items: [ + fieldTypes: inferFieldTypes([ { first_name: "Michael", last_name: "Scott", @@ -61,7 +64,7 @@ export const Empty1 = { first_name: "Dwight", last_name: "Schrute", }, - ], + ]), rowHeaders: [], selection: null, })} @@ -101,7 +104,7 @@ export const Pagination = { }, ]} columns={generateColumns({ - items: [ + fieldTypes: inferFieldTypes([ { first_name: "Michael", last_name: "Scott", @@ -110,7 +113,7 @@ export const Pagination = { first_name: "Dwight", last_name: "Schrute", }, - ], + ]), rowHeaders: [], selection: null, })}