From b4370c04cabb9e0b80254869ee520ba0176be891 Mon Sep 17 00:00:00 2001 From: ghanse <163584195+ghanse@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:28:03 -0500 Subject: [PATCH] Add stock ticker dataset (#303) * Add stock ticker dataset * Add multi-table sales order dataset * Fix code formatting * Fix code formatting * Add min and max lat-lon options --------- Co-authored-by: Ronan Stokes <42389040+ronanstokes-db@users.noreply.github.com> --- dbldatagen/datasets/__init__.py | 4 + dbldatagen/datasets/basic_geometries.py | 24 +- dbldatagen/datasets/basic_stock_ticker.py | 103 ++++ .../multi_table_sales_order_provider.py | 577 ++++++++++++++++++ tests/test_standard_dataset_providers.py | 239 ++++++-- 5 files changed, 896 insertions(+), 51 deletions(-) create mode 100644 dbldatagen/datasets/basic_stock_ticker.py create mode 100644 dbldatagen/datasets/multi_table_sales_order_provider.py diff --git a/dbldatagen/datasets/__init__.py b/dbldatagen/datasets/__init__.py index 17f5b212..145a9148 100644 --- a/dbldatagen/datasets/__init__.py +++ b/dbldatagen/datasets/__init__.py @@ -1,16 +1,20 @@ from .dataset_provider import DatasetProvider, dataset_definition from .basic_geometries import BasicGeometriesProvider from .basic_process_historian import BasicProcessHistorianProvider +from .basic_stock_ticker import BasicStockTickerProvider from .basic_telematics import BasicTelematicsProvider from .basic_user import BasicUserProvider from .benchmark_groupby import BenchmarkGroupByProvider +from .multi_table_sales_order_provider import MultiTableSalesOrderProvider from .multi_table_telephony_provider import MultiTableTelephonyProvider __all__ = ["dataset_provider", "basic_geometries", "basic_process_historian", + "basic_stock_ticker", "basic_telematics", "basic_user", "benchmark_groupby", + "multi_table_sales_order_provider", "multi_table_telephony_provider" ] diff --git a/dbldatagen/datasets/basic_geometries.py b/dbldatagen/datasets/basic_geometries.py index 1673bfd0..02bc5806 100644 --- a/dbldatagen/datasets/basic_geometries.py +++ b/dbldatagen/datasets/basic_geometries.py @@ -29,10 +29,18 @@ class BasicGeometriesProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset """ MIN_LOCATION_ID = 1000000 MAX_LOCATION_ID = 9223372036854775807 + DEFAULT_MIN_LAT = -90.0 + DEFAULT_MAX_LAT = 90.0 + DEFAULT_MIN_LON = -180.0 + DEFAULT_MAX_LON = 180.0 COLUMN_COUNT = 2 ALLOWED_OPTIONS = [ "geometryType", "maxVertices", + "minLatitude", + "maxLatitude", + "minLongitude", + "maxLongitude", "random" ] @@ -45,6 +53,10 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions generateRandom = options.get("random", False) geometryType = options.get("geometryType", "point") maxVertices = options.get("maxVertices", 1 if geometryType == "point" else 3) + minLatitude = options.get("minLatitude", self.DEFAULT_MIN_LAT) + maxLatitude = options.get("maxLatitude", self.DEFAULT_MAX_LAT) + minLongitude = options.get("minLongitude", self.DEFAULT_MIN_LON) + maxLongitude = options.get("maxLongitude", self.DEFAULT_MAX_LON) assert tableName is None or tableName == DatasetProvider.DEFAULT_TABLE_NAME, "Invalid table name" if rows is None or rows < 0: @@ -62,9 +74,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions if maxVertices > 1: w.warn('Ignoring property maxVertices for point geometries') df_spec = ( - df_spec.withColumn("lat", "float", minValue=-90.0, maxValue=90.0, + df_spec.withColumn("lat", "float", minValue=minLatitude, maxValue=maxLatitude, step=1e-5, random=generateRandom, omit=True) - .withColumn("lon", "float", minValue=-180.0, maxValue=180.0, + .withColumn("lon", "float", minValue=minLongitude, maxValue=maxLongitude, step=1e-5, random=generateRandom, omit=True) .withColumn("wkt", "string", expr="concat('POINT(', lon, ' ', lat, ')')") ) @@ -75,9 +87,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions j = 0 while j < maxVertices: df_spec = ( - df_spec.withColumn(f"lat_{j}", "float", minValue=-90.0, maxValue=90.0, + df_spec.withColumn(f"lat_{j}", "float", minValue=minLatitude, maxValue=maxLatitude, step=1e-5, random=generateRandom, omit=True) - .withColumn(f"lon_{j}", "float", minValue=-180.0, maxValue=180.0, + .withColumn(f"lon_{j}", "float", minValue=minLongitude, maxValue=maxLongitude, step=1e-5, random=generateRandom, omit=True) ) j = j + 1 @@ -93,9 +105,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions j = 0 while j < maxVertices: df_spec = ( - df_spec.withColumn(f"lat_{j}", "float", minValue=-90.0, maxValue=90.0, + df_spec.withColumn(f"lat_{j}", "float", minValue=minLatitude, maxValue=maxLatitude, step=1e-5, random=generateRandom, omit=True) - .withColumn(f"lon_{j}", "float", minValue=-180.0, maxValue=180.0, + .withColumn(f"lon_{j}", "float", minValue=minLongitude, maxValue=maxLongitude, step=1e-5, random=generateRandom, omit=True) ) j = j + 1 diff --git a/dbldatagen/datasets/basic_stock_ticker.py b/dbldatagen/datasets/basic_stock_ticker.py new file mode 100644 index 00000000..7f67576d --- /dev/null +++ b/dbldatagen/datasets/basic_stock_ticker.py @@ -0,0 +1,103 @@ +from random import random + +from .dataset_provider import DatasetProvider, dataset_definition + + +@dataset_definition(name="basic/stock_ticker", + summary="Stock ticker dataset", + autoRegister=True, + supportsStreaming=True) +class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): + """ + Basic Stock Ticker Dataset + ======================== + + This is a basic stock ticker dataset with time-series `symbol`, `open`, `close`, `high`, `low`, + `adj_close`, and `volume` values. + + It takes the following options when retrieving the table: + - rows : number of rows to generate + - partitions: number of partitions to use + - numSymbols: number of unique stock ticker symbols + - startDate: first date for stock ticker data + - endDate: last date for stock ticker data + + As the data specification is a DataGenerator object, you can add further columns to the data set and + add constraints (when the feature is available) + + Note that this dataset does not use any features that would prevent it from being used as a source for a + streaming dataframe, and so the flag `supportsStreaming` is set to True. + + """ + DEFAULT_NUM_SYMBOLS = 100 + DEFAULT_START_DATE = "2024-10-01" + COLUMN_COUNT = 8 + ALLOWED_OPTIONS = [ + "numSymbols", + "startDate" + ] + + @DatasetProvider.allowed_options(options=ALLOWED_OPTIONS) + def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, **options): + import dbldatagen as dg + + numSymbols = options.get("numSymbols", self.DEFAULT_NUM_SYMBOLS) + startDate = options.get("startDate", self.DEFAULT_START_DATE) + + assert tableName is None or tableName == DatasetProvider.DEFAULT_TABLE_NAME, "Invalid table name" + if rows is None or rows < 0: + rows = DatasetProvider.DEFAULT_ROWS + if partitions is None or partitions < 0: + partitions = self.autoComputePartitions(rows, self.COLUMN_COUNT) + if numSymbols <= 0: + raise ValueError("'numSymbols' must be > 0") + + df_spec = ( + dg.DataGenerator(sparkSession=sparkSession, rows=rows, + partitions=partitions, randomSeedMethod="hash_fieldname") + .withColumn("symbol_id", "long", minValue=676, maxValue=676 + numSymbols - 1) + .withColumn("rand_value", "float", minValue=0.0, maxValue=1.0, step=0.1, + baseColumn="symbol_id", omit=True) + .withColumn("symbol", "string", + expr="""concat_ws('', transform(split(conv(symbol_id, 10, 26), ''), + x -> case when x < 10 then char(ascii(x) - 48 + 65) else char(ascii(x) + 10) end))""") + .withColumn("days_from_start_date", "int", expr=f"floor(id / {numSymbols})", omit=True) + .withColumn("post_date", "date", expr=f"date_add(cast('{startDate}' as date), days_from_start_date)") + .withColumn("start_value", "decimal(11,2)", + values=[1.0 + 199.0 * random() for _ in range(int(numSymbols / 10))], omit=True) + .withColumn("growth_rate", "float", values=[-0.1 + 0.35 * random() for _ in range(int(numSymbols / 10))], + baseColumn="symbol_id") + .withColumn("volatility", "float", values=[0.0075 * random() for _ in range(int(numSymbols / 10))], + baseColumn="symbol_id", omit=True) + .withColumn("prev_modifier_sign", "float", + expr=f"case when sin((id - {numSymbols}) % 17) > 0 then -1.0 else 1.0 end""", + omit=True) + .withColumn("modifier_sign", "float", + expr="case when sin(id % 17) > 0 then -1.0 else 1.0 end", + omit=True) + .withColumn("open_base", "decimal(11,2)", + expr=f"""start_value + + (volatility * prev_modifier_sign * start_value * sin((id - {numSymbols}) % 17)) + + (growth_rate * start_value * (days_from_start_date - 1) / 365)""", + omit=True) + .withColumn("close_base", "decimal(11,2)", + expr="""start_value + + (volatility * start_value * sin(id % 17)) + + (growth_rate * start_value * days_from_start_date / 365)""", + omit=True) + .withColumn("high_base", "decimal(11,2)", + expr="greatest(open_base, close_base) + rand() * volatility * open_base", + omit=True) + .withColumn("low_base", "decimal(11,2)", + expr="least(open_base, close_base) - rand() * volatility * open_base", + omit=True) + .withColumn("open", "decimal(11,2)", expr="greatest(open_base, 0.0)") + .withColumn("close", "decimal(11,2)", expr="greatest(close_base, 0.0)") + .withColumn("high", "decimal(11,2)", expr="greatest(high_base, 0.0)") + .withColumn("low", "decimal(11,2)", expr="greatest(low_base, 0.0)") + .withColumn("dividend", "decimal(4,2)", expr="0.05 * rand_value * close", omit=True) + .withColumn("adj_close", "decimal(11,2)", expr="greatest(close - dividend, 0.0)") + .withColumn("volume", "long", minValue=100000, maxValue=5000000, random=True) + ) + + return df_spec diff --git a/dbldatagen/datasets/multi_table_sales_order_provider.py b/dbldatagen/datasets/multi_table_sales_order_provider.py new file mode 100644 index 00000000..d20d3e03 --- /dev/null +++ b/dbldatagen/datasets/multi_table_sales_order_provider.py @@ -0,0 +1,577 @@ +from .dataset_provider import DatasetProvider, dataset_definition + + +@dataset_definition(name="multi_table/sales_order", summary="Multi-table sales order dataset", supportsStreaming=True, + autoRegister=True, + tables=["customers", "carriers", "catalog_items", "base_orders", "base_order_line_items", + "base_order_shipments", "base_invoices"], + associatedDatasets=["orders", "order_line_items", "order_shipments", "invoices"]) +class MultiTableSalesOrderProvider(DatasetProvider): + """ Generates a multi-table sales order scenario + + See [https://databrickslabs.github.io/dbldatagen/public_docs/multi_table_data.html] + + It generates one of several possible tables: + + customers - which model customers + carriers - which model shipping carriers + catalog_items - which model items in a sales catalog + base_orders - which model basic sales order data without relations + base_order_line_items - which model basic sales order line item data without relations + base_order_shipments - which model basic sales order shipment data without relations + base_invoices - which model basic invoice data without relations + + Once the above tables have been computed, you can retrieve the combined tables for: + + orders - which model complete sales orders + order_line_items - which model complete sales order line items + order_shipments - which model complete sales order shipments + invoices - which model complete invoices + + using `Datasets(...).getCombinedTable("orders")`, `Datasets(...).getCombinedTable("order_line_items")`, + `Datasets(...).getCombinedTable("order_shipments")`, `Datasets(...).getCombinedTable("invoices")` + + The following options are supported: + - numCustomers - number of unique customers + - numCarriers - number of unique shipping carriers + - numCatalogItems - number of unique catalog items + - numOrders - number of unique orders + - lineItemsPerOrder - number of line items per order + - startDate - earliest order date + - endDate - latest order date + - dummyValues - number of dummy values to widen the tables + + While it is possible to specify the number of rows explicitly when getting each table generator, the default will + be to compute the number of rows from these options. + """ + MAX_LONG = 9223372036854775807 + DEFAULT_NUM_CUSTOMERS = 1_000 + DEFAULT_NUM_CARRIERS = 100 + DEFAULT_NUM_CATALOG_ITEMS = 1_000 + DEFAULT_NUM_ORDERS = 100_000 + DEFAULT_LINE_ITEMS_PER_ORDER = 3 + DEFAULT_START_DATE = "2024-01-01" + DEFAULT_END_DATE = "2025-01-01" + CUSTOMER_MIN_VALUE = 10_000 + CARRIER_MIN_VALUE = 100 + CATALOG_ITEM_MIN_VALUE = 10_000 + ORDER_MIN_VALUE = 10_000_000 + ORDER_LINE_ITEM_MIN_VALUE = 100_000_000 + SHIPMENT_MIN_VALUE = 10_000_000 + INVOICE_MIN_VALUE = 1_000_000 + + def getCustomers(self, sparkSession, *, rows, partitions, numCustomers, dummyValues): + import dbldatagen as dg + + # Validate the options: + if numCustomers is None or numCustomers < 0: + numCustomers = self.DEFAULT_NUM_CUSTOMERS + if rows is None or rows < 0: + rows = numCustomers + if partitions is None or partitions < 0: + partitions = self.autoComputePartitions(rows, 9 + dummyValues) + + # Create the base data generation spec: + customers_data_spec = ( + dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) + .withColumn("customer_id", "integer", minValue=self.CUSTOMER_MIN_VALUE, uniqueValues=numCustomers) + .withColumn("customer_name", "string", prefix="CUSTOMER", baseColumn="customer_id") + .withColumn("sic_code", "integer", minValue=100, maxValue=9_995, random=True) + .withColumn("num_employees", "integer", minValue=1, maxValue=10_000, random=True) + .withColumn("region", "string", values=["AMER", "EMEA", "APAC", "NONE"], random=True) + .withColumn("phone_number", "string", template="ddd-ddd-dddd") + .withColumn("email_user_name", "string", + values=["billing", "procurement", "office", "purchasing", "buyer"], omit=True) + .withColumn("email_address", "string", expr="concat(email_user_name, '@', lower(customer_name), '.com')") + .withColumn("payment_terms", "string", values=["DUE_ON_RECEIPT", "NET30", "NET60", "NET120"]) + .withColumn("created_on", "date", begin="2000-01-01", end=self.DEFAULT_START_DATE, interval="1 DAY") + .withColumn("created_by", "integer", minValue=1_000, maxValue=9_999, random=True) + .withColumn("is_updated", "boolean", expr="rand() > 0.75", omit=True) + .withColumn("updated_after_days", "integer", minValue=0, maxValue=1_000, random=True, omit=True) + .withColumn("updated_on", "date", expr="""case when is_updated then created_on + else date_add(created_on, updated_after_days) end""") + .withColumn("updated_by_user", "integer", minValue=1_000, maxValue=9_999, random=True, omit=True) + .withColumn("updated_by", "integer", expr="case when is_updated then updated_by_user else created_by end") + ) + + # Add dummy values if they were requested: + if dummyValues > 0: + customers_data_spec = customers_data_spec.withColumn("dummy", "long", random=True, numColumns=dummyValues, + minValue=1, maxValue=self.MAX_LONG) + + return customers_data_spec + + def getCarriers(self, sparkSession, *, rows, partitions, numCarriers, dummyValues): + import dbldatagen as dg + + # Validate the options: + if numCarriers is None or numCarriers < 0: + numCarriers = self.DEFAULT_NUM_CARRIERS + if rows is None or rows < 0: + rows = numCarriers + if partitions is None or partitions < 0: + partitions = self.autoComputePartitions(rows, 9 + dummyValues) + + # Create the base data generation spec: + carriers_data_spec = ( + dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) + .withColumn("carrier_id", "integer", minValue=self.CARRIER_MIN_VALUE, uniqueValues=numCarriers) + .withColumn("carrier_name", "string", prefix="CARRIER", baseColumn="carrier_id") + .withColumn("phone_number", "string", template="ddd-ddd-dddd") + .withColumn("email_user_name", "string", + values=["shipping", "parcel", "logistics", "carrier"], omit=True) + .withColumn("email_address", "string", expr="concat(email_user_name, '@', lower(carrier_name), '.com')") + .withColumn("created_on", "date", begin="2000-01-01", end=self.DEFAULT_START_DATE, interval="1 DAY") + .withColumn("created_by", "integer", minValue=1_000, maxValue=9_999, random=True) + .withColumn("is_updated", "boolean", expr="rand() > 0.75", omit=True) + .withColumn("updated_after_days", "integer", minValue=0, maxValue=1_000, random=True, omit=True) + .withColumn("updated_on", "date", expr="""case when is_updated then created_on + else date_add(created_on, updated_after_days) end""") + .withColumn("updated_by_user", "integer", minValue=1_000, maxValue=9_999, random=True, omit=True) + .withColumn("updated_by", "integer", expr="case when is_updated then updated_by_user else created_by end") + ) + + # Add dummy values if they were requested: + if dummyValues > 0: + carriers_data_spec = carriers_data_spec.withColumn("dummy", "long", random=True, numColumns=dummyValues, + minValue=1, maxValue=self.MAX_LONG) + + return carriers_data_spec + + def getCatalogItems(self, sparkSession, *, rows, partitions, numCatalogItems, dummyValues): + import dbldatagen as dg + + # Validate the options: + if numCatalogItems is None or numCatalogItems < 0: + numCatalogItems = self.DEFAULT_NUM_CATALOG_ITEMS + if rows is None or rows < 0: + rows = numCatalogItems + if partitions is None or partitions < 0: + partitions = self.autoComputePartitions(rows, 9 + dummyValues) + + # Create the base data generation spec: + catalog_items_data_spec = ( + dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) + .withColumn("catalog_item_id", "integer", minValue=self.CATALOG_ITEM_MIN_VALUE, + uniqueValues=numCatalogItems) + .withColumn("item_name", "string", prefix="ITEM", baseColumn="catalog_item_id") + .withColumn("unit_price", "decimal(8,2)", minValue=1.50, maxValue=500.0, random=True) + .withColumn("discount_rate", "decimal(3,2)", minValue=0.00, maxValue=9.99, random=True) + .withColumn("min_inventory_qty", "integer", minValue=0, maxValue=10_000, random=True) + .withColumn("inventory_qty_range", "integer", minValue=0, maxValue=10_000, random=True, omit=True) + .withColumn("max_inventory_qty", "integer", expr="min_inventory_qty + inventory_qty_range") + .withColumn("created_on", "date", begin="2000-01-01", end=self.DEFAULT_START_DATE, interval="1 DAY") + .withColumn("created_by", "integer", minValue=1_000, maxValue=9_999, random=True) + .withColumn("is_updated", "boolean", expr="rand() > 0.75", omit=True) + .withColumn("updated_after_days", "integer", minValue=0, maxValue=1_000, random=True, omit=True) + .withColumn("updated_on", "date", expr="""case when is_updated then created_on + else date_add(created_on, updated_after_days) end""") + .withColumn("updated_by_user", "integer", minValue=1_000, maxValue=9_999, random=True, omit=True) + .withColumn("updated_by", "integer", expr="case when is_updated then updated_by_user else created_by end") + ) + + # Add dummy values if they were requested: + if dummyValues > 0: + catalog_items_data_spec = ( + catalog_items_data_spec.withColumn("dummy", "long", random=True, + numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG)) + + return catalog_items_data_spec + + def getBaseOrders(self, sparkSession, *, rows, partitions, numOrders, numCustomers, startDate, + endDate, dummyValues): + import dbldatagen as dg + + # Validate the options: + if numOrders is None or numOrders < 0: + numOrders = self.DEFAULT_NUM_ORDERS + if numCustomers is None or numCustomers < 0: + numCustomers = self.DEFAULT_NUM_CUSTOMERS + if startDate is None: + startDate = self.DEFAULT_START_DATE + if endDate is None: + endDate = self.DEFAULT_END_DATE + if rows is None or rows < 0: + rows = numOrders + if partitions is None or partitions < 0: + partitions = self.autoComputePartitions(rows, 9 + dummyValues) + + # Create the base data generation spec: + base_orders_data_spec = ( + dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) + .withColumn("order_id", "integer", minValue=self.ORDER_MIN_VALUE, uniqueValues=numOrders) + .withColumn("order_title", "string", prefix="ORDER", baseColumn="order_id") + .withColumn("customer_id", "integer", minValue=self.CUSTOMER_MIN_VALUE, + maxValue=self.CUSTOMER_MIN_VALUE + numCustomers, random=True) + .withColumn("purchase_order_number", "string", template="KKKK-KKKK-DDDD-KKKK", random=True) + .withColumn("order_open_date", "date", begin=startDate, end=endDate, + interval="1 DAY", random=True) + .withColumn("order_open_to_close_days", "integer", minValue=0, maxValue=30, random=True, omit=True) + .withColumn("order_close_date", "date", expr=f"""least(cast('{endDate}' as date), + date_add(order_open_date, order_open_to_close_days))""") + .withColumn("sales_rep_id", "integer", minValue=1_000, maxValue=9_999, random=True) + .withColumn("sales_group_id", "integer", minValue=100, maxValue=999, random=True) + .withColumn("created_on", "date", expr="order_open_date") + .withColumn("created_by", "integer", minValue=1_000, maxValue=9_999, random=True) + .withColumn("updated_after_days", "integer", minValue=0, maxValue=5, random=True, omit=True) + .withColumn("updated_on", "date", expr="date_add(order_close_date, updated_after_days)") + .withColumn("updated_by", "integer", minValue=1_000, maxValue=9_999, random=True) + ) + + # Add dummy values if they were requested: + if dummyValues > 0: + base_orders_data_spec = base_orders_data_spec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG) + + return base_orders_data_spec + + def getBaseOrderLineItems(self, sparkSession, *, rows, partitions, numOrders, numCatalogItems, + lineItemsPerOrder, dummyValues): + import dbldatagen as dg + + # Validate the options: + if numOrders is None or numOrders < 0: + numOrders = self.DEFAULT_NUM_ORDERS + if numCatalogItems is None or numCatalogItems < 0: + numCatalogItems = self.DEFAULT_NUM_CATALOG_ITEMS + if lineItemsPerOrder is None or lineItemsPerOrder < 0: + lineItemsPerOrder = self.DEFAULT_LINE_ITEMS_PER_ORDER + if rows is None or rows < 0: + rows = numOrders + if partitions is None or partitions < 0: + partitions = self.autoComputePartitions(rows, 9 + dummyValues) + + # Create the base data generation spec: + base_order_line_items_data_spec = ( + dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) + .withColumn("order_line_item_id", "integer", minValue=self.ORDER_LINE_ITEM_MIN_VALUE, + uniqueValues=numOrders * lineItemsPerOrder) + .withColumn("order_id", "integer", minValue=self.ORDER_MIN_VALUE, maxValue=self.ORDER_MIN_VALUE + numOrders, + uniqueValues=numOrders, random=True) + .withColumn("catalog_item_id", "integer", minValue=self.CATALOG_ITEM_MIN_VALUE, + maxValue=self.CATALOG_ITEM_MIN_VALUE + numCatalogItems, uniqueValues=numCatalogItems, + random=True) + .withColumn("has_discount", "boolean", expr="rand() > 0.9") + .withColumn("units", "integer", minValue=1, maxValue=100, random=True) + .withColumn("added_after_order_creation_days", "integer", minValue=0, maxValue=30, random=True) + ) + + # Add dummy values if they were requested: + if dummyValues > 0: + base_order_line_items_data_spec = base_order_line_items_data_spec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG) + + return base_order_line_items_data_spec + + def getBaseOrderShipments(self, sparkSession, *, rows, partitions, numOrders, numCarriers, dummyValues): + import dbldatagen as dg + + # Validate the options: + if numOrders is None or numOrders < 0: + numOrders = self.DEFAULT_NUM_ORDERS + if numCarriers is None or numCarriers < 0: + numCarriers = self.DEFAULT_NUM_CARRIERS + if rows is None or rows < 0: + rows = numOrders + if partitions is None or partitions < 0: + partitions = self.autoComputePartitions(rows, 9 + dummyValues) + + # Create the base data generation spec: + base_order_shipments_data_spec = ( + dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) + .withColumn("order_shipment_id", "integer", minValue=self.ORDER_MIN_VALUE, uniqueValues=numOrders) + .withColumn("order_id", "integer", minValue=self.ORDER_MIN_VALUE, maxValue=self.ORDER_MIN_VALUE + numOrders, + uniqueValues=numOrders, random=True) + .withColumn("carrier_id", "integer", minValue=self.CARRIER_MIN_VALUE, + maxValue=self.CARRIER_MIN_VALUE + numCarriers, uniqueValues=numCarriers, random=True) + .withColumn("house_number", "integer", minValue=1, maxValue=9999, random=True, omit=True) + .withColumn("street_number", "integer", minValue=1, maxValue=150, random=True, omit=True) + .withColumn("street_direction", "string", values=["", "N", "S", "E", "W", "NW", "NE", "SW", "SE"], + random=True) + .withColumn("ship_to_address_line", "string", expr="""concat_ws(' ', house_number, street_direction, + street_number, 'ST')""") + .withColumn("ship_to_country_code", "string", values=["US", "CA"], weights=[8, 3], random=True) + .withColumn("order_open_to_ship_days", "integer", minValue=0, maxValue=30, random=True) + .withColumn("estimated_transit_days", "integer", minValue=1, maxValue=5, random=True) + .withColumn("actual_transit_days", "integer", expr="greatest(1, estimated_transit_days - ceil(3*rand()))") + .withColumn("receipt_on_delivery", "boolean", expr="rand() > 0.7") + .withColumn("method", "string", values=["GROUND", "AIR"], weights=[7, 4], random=True) + ) + + # Add dummy values if they were requested: + if dummyValues > 0: + base_order_shipments_data_spec = base_order_shipments_data_spec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG) + + return base_order_shipments_data_spec + + def getBaseInvoices(self, sparkSession, *, rows, partitions, numOrders, dummyValues): + import dbldatagen as dg + + # Validate the options: + if numOrders is None or numOrders < 0: + numOrders = self.DEFAULT_NUM_ORDERS + if rows is None or rows < 0: + rows = numOrders + if partitions is None or partitions < 0: + partitions = self.autoComputePartitions(rows, 9 + dummyValues) + + # Create the base data generation spec: + base_invoices_data_spec = ( + dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) + .withColumn("invoice_id", "integer", minValue=self.INVOICE_MIN_VALUE, uniqueValues=numOrders) + .withColumn("order_id", "integer", minValue=self.ORDER_MIN_VALUE, maxValue=self.ORDER_MIN_VALUE + numOrders, + uniqueValues=numOrders, random=True) + .withColumn("house_number", "integer", minValue=1, maxValue=9999, random=True, omit=True) + .withColumn("street_number", "integer", minValue=1, maxValue=150, random=True, omit=True) + .withColumn("street_direction", "string", values=["", "N", "S", "E", "W", "NW", "NE", "SW", "SE"], + random=True) + .withColumn("bill_to_address_line", "string", expr="""concat_ws(' ', house_number, street_direction, + street_number, 'ST')""") + .withColumn("bill_to_country_code", "string", values=["US", "CA"], weights=[8, 3], random=True) + .withColumn("order_close_to_invoice_days", "integer", minValue=0, maxValue=5, random=True) + .withColumn("order_close_to_create_days", "integer", minValue=0, maxValue=2, random=True) + .withColumn("created_by", "integer", minValue=1_000, maxValue=9_999, random=True) + .withColumn("is_updated", "boolean", expr="rand() > 0.75") + .withColumn("updated_after_days", "integer", minValue=0, maxValue=5, random=True) + .withColumn("updated_by_user", "integer", minValue=1_000, maxValue=9_999, random=True, omit=True) + .withColumn("updated_by", "integer", expr="case when is_updated then updated_by_user else created_by end") + ) + + # Add dummy values if they were requested: + if dummyValues > 0: + base_invoices_data_spec = base_invoices_data_spec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG) + + return base_invoices_data_spec + + @DatasetProvider.allowed_options(options=["numCustomers", "numCarriers", "numCatalogItems", "numOrders", + "lineItemsPerOrder", "startDate", "endDate", "dummyValues"]) + def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, **options): + # Get the option values: + numCustomers = options.get("numCustomers", self.DEFAULT_NUM_CUSTOMERS) + numCarriers = options.get("numCarriers", self.DEFAULT_NUM_CARRIERS) + numCatalogItems = options.get("numCatalogItems", self.DEFAULT_NUM_CATALOG_ITEMS) + numOrders = options.get("numOrders", self.DEFAULT_NUM_ORDERS) + lineItemsPerOrder = options.get("lineItemsPerOrder", self.DEFAULT_LINE_ITEMS_PER_ORDER) + startDate = options.get("startDate", self.DEFAULT_START_DATE) + endDate = options.get("endDate", self.DEFAULT_END_DATE) + dummyValues = options.get("dummyValues", 0) + + # Get table generation specs for the base tables: + spec = None + if tableName == "customers": + spec = self.getCustomers( + sparkSession, + rows=rows, + partitions=partitions, + numCustomers=numCustomers, + dummyValues=dummyValues + ) + elif tableName == "carriers": + spec = self.getCarriers( + sparkSession, + rows=rows, + partitions=partitions, + numCarriers=numCarriers, + dummyValues=dummyValues + ) + elif tableName == "catalog_items": + spec = self.getCatalogItems( + sparkSession, + rows=rows, + partitions=partitions, + numCatalogItems=numCatalogItems, + dummyValues=dummyValues + ) + elif tableName == "base_orders": + spec = self.getBaseOrders( + sparkSession, + rows=rows, + partitions=partitions, + numOrders=numOrders, + numCustomers=numCustomers, + startDate=startDate, + endDate=endDate, + dummyValues=dummyValues + ) + elif tableName == "base_order_line_items": + spec = self.getBaseOrderLineItems( + sparkSession, + rows=rows, + partitions=partitions, + numOrders=numOrders, + numCatalogItems=numCatalogItems, + lineItemsPerOrder=lineItemsPerOrder, + dummyValues=dummyValues + ) + elif tableName == "base_order_shipments": + spec = self.getBaseOrderShipments( + sparkSession, + rows=rows, + partitions=partitions, + numOrders=numOrders, + numCarriers=numCarriers, + dummyValues=dummyValues + ) + elif tableName == "base_invoices": + spec = self.getBaseInvoices( + sparkSession, + rows=rows, + partitions=partitions, + numOrders=numOrders, + dummyValues=dummyValues + ) + if spec is not None: + return spec + raise ValueError("tableName must be 'customers', 'carriers', 'catalog_items', 'base_orders'," + "'base_order_line_items', 'base_order_shipments', 'base_invoices'") + + @DatasetProvider.allowed_options(options=[ + "customers", + "carriers", + "catalogItems", + "baseOrders", + "baseOrderLineItems", + "baseOrderShipments", + "baseInvoices" + ]) + def getAssociatedDataset(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, **options): + from pyspark.sql import DataFrame + import pyspark.sql.functions as F + + dfCustomers = options.get("customers", None) + assert dfCustomers is not None and issubclass(type(dfCustomers), DataFrame), \ + "Option `customers` should be a dataframe of customer records" + + dfCarriers = options.get("carriers", None) + assert dfCarriers is not None and issubclass(type(dfCarriers), DataFrame), \ + "Option `carriers` should be dataframe of carrier records" + + dfCatalogItems = options.get("catalogItems", None) + assert dfCatalogItems is not None and issubclass(type(dfCatalogItems), DataFrame), \ + "Option `catalogItems` should be dataframe of catalog item records" + + dfBaseOrders = options.get("baseOrders", None) + assert dfBaseOrders is not None and issubclass(type(dfBaseOrders), DataFrame), \ + "Option `baseOrders` should be dataframe of base order records" + + dfBaseOrderLineItems = options.get("baseOrderLineItems", None) + assert dfBaseOrderLineItems is not None and issubclass(type(dfBaseOrderLineItems), DataFrame), \ + "Option `baseOrderLineItems` should be dataframe of base order line item records" + + dfBaseOrderShipments = options.get("baseOrderShipments", None) + assert dfBaseOrderShipments is not None and issubclass(type(dfBaseOrderShipments), DataFrame), \ + "Option `baseOrderLineItems` should be dataframe of base order shipment records" + + dfBaseInvoices = options.get("baseInvoices", None) + assert dfBaseInvoices is not None and issubclass(type(dfBaseInvoices), DataFrame), \ + "Option `baseInvoices` should be dataframe of base invoice records" + + if tableName == "orders": + dfOrderTotals = ( + dfBaseOrderLineItems.alias("a") + .join(dfCatalogItems.alias("b"), on="catalog_item_id") + .selectExpr("a.order_id as order_id", + "a.order_line_item_id as order_line_item_id", + """case when a.has_discount then (b.unit_price * 1 - (b.discount_rate / 100)) + else b.unit_price end as unit_price""", + "a.units as units") + .selectExpr("order_id", "order_line_item_id", "unit_price * units as total_price") + .groupBy("order_id") + .agg(F.count("order_line_item_id").alias("num_line_items"), + F.sum("total_price").alias("order_total")) + ) + return ( + dfBaseOrders.alias("a") + .join(dfOrderTotals.alias("b"), on="order_id") + .join(dfCustomers.alias("c"), on="customer_id") + .join(dfBaseOrderShipments.alias("d"), on="order_id") + .selectExpr( + "a.order_id", + "concat(c.customer_name, ' ', a.order_title) AS order_title", + "a.customer_id", + "b.order_total", + "b.num_line_items", + "a.order_open_date", + "a.order_close_date", + "a.sales_rep_id", + "a.sales_group_id", + "a.created_on", + "a.created_by", + "a.updated_on", + "a.updated_by") + ) + + if tableName == "order_line_items": + return ( + dfBaseOrderLineItems.alias("a") + .join(dfBaseOrders.alias("b"), on="order_id") + .join(dfCatalogItems.alias("c"), on="catalog_item_id") + .selectExpr( + "a.order_line_item_id", + "a.order_id", + "a.catalog_item_id", + "a.units", + "c.unit_price", + "a.units * c.unit_price as gross_price", + """case when a.has_discount then a.units * c.unit_price * (1 - (c.discount_rate / 100)) + else a.units * c.unit_price end as net_price""", + "date_add(b.created_on, a.added_after_order_creation_days) as created_on", + "b.created_by") + ) + + if tableName == "order_shipments": + return ( + dfBaseOrderShipments.alias("a") + .join(dfBaseOrders.alias("b"), on="order_id") + .selectExpr( + "a.order_shipment_id", + "a.order_id", + "a.carrier_id", + "a.method", + "a.ship_to_address_line", + "a.ship_to_country_code", + """least(b.order_close_date, + date_add(b.order_open_date, a.order_open_to_ship_days)) as order_shipment_date""", + "a.estimated_transit_days", + "a.actual_transit_days", + "b.created_on", + "b.created_by") + ) + + if tableName == "invoices": + dfOrderTotals = ( + dfBaseOrderLineItems.alias("a") + .join(dfCatalogItems.alias("b"), on="catalog_item_id") + .selectExpr("a.order_id as order_id", + "a.order_line_item_id as order_line_item_id", + """case when a.has_discount then (b.unit_price * 1 - (b.discount_rate / 100)) + else b.unit_price end as unit_price""", + "a.units as units") + .selectExpr("order_id", "order_line_item_id", "unit_price * units as total_price") + .groupBy("order_id") + .agg(F.count("order_line_item_id").alias("num_line_items"), F.sum("total_price").alias("order_total")) + ) + return ( + dfBaseInvoices.alias("a") + .join(dfBaseOrders.alias("b"), on="order_id") + .join(dfCustomers.alias("c"), on="customer_id") + .join(dfOrderTotals.alias("d"), on="order_id") + .selectExpr( + "a.invoice_id", + "a.order_id", + "b.purchase_order_number", + "d.order_total", + "b.customer_id", + "c.payment_terms", + "a.bill_to_address_line", + "a.bill_to_country_code", + "date_add(b.order_close_date, a.order_close_to_invoice_days) as invoice_date", + "date_add(b.order_close_date, a.order_close_to_create_days) as created_on", + "a.created_by", + """case when a.is_updated then + date_add(b.order_close_date, a.order_close_to_create_days + a.updated_after_days) + else date_add(b.order_close_date, a.order_close_to_create_days) end as updated_on""", + "case when a.is_updated then a.updated_by else a.created_by end as updated_by") + ) diff --git a/tests/test_standard_dataset_providers.py b/tests/test_standard_dataset_providers.py index bb952b40..1123bf8c 100644 --- a/tests/test_standard_dataset_providers.py +++ b/tests/test_standard_dataset_providers.py @@ -1,5 +1,6 @@ +from datetime import date +from contextlib import nullcontext as does_not_raise import pytest - import dbldatagen as dg spark = dg.SparkSingleton.getLocalInstance("unit tests") @@ -8,56 +9,62 @@ class TestStandardDatasetProviders: # BASIC GEOMETRIES tests: - @pytest.mark.parametrize("providerName, providerOptions", [ - ("basic/geometries", - {"rows": 50, "partitions": 4, "random": False, "geometryType": "point", "maxVertices": 1}), - ("basic/geometries", - {"rows": 100, "partitions": -1, "random": False, "geometryType": "point", "maxVertices": 2}), - ("basic/geometries", - {"rows": -1, "partitions": 4, "random": True, "geometryType": "point"}), - ("basic/geometries", {}), - ("basic/geometries", - {"rows": 5000, "partitions": -1, "random": True, "geometryType": "lineString"}), - ("basic/geometries", - {"rows": -1, "partitions": -1, "random": False, "geometryType": "lineString", "maxVertices": 2}), - ("basic/geometries", - {"rows": -1, "partitions": 4, "random": True, "geometryType": "lineString", "maxVertices": 1}), - ("basic/geometries", - {"rows": 5000, "partitions": 4, "geometryType": "lineString", "maxVertices": 2}), - ("basic/geometries", - {"rows": 5000, "partitions": -1, "random": False, "geometryType": "polygon"}), - ("basic/geometries", - {"rows": -1, "partitions": -1, "random": True, "geometryType": "polygon", "maxVertices": 3}), - ("basic/geometries", - {"rows": -1, "partitions": 4, "random": True, "geometryType": "polygon", "maxVertices": 2}), - ("basic/geometries", - {"rows": 5000, "partitions": 4, "geometryType": "polygon", "maxVertices": 5}), + @pytest.mark.parametrize("providerName, providerOptions, expectation", [ + ("basic/geometries", {}, does_not_raise()), + ("basic/geometries", {"rows": 50, "partitions": 4, "random": False, + "geometryType": "point", "maxVertices": 1}, does_not_raise()), + ("basic/geometries", {"rows": 100, "partitions": -1, "random": False, + "geometryType": "point", "maxVertices": 2}, does_not_raise()), + ("basic/geometries", {"rows": -1, "partitions": 4, "random": True, + "geometryType": "point"}, does_not_raise()), + ("basic/geometries", {"rows": 5000, "partitions": -1, "random": True, + "geometryType": "lineString"}, does_not_raise()), + ("basic/geometries", {"rows": -1, "partitions": -1, "random": False, + "geometryType": "lineString", "maxVertices": 2}, does_not_raise()), + ("basic/geometries", {"rows": -1, "partitions": 4, "random": True, + "geometryType": "lineString", "maxVertices": 1}, does_not_raise()), + ("basic/geometries", {"rows": 5000, "partitions": 4, + "geometryType": "lineString", "maxVertices": 2}, does_not_raise()), + ("basic/geometries", {"rows": 5000, "partitions": -1, "random": False, + "geometryType": "polygon"}, does_not_raise()), + ("basic/geometries", {"rows": -1, "partitions": -1, "random": True, + "geometryType": "polygon", "maxVertices": 3}, does_not_raise()), + ("basic/geometries", {"rows": -1, "partitions": 4, "random": True, + "geometryType": "polygon", "maxVertices": 2}, does_not_raise()), + ("basic/geometries", {"rows": 5000, "partitions": 4, + "geometryType": "polygon", "maxVertices": 5}, does_not_raise()), + ("basic/geometries", + {"rows": 5000, "partitions": 4, "geometryType": "polygon", "minLatitude": 45.0, + "maxLatitude": 50.0, "minLongitude": -85.0, "maxLongitude": -75.0}, does_not_raise()), + ("basic/geometries", + {"rows": -1, "partitions": -1, "geometryType": "multipolygon"}, pytest.raises(ValueError)) ]) - def test_basic_geometries_retrieval(self, providerName, providerOptions): - ds = dg.Datasets(spark, providerName).get(**providerOptions) - assert ds is not None + def test_basic_geometries_retrieval(self, providerName, providerOptions, expectation): + with expectation: + ds = dg.Datasets(spark, providerName).get(**providerOptions) + assert ds is not None - df = ds.build() - assert df.count() >= 0 - assert "wkt" in df.columns + df = ds.build() + assert df.count() >= 0 + assert "wkt" in df.columns - geometryType = providerOptions.get("geometryType", None) - row = df.first().asDict() - if geometryType == "point" or geometryType is None: - assert "POINT" in row["wkt"] + geometryType = providerOptions.get("geometryType", None) + row = df.first().asDict() + if geometryType == "point" or geometryType is None: + assert "POINT" in row["wkt"] - if geometryType == "lineString": - assert "LINESTRING" in row["wkt"] + if geometryType == "lineString": + assert "LINESTRING" in row["wkt"] - if geometryType == "polygon": - assert "POLYGON" in row["wkt"] + if geometryType == "polygon": + assert "POLYGON" in row["wkt"] - random = providerOptions.get("random", None) - if random: - print("") - leadingRows = df.limit(100).collect() - ids = [r.location_id for r in leadingRows] - assert ids != sorted(ids) + random = providerOptions.get("random", None) + if random: + print("") + leadingRows = df.limit(100).collect() + ids = [r.location_id for r in leadingRows] + assert ids != sorted(ids) # BASIC PROCESS HISTORIAN tests: @pytest.mark.parametrize("providerName, providerOptions", [ @@ -112,6 +119,61 @@ def test_basic_process_historian_retrieval(self, providerName, providerOptions): ids = [r.device_id for r in leadingRows] assert ids != sorted(ids) + # BASIC STOCK TICKER tests: + @pytest.mark.parametrize("providerName, providerOptions, expectation", [ + ("basic/stock_ticker", + {"rows": 50, "partitions": 4, "numSymbols": 5, "startDate": "2024-01-01"}, does_not_raise()), + ("basic/stock_ticker", + {"rows": 100, "partitions": -1, "numSymbols": 5, "startDate": "2024-01-01"}, does_not_raise()), + ("basic/stock_ticker", + {"rows": -1, "partitions": 4, "numSymbols": 10, "startDate": "2024-01-01"}, does_not_raise()), + ("basic/stock_ticker", {}, does_not_raise()), + ("basic/stock_ticker", + {"rows": 5000, "partitions": -1, "numSymbols": 50, "startDate": "2024-01-01"}, does_not_raise()), + ("basic/stock_ticker", + {"rows": 5000, "partitions": 4, "numSymbols": 50}, does_not_raise()), + ("basic/stock_ticker", + {"rows": 5000, "partitions": 4, "startDate": "2024-01-01"}, does_not_raise()), + ("basic/stock_ticker", + {"rows": 5000, "partitions": 4, "numSymbols": 100, "startDate": "2024-01-01"}, does_not_raise()), + ("basic/stock_ticker", + {"rows": 1000, "partitions": -1, "numSymbols": 100, "startDate": "2025-01-01"}, does_not_raise()), + ("basic/stock_ticker", + {"rows": 1000, "partitions": -1, "numSymbols": 10, "startDate": "2020-01-01"}, does_not_raise()), + ("basic/stock_ticker", + {"rows": 50, "partitions": 2, "numSymbols": 0, "startDate": "2020-06-04"}, pytest.raises(ValueError)), + ("basic/stock_ticker", + {"rows": 500, "numSymbols": 12, "startDate": "2025-06-04"}, does_not_raise()), + ("basic/stock_ticker", + {"rows": 10, "partitions": 1, "numSymbols": -1, "startDate": "2009-01-02"}, pytest.raises(ValueError)), + ("basic/stock_ticker", + {"partitions": 2, "numSymbols": 20, "startDate": "2021-01-01"}, does_not_raise()), + ("basic/stock_ticker", + {"rows": 50, "partitions": 2, "numSymbols": 2}, does_not_raise()), + ]) + def test_basic_stock_ticker_retrieval(self, providerName, providerOptions, expectation): + with expectation: + ds = dg.Datasets(spark, providerName).get(**providerOptions) + assert ds is not None + df = ds.build() + assert df.count() >= 0 + + if "numSymbols" in providerOptions: + assert df.selectExpr("symbol").distinct().count() == providerOptions.get("numSymbols") + + if "startDate" in providerOptions: + assert df.selectExpr("min(post_date) as min_post_date") \ + .collect()[0] \ + .asDict()["min_post_date"] == date.fromisoformat(providerOptions.get("startDate")) + + assert df.where("""open < 0.0 + or close < 0.0 + or high < 0.0 + or low < 0.0 + or adj_close < 0.0""").count() == 0 + + assert df.where("high < low").count() == 0 + # BASIC TELEMATICS tests: @pytest.mark.parametrize("providerName, providerOptions", [ ("basic/telematics", @@ -266,6 +328,92 @@ def test_benchmark_groupby_retrieval(self, providerName, providerOptions): vals = [r.v3 for r in leadingRows] assert vals != sorted(vals) + # MULTI-TABLE SALES ORDER tests: + @pytest.mark.parametrize("providerName, providerOptions, expectation", [ + ("multi_table/sales_order", {"rows": 50, "partitions": 4}, does_not_raise()), + ("multi_table/sales_order", {"rows": -1, "partitions": 4}, does_not_raise()), + ("multi_table/sales_order", {}, does_not_raise()), + ("multi_table/sales_order", {"rows": 100, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"rows": 5000, "dummyValues": 4}, does_not_raise()), + ("multi_table/sales_order", {"rows": 100, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "customers", "numCustomers": 100}, does_not_raise()), + ("multi_table/sales_order", {"table": "customers", "numCustomers": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "customers", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "customers", "rows": -1, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "carriers", "numCarriers": 50}, does_not_raise()), + ("multi_table/sales_order", {"table": "carriers", "numCarriers": -1, "dummyValues": 2}, does_not_raise()), + ("multi_table/sales_order", {"table": "carriers", "numCustomers": 100}, does_not_raise()), + ("multi_table/sales_order", {"table": "carriers", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "carriers", "rows": -1, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "catalog_items", "numCatalogItems": -1, + "dummyValues": 5}, does_not_raise()), + ("multi_table/sales_order", {"table": "catalog_items", "numCatalogItems": 100, + "numCustomers": 1000}, does_not_raise()), + ("multi_table/sales_order", {"table": "catalog_items", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "catalog_items", "rows": -1, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_orders", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_orders", "rows": -1, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_orders", "numOrders": -1, "numCustomers": -1, "startDate": None, + "endDate": None}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_orders", "numOrders": 1000, "numCustomers": 10, + "dummyValues": 2}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_order_line_items", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_order_line_items", "rows": -1, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_order_line_items", "numOrders": 1000, + "dummyValues": 5}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_order_line_items", "numOrders": -1, "numCatalogItems": -1, + "lineItemsPerOrder": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_order_shipments", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_order_shipments", "rows": -1, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_order_shipments", "numOrders": 1000, + "numCarriers": 10}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_order_shipments", "numOrders": -1, "numCarriers": -1, + "dummyValues": 2}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_invoices", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_invoices", "rows": -1, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_invoices", "numOrders": 1000, + "numCustomers": 10}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_invoices", "numOrders": -1, "dummyValues": 2}, does_not_raise()), + ("multi_table/sales_order", {"table": "invalid_table_name"}, pytest.raises(ValueError)) + ]) + def test_multi_table_sales_order_retrieval(self, providerName, providerOptions, expectation): + with expectation: + ds = dg.Datasets(spark, providerName).get(**providerOptions) + assert ds is not None, f"""expected to get dataset specification for provider `{providerName}` + with options: {providerOptions} + """ + df = ds.build() + assert df.limit(100).count() >= 0 + + def test_full_multitable_sales_order_sequence(self): + multiTableDataSet = dg.Datasets(spark, "multi_table/sales_order") + options = {"numCustomers": 100, "numOrders": 1000, "numCarriers": 10, "numCatalogItems": 100, + "startDate": "2024-01-01", "endDate": "2024-12-31", "lineItemsPerOrder": 3} + dfCustomers = multiTableDataSet.get(table="customers", **options).build() + dfCarriers = multiTableDataSet.get(table="carriers", **options).build() + dfCatalogItems = multiTableDataSet.get(table="catalog_items", **options).build() + dfBaseOrders = multiTableDataSet.get(table="base_orders", **options).build() + dfBaseOrderLineItems = multiTableDataSet.get(table="base_order_line_items", **options).build() + dfBaseOrderShipments = multiTableDataSet.get(table="base_order_shipments", **options).build() + dfBaseInvoices = multiTableDataSet.get(table="base_invoices", **options).build() + + tables = ["orders", "order_line_items", "order_shipments", "invoices"] + for table in tables: + df = multiTableDataSet.getSummaryDataset( + table=table, + customers=dfCustomers, + carriers=dfCarriers, + catalogItems=dfCatalogItems, + baseOrders=dfBaseOrders, + baseOrderLineItems=dfBaseOrderLineItems, + baseOrderShipments=dfBaseOrderShipments, + baseInvoices=dfBaseInvoices + ) + + assert df is not None + assert df.count() >= 0 + assert df + # MULTI-TABLE TELEPHONY tests: @pytest.mark.parametrize("providerName, providerOptions", [ ("multi_table/telephony", {"rows": 50, "partitions": 4, "random": False}), @@ -277,6 +425,7 @@ def test_benchmark_groupby_retrieval(self, providerName, providerOptions): ("multi_table/telephony", {"table": 'plans', "numPlans": 100}), ("multi_table/telephony", {"table": 'plans'}), ("multi_table/telephony", {"table": 'customers', "numPlans": 100, "numCustomers": 1000}), + ("multi_table/telephony", {"table": 'customers', "numPlans": 100, "numCustomers": 1000}), ("multi_table/telephony", {"table": 'customers'}), ("multi_table/telephony", {"table": 'deviceEvents', "numPlans": 100, "numCustomers": 1000}), ("multi_table/telephony", {"table": 'deviceEvents'}),