Skip to content

Commit

Permalink
Fix code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ghanse committed Nov 21, 2024
1 parent d2d4e6d commit 562d93f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
22 changes: 12 additions & 10 deletions dbldatagen/datasets/multi_table_sales_order_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def getBaseOrderLineItems(self, sparkSession, *, rows, partitions, numOrders, nu
base_order_line_items_data_spec = (
dg.DataGenerator(sparkSession, rows=rows, partitions=partitions)
.withColumn("order_line_item_id", "integer", minValue=self.ORDER_LINE_ITEM_MIN_VALUE,
uniqueValues=numOrders*lineItemsPerOrder)
uniqueValues=numOrders * lineItemsPerOrder)
.withColumn("order_id", "integer", minValue=self.ORDER_MIN_VALUE, maxValue=self.ORDER_MIN_VALUE + numOrders,
uniqueValues=numOrders, random=True)
.withColumn("catalog_item_id", "integer", minValue=self.CATALOG_ITEM_MIN_VALUE,
Expand Down Expand Up @@ -359,32 +359,33 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
dummyValues = options.get("dummyValues", 0)

# Get table generation specs for the base tables:
spec = None
if tableName == "customers":
return self.getCustomers(
spec = self.getCustomers(
sparkSession,
rows=rows,
partitions=partitions,
numCustomers=numCustomers,
dummyValues=dummyValues
)
elif tableName == "carriers":
return self.getCarriers(
spec = self.getCarriers(
sparkSession,
rows=rows,
partitions=partitions,
numCarriers=numCarriers,
dummyValues=dummyValues
)
elif tableName == "catalog_items":
return self.getCatalogItems(
spec = self.getCatalogItems(
sparkSession,
rows=rows,
partitions=partitions,
numCatalogItems=numCatalogItems,
dummyValues=dummyValues
)
elif tableName == "base_orders":
return self.getBaseOrders(
spec = self.getBaseOrders(
sparkSession,
rows=rows,
partitions=partitions,
Expand All @@ -395,7 +396,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
dummyValues=dummyValues
)
elif tableName == "base_order_line_items":
return self.getBaseOrderLineItems(
spec = self.getBaseOrderLineItems(
sparkSession,
rows=rows,
partitions=partitions,
Expand All @@ -405,7 +406,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
dummyValues=dummyValues
)
elif tableName == "base_order_shipments":
return self.getBaseOrderShipments(
spec = self.getBaseOrderShipments(
sparkSession,
rows=rows,
partitions=partitions,
Expand All @@ -414,13 +415,15 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
dummyValues=dummyValues
)
elif tableName == "base_invoices":
return self.getBaseInvoices(
spec = self.getBaseInvoices(
sparkSession,
rows=rows,
partitions=partitions,
numOrders=numOrders,
dummyValues=dummyValues
)
if spec is not None:
return spec
raise ValueError("tableName must be 'customers', 'carriers', 'catalog_items', 'base_orders',"
"'base_order_line_items', 'base_order_shipments', 'base_invoices'")

Expand Down Expand Up @@ -548,8 +551,7 @@ def getAssociatedDataset(self, sparkSession, *, tableName=None, rows=-1, partiti
"a.units as units")
.selectExpr("order_id", "order_line_item_id", "unit_price * units as total_price")
.groupBy("order_id")
.agg(F.count("order_line_item_id").alias("num_line_items"),
F.sum("total_price").alias("order_total"))
.agg(F.count("order_line_item_id").alias("num_line_items"),F.sum("total_price").alias("order_total"))
)
return (
dfBaseInvoices.alias("a")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_standard_dataset_providers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from datetime import date
import dbldatagen as dg
from contextlib import nullcontext as does_not_raise
import pytest
import dbldatagen as dg

spark = dg.SparkSingleton.getLocalInstance("unit tests")

Expand Down

0 comments on commit 562d93f

Please sign in to comment.