diff --git a/dbldatagen/datasets/multi_table_sales_order_provider.py b/dbldatagen/datasets/multi_table_sales_order_provider.py index 8c8f2d0e..7b855df6 100644 --- a/dbldatagen/datasets/multi_table_sales_order_provider.py +++ b/dbldatagen/datasets/multi_table_sales_order_provider.py @@ -245,7 +245,7 @@ def getBaseOrderLineItems(self, sparkSession, *, rows, partitions, numOrders, nu base_order_line_items_data_spec = ( dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) .withColumn("order_line_item_id", "integer", minValue=self.ORDER_LINE_ITEM_MIN_VALUE, - uniqueValues=numOrders*lineItemsPerOrder) + uniqueValues=numOrders * lineItemsPerOrder) .withColumn("order_id", "integer", minValue=self.ORDER_MIN_VALUE, maxValue=self.ORDER_MIN_VALUE + numOrders, uniqueValues=numOrders, random=True) .withColumn("catalog_item_id", "integer", minValue=self.CATALOG_ITEM_MIN_VALUE, @@ -359,8 +359,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions dummyValues = options.get("dummyValues", 0) # Get table generation specs for the base tables: + spec = None if tableName == "customers": - return self.getCustomers( + spec = self.getCustomers( sparkSession, rows=rows, partitions=partitions, @@ -368,7 +369,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions dummyValues=dummyValues ) elif tableName == "carriers": - return self.getCarriers( + spec = self.getCarriers( sparkSession, rows=rows, partitions=partitions, @@ -376,7 +377,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions dummyValues=dummyValues ) elif tableName == "catalog_items": - return self.getCatalogItems( + spec = self.getCatalogItems( sparkSession, rows=rows, partitions=partitions, @@ -384,7 +385,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions dummyValues=dummyValues ) elif tableName == "base_orders": - return self.getBaseOrders( + spec = self.getBaseOrders( sparkSession, rows=rows, partitions=partitions, @@ -395,7 +396,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions dummyValues=dummyValues ) elif tableName == "base_order_line_items": - return self.getBaseOrderLineItems( + spec = self.getBaseOrderLineItems( sparkSession, rows=rows, partitions=partitions, @@ -405,7 +406,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions dummyValues=dummyValues ) elif tableName == "base_order_shipments": - return self.getBaseOrderShipments( + spec = self.getBaseOrderShipments( sparkSession, rows=rows, partitions=partitions, @@ -414,13 +415,15 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions dummyValues=dummyValues ) elif tableName == "base_invoices": - return self.getBaseInvoices( + spec = self.getBaseInvoices( sparkSession, rows=rows, partitions=partitions, numOrders=numOrders, dummyValues=dummyValues ) + if spec is not None: + return spec raise ValueError("tableName must be 'customers', 'carriers', 'catalog_items', 'base_orders'," "'base_order_line_items', 'base_order_shipments', 'base_invoices'") @@ -548,8 +551,7 @@ def getAssociatedDataset(self, sparkSession, *, tableName=None, rows=-1, partiti "a.units as units") .selectExpr("order_id", "order_line_item_id", "unit_price * units as total_price") .groupBy("order_id") - .agg(F.count("order_line_item_id").alias("num_line_items"), - F.sum("total_price").alias("order_total")) + .agg(F.count("order_line_item_id").alias("num_line_items"),F.sum("total_price").alias("order_total")) ) return ( dfBaseInvoices.alias("a") diff --git a/tests/test_standard_dataset_providers.py b/tests/test_standard_dataset_providers.py index 009e98fb..9cce393d 100644 --- a/tests/test_standard_dataset_providers.py +++ b/tests/test_standard_dataset_providers.py @@ -1,7 +1,7 @@ -import pytest from datetime import date -import dbldatagen as dg from contextlib import nullcontext as does_not_raise +import pytest +import dbldatagen as dg spark = dg.SparkSingleton.getLocalInstance("unit tests")