From 43b49ac956918186c2d0ebfde6deabf660be38de Mon Sep 17 00:00:00 2001 From: Alex Mueller Date: Wed, 18 May 2022 13:54:02 -0500 Subject: [PATCH] LEXIO-38099: union operator (#10) * LEXIO-38099: added UnionStatement * LEXIO-38099: added union tests * Version bumped to 0.10.0 * LEXIO-38099: cruft update Co-authored-by: ns-circle-ci --- .cruft.json | 2 +- pyproject.toml | 2 +- pysaql/__init__.py | 2 +- pysaql/stream.py | 66 +++++++++++++++++++++++++++++++++++++-- tests/unit/test_stream.py | 36 ++++++++++++++++++++- 5 files changed, 101 insertions(+), 7 deletions(-) diff --git a/.cruft.json b/.cruft.json index dbc7876..b8bfe84 100644 --- a/.cruft.json +++ b/.cruft.json @@ -1,6 +1,6 @@ { "template": "https://github.com/NarrativeScience/cookiecutter-python-lib", - "commit": "e090078902c7effbcde4bb9da97dd41eefee28b7", + "commit": "06d791b4e3ac2362c595a9bcf0617f84e546ec3c", "checkout": null, "context": { "cookiecutter": { diff --git a/pyproject.toml b/pyproject.toml index fedaaf4..c49c58c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pysaql" -version = "0.9.0" +version = "0.10.0" description = "Python SAQL query builder" authors = ["Jonathan Drake "] license = "BSD-3-Clause" diff --git a/pysaql/__init__.py b/pysaql/__init__.py index 4f4eccc..64b562a 100644 --- a/pysaql/__init__.py +++ b/pysaql/__init__.py @@ -1,3 +1,3 @@ """Python SAQL query builder""" -__version__ = "0.9.0" +__version__ = "0.10.0" diff --git a/pysaql/stream.py b/pysaql/stream.py index acbfbb2..c879b9f 100644 --- a/pysaql/stream.py +++ b/pysaql/stream.py @@ -63,9 +63,14 @@ def increment_id(self, incr: int) -> int: statement.stream._id += incr + i max_id = max(max_id, statement.stream._id) i += 1 - elif isinstance(statement, CogroupStatement): - # For cogroup statements, leave the left-most (first) branch alone - for (stream, _) in statement.streams[1:]: + elif isinstance(statement, (CogroupStatement, UnionStatement)): + # For cogroup and union statements, leave the left-most (first) branch alone + if isinstance(statement, CogroupStatement): + streams = [stream for (stream, _) in statement.streams[1:]] + else: + streams = list(statement.streams[1:]) + + for stream in streams: stream.increment_id(incr + i) max_id = max(max_id, stream._id) i += 1 @@ -398,6 +403,40 @@ def __str__(self) -> str: return "\n".join(lines) +class UnionStatement(StreamStatement): + """Statement to combine (union) two or more streams with the same structure into one""" + + def __init__( + self, + stream: Stream, + streams: Sequence[Stream], + ) -> None: + """Initializer + + Args: + stream: Stream containing this statement + streams: Streams that will be combined + + """ + super().__init__() + self.stream = stream + if not streams or len(streams) < 2: + raise ValueError("At least two streams are required") + self.streams = streams + + def __str__(self) -> str: + """Cast this union statement to a string""" + lines = [] + stream_refs = [] + + for stream in self.streams: + lines.append(str(stream)) + stream_refs.append(stream.ref) + + lines.append(f"{self.stream.ref} = union {', '.join(stream_refs)};") + return "\n".join(lines) + + class FillStatement(StreamStatement): """Statement to fill a data stream with missing dates""" @@ -471,3 +510,24 @@ def cogroup( # We'll use the ID of the first stream as the basis for incrementing. stream.increment_id(streams[0][0]._id) return stream + + +def union(*streams: Stream) -> Stream: + """Union data from two or more data streams into a single data stream + + Each stream should have the same field names and structure. The streams do + not need to be from the same dataset. + + Args: + streams: Streams that will be unioned together + + Returns: + a new stream + + """ + stream = Stream() + stream.add_statement(UnionStatement(stream, streams)) + # Increment stream IDs for all streams contained in this union statement. + # We'll use the ID of the first stream as the basis for incrementing. + stream.increment_id(streams[0]._id) + return stream diff --git a/tests/unit/test_stream.py b/tests/unit/test_stream.py index 4b82220..d03d359 100644 --- a/tests/unit/test_stream.py +++ b/tests/unit/test_stream.py @@ -4,7 +4,7 @@ from pysaql.enums import FillDateTypeString, JoinType, Order from pysaql.scalar import field -from pysaql.stream import cogroup, load, Stream +from pysaql.stream import cogroup, load, Stream, union def test_load(): @@ -200,3 +200,37 @@ def test_fill__partition(): str(stream) == """q0 = fill q0 by (dateCols=('Year', 'Month', "Y-M"), partition='Type');""" ) + + +def test_union(): + """Should return a unioned stream""" + + q0 = load("q0_dataset") + q1 = load("q1_dataset") + q2 = load("q2_dataset") + q3 = load("q3_dataset") + + u0 = union(q0, q1) + u1 = union(u0, q2, q3) + + assert str(u1).split("\n") == [ + """q0 = load "q0_dataset";""", + """q1 = load "q1_dataset";""", + """q2 = union q0, q1;""", + """q3 = load "q2_dataset";""", + """q4 = load "q3_dataset";""", + """q5 = union q2, q3, q4;""", + ] + + +def test_union__no_streams(): + """Should raise ValueError when no streams are provided""" + with pytest.raises(ValueError): + union() + + +def test_union__one_streams(): + """Should raise ValueError when a single streams is provided""" + with pytest.raises(ValueError): + q0 = load("q0_dataset") + union(q0)