Skip to content

Commit

Permalink
LEXIO-38099: union operator (#10)
Browse files Browse the repository at this point in the history
* LEXIO-38099: added UnionStatement

* LEXIO-38099: added union tests

* Version bumped to 0.10.0

* LEXIO-38099: cruft update

Co-authored-by: ns-circle-ci <devops-team+circleci@narrativescience.com>
  • Loading branch information
ns-alexmueller and ns-circle-ci authored May 18, 2022
1 parent 1a16ee2 commit 43b49ac
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .cruft.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"template": "https://github.com/NarrativeScience/cookiecutter-python-lib",
"commit": "e090078902c7effbcde4bb9da97dd41eefee28b7",
"commit": "06d791b4e3ac2362c595a9bcf0617f84e546ec3c",
"checkout": null,
"context": {
"cookiecutter": {
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pysaql"
version = "0.9.0"
version = "0.10.0"
description = "Python SAQL query builder"
authors = ["Jonathan Drake <jon.drake@salesforce.com>"]
license = "BSD-3-Clause"
Expand Down
2 changes: 1 addition & 1 deletion pysaql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Python SAQL query builder"""

__version__ = "0.9.0"
__version__ = "0.10.0"
66 changes: 63 additions & 3 deletions pysaql/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,14 @@ def increment_id(self, incr: int) -> int:
statement.stream._id += incr + i
max_id = max(max_id, statement.stream._id)
i += 1
elif isinstance(statement, CogroupStatement):
# For cogroup statements, leave the left-most (first) branch alone
for (stream, _) in statement.streams[1:]:
elif isinstance(statement, (CogroupStatement, UnionStatement)):
# For cogroup and union statements, leave the left-most (first) branch alone
if isinstance(statement, CogroupStatement):
streams = [stream for (stream, _) in statement.streams[1:]]
else:
streams = list(statement.streams[1:])

for stream in streams:
stream.increment_id(incr + i)
max_id = max(max_id, stream._id)
i += 1
Expand Down Expand Up @@ -398,6 +403,40 @@ def __str__(self) -> str:
return "\n".join(lines)


class UnionStatement(StreamStatement):
"""Statement to combine (union) two or more streams with the same structure into one"""

def __init__(
self,
stream: Stream,
streams: Sequence[Stream],
) -> None:
"""Initializer
Args:
stream: Stream containing this statement
streams: Streams that will be combined
"""
super().__init__()
self.stream = stream
if not streams or len(streams) < 2:
raise ValueError("At least two streams are required")
self.streams = streams

def __str__(self) -> str:
"""Cast this union statement to a string"""
lines = []
stream_refs = []

for stream in self.streams:
lines.append(str(stream))
stream_refs.append(stream.ref)

lines.append(f"{self.stream.ref} = union {', '.join(stream_refs)};")
return "\n".join(lines)


class FillStatement(StreamStatement):
"""Statement to fill a data stream with missing dates"""

Expand Down Expand Up @@ -471,3 +510,24 @@ def cogroup(
# We'll use the ID of the first stream as the basis for incrementing.
stream.increment_id(streams[0][0]._id)
return stream


def union(*streams: Stream) -> Stream:
"""Union data from two or more data streams into a single data stream
Each stream should have the same field names and structure. The streams do
not need to be from the same dataset.
Args:
streams: Streams that will be unioned together
Returns:
a new stream
"""
stream = Stream()
stream.add_statement(UnionStatement(stream, streams))
# Increment stream IDs for all streams contained in this union statement.
# We'll use the ID of the first stream as the basis for incrementing.
stream.increment_id(streams[0]._id)
return stream
36 changes: 35 additions & 1 deletion tests/unit/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pysaql.enums import FillDateTypeString, JoinType, Order
from pysaql.scalar import field
from pysaql.stream import cogroup, load, Stream
from pysaql.stream import cogroup, load, Stream, union


def test_load():
Expand Down Expand Up @@ -200,3 +200,37 @@ def test_fill__partition():
str(stream)
== """q0 = fill q0 by (dateCols=('Year', 'Month', "Y-M"), partition='Type');"""
)


def test_union():
"""Should return a unioned stream"""

q0 = load("q0_dataset")
q1 = load("q1_dataset")
q2 = load("q2_dataset")
q3 = load("q3_dataset")

u0 = union(q0, q1)
u1 = union(u0, q2, q3)

assert str(u1).split("\n") == [
"""q0 = load "q0_dataset";""",
"""q1 = load "q1_dataset";""",
"""q2 = union q0, q1;""",
"""q3 = load "q2_dataset";""",
"""q4 = load "q3_dataset";""",
"""q5 = union q2, q3, q4;""",
]


def test_union__no_streams():
"""Should raise ValueError when no streams are provided"""
with pytest.raises(ValueError):
union()


def test_union__one_streams():
"""Should raise ValueError when a single streams is provided"""
with pytest.raises(ValueError):
q0 = load("q0_dataset")
union(q0)

0 comments on commit 43b49ac

Please sign in to comment.