Skip to content

Commit

Permalink
feat: add sqlglot to parse sql dataflow (#3310)
Browse files Browse the repository at this point in the history
## 📝 Summary

<!--
Provide a concise summary of what this pull request is addressing.

If this PR fixes any issues, list them here by number (e.g., Fixes
#123).
-->
Fixes #3103. Adds sqlglot as an optional dependency to handle complex
sql refs parsing. Also supports DML statements (insert, update and
delete).

## 🔍 Description of Changes

<!--
Detail the specific changes made in this pull request. Explain the
problem addressed and how it was resolved. If applicable, provide before
and after comparisons, screenshots, or any relevant details to help
reviewers understand the changes easily.
-->

- set sqlglot version to lowest possible (likely >23.4, due to ibis lib)
- If we're keen on supporting dml's, then I could add more tests (I
added some here alrdy)
- `def find_sql_def()` can also be modified to use sqlglot in the
future, not included in this PR.
- small change: updated duckdb version to 1.1.0 to be synchronised
across all installations (some were 1.0 and some were 1.1)

## 📋 Checklist

- [x] I have read the [contributor
guidelines](https://github.com/marimo-team/marimo/blob/main/CONTRIBUTING.md).
- [ ] For large changes, or changes that affect the public API: this
change was discussed or approved through an issue, on
[Discord](https://marimo.io/discord?ref=pr), or the community
[discussions](https://github.com/marimo-team/marimo/discussions) (Please
provide a link if applicable).
- [x] I have added tests for the changes made.
- [x] I have run the code and verified that it works as expected.

## 📜 Reviewers

<!--
Tag potential reviewers from the community or maintainers who might be
interested in reviewing this pull request.

Your PR will be reviewed more quickly if you can figure out the right
person to tag with @ -->

@akshayka OR @mscolnick
  • Loading branch information
Light2Dark authored Dec 31, 2024
1 parent f8dc9f5 commit be4d05c
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 111 deletions.
143 changes: 45 additions & 98 deletions marimo/_ast/sql_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,6 @@ def find_sql_defs(sql_statement: str) -> SQLDefs:
)


# TODO(akshayka): there are other kinds of refs to find; this should be
# find_sql_refs
def find_sql_refs(
sql_statement: str,
) -> list[str]:
Expand All @@ -315,109 +313,58 @@ def find_sql_refs(
Returns:
A list of table and schema names referenced in the statement.
"""
if not DependencyManager.duckdb.has():

# Use sqlglot to parse ast (https://github.com/tobymao/sqlglot/blob/main/posts/ast_primer.md)
if not DependencyManager.sqlglot.has():
return []

import duckdb
from sqlglot import exp, parse
from sqlglot.errors import ParseError
from sqlglot.optimizer.scope import build_scope

tokens = duckdb.tokenize(sql_statement)
token_extractor = TokenExtractor(
sql_statement=sql_statement, tokens=tokens
)
refs: list[str] = []
cte_names: set[str] = set()
i = 0

# First pass - collect CTE names
while i < len(tokens):
if token_extractor.is_keyword(i, "with"):
i += 1
# Handle optional parenthesis after WITH
if token_extractor.token_str(i) == "(":
i += 1
while i < len(tokens):
if token_extractor.is_keyword(i, "select"):
break
if (
token_extractor.token_str(i) == ","
or token_extractor.token_str(i) == "("
):
i += 1
continue
cte_name = token_extractor.strip_quotes(
token_extractor.token_str(i)
)
if not token_extractor.is_keyword(i, "as"):
cte_names.add(cte_name)
i += 1
if token_extractor.is_keyword(i, "as"):
break
i += 1
def append_refs_from_table(table: exp.Table) -> None:
if table.catalog == "memory":
# Default in-memory catalog, only include table name
refs.append(table.name)
else:
# We skip schema if there is a catalog
# Because it may be called "public" or "main" across all catalogs
# and they aren't referenced in the code
if table.catalog:
refs.append(table.catalog)
elif table.db:
refs.append(table.db) # schema

if table.name:
refs.append(table.name)

try:
expression_list = parse(sql_statement, dialect="duckdb")
except ParseError as e:
LOGGER.error(f"Unable to parse SQL. Error: {e}")
return []

# Second pass - collect references excluding CTEs
i = 0
while i < len(tokens):
if token_extractor.is_keyword(i, "from") or token_extractor.is_keyword(
i, "join"
):
i += 1
if i < len(tokens):
# Skip over opening parenthesis for subqueries
if token_extractor.token_str(i) == "(":
continue

# Get table name parts, this could be:
# - catalog.schema.table
# - catalog.table (this is shorthand for catalog.main.table)
# - table

parts: List[str] = []
while i < len(tokens):
part = token_extractor.strip_quotes(
token_extractor.token_str(i)
)
parts.append(part)
# next token is a dot, so we continue getting parts
if (
i + 1 < len(tokens)
and token_extractor.token_str(i + 1) == "."
):
i += 2
continue
break

if len(parts) == 3:
# If its the default in-memory catalog,
# only add the table name
if parts[0] == "memory":
refs.append(parts[2])
else:
# Just add the catalog and table, skip schema
refs.extend([parts[0], parts[2]])
elif len(parts) == 2:
# If its the default in-memory catalog, only add the table
if parts[0] == "memory":
refs.append(parts[1])
else:
# It's a catalog and table, add both
refs.extend(parts)
elif len(parts) == 1:
# It's a table, make sure it's not a CTE
if parts[0] not in cte_names:
refs.append(parts[0])
else:
LOGGER.warning(
"Unexpected number of parts in SQL reference: %s",
parts,
)
for expression in expression_list:
if expression is None:
continue

i -= 1 # Compensate for outer loop increment
i += 1
if is_dml := bool(expression.find(exp.Update, exp.Insert, exp.Delete)):
for table in expression.find_all(exp.Table):
append_refs_from_table(table)

# build_scope only works for select statements
if root := build_scope(expression):
if is_dml:
LOGGER.warning(
"Scopes should not exist for dml's, may need rework if this occurs"
)

# Re-use find_sql_defs to find referenced schemas and catalogs during creation.
defs = find_sql_defs(sql_statement)
refs.extend(defs.reffed_schemas)
refs.extend(defs.reffed_catalogs)
for scope in root.traverse(): # type: ignore
for _alias, (_node, source) in scope.selected_sources.items():
if isinstance(source, exp.Table):
append_refs_from_table(source)

# Remove duplicates while preserving order
# remove duplicates while preserving order
return list(dict.fromkeys(refs))
1 change: 1 addition & 0 deletions marimo/_dependencies/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class DependencyManager:
numpy = Dependency("numpy")
altair = Dependency("altair", min_version="5.3.0", max_version="6.0.0")
duckdb = Dependency("duckdb")
sqlglot = Dependency("sqlglot")
pillow = Dependency("PIL")
plotly = Dependency("plotly")
bokeh = Dependency("bokeh")
Expand Down
21 changes: 14 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,22 @@ marimo = "marimo._cli.cli:main"
homepage = "https://github.com/marimo-team/marimo"

[project.optional-dependencies]
sql = ["duckdb >= 1.0.0", "polars[pyarrow] >= 1.9.0"]
sql = [
"duckdb>=1.0.0",
"polars[pyarrow]>=1.9.0",
"sqlglot>=23.4"
]

# List of deps that are recommended for most users
# in order to unlock all features in marimo
recommended = [
"duckdb>=1.1.0", # SQL cells
"altair>=5.4.0", # Plotting in datasource viewer
"polars>=1.9.0", # SQL output back in Python
"openai>=1.41.1", # AI features
"ruff", # Formatting
"nbformat>=5.7.0", # Export as IPYNB
"duckdb>=1.0.0", # SQL cells
"altair>=5.4.0", # Plotting in datasource viewer
"polars[pyarrow]>=1.9.0", # SQL output back in Python
"sqlglot>=23.4", # SQL cells parsing
"openai>=1.41.1", # AI features
"ruff", # Formatting
"nbformat>=5.7.0", # Export as IPYNB
]

dev = [
Expand All @@ -95,6 +101,7 @@ dev = [
"opentelemetry-sdk~=1.26.0",
# For SQL
"duckdb>=1.0.0",
"sqlglot>=23.4",
# For linting
"ruff~=0.6.1",
# For AI
Expand Down
129 changes: 123 additions & 6 deletions tests/_ast/test_sql_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from marimo._dependencies.dependencies import DependencyManager

HAS_DUCKDB = DependencyManager.duckdb.has()
HAS_SQLGLOT = DependencyManager.sqlglot.has()


def test_execute_with_string_literal() -> None:
Expand Down Expand Up @@ -469,7 +470,7 @@ def test_find_sql_defs_duckdb_not_available() -> None:
assert find_sql_defs("CREATE TABLE test (id INT);") == SQLDefs()


@pytest.mark.skipif(not HAS_DUCKDB, reason="Missing DuckDB")
@pytest.mark.skipif(not HAS_SQLGLOT, reason="Missing sqlglot")
class TestFindSQLRefs:
@staticmethod
def test_find_sql_refs_simple() -> None:
Expand Down Expand Up @@ -509,8 +510,6 @@ def test_find_sql_refs_with_schema() -> None:
@staticmethod
def test_find_sql_refs_with_catalog() -> None:
# Skip the schema if it's coming from a catalog
# Why? Because it may be called "public" or "main" across all catalogs
# and they aren't referenced in the code
sql = "SELECT * FROM my_catalog.my_schema.my_table;"
assert find_sql_refs(sql) == ["my_catalog", "my_table"]

Expand Down Expand Up @@ -566,7 +565,6 @@ def test_find_sql_refs_with_quoted_names() -> None:
assert find_sql_refs(sql) == ["My Table", "Weird.Name"]

@staticmethod
@pytest.mark.xfail(reason="Multiple CTEs are not supported")
def test_find_sql_refs_with_multiple_ctes() -> None:
sql = """
WITH
Expand All @@ -578,7 +576,6 @@ def test_find_sql_refs_with_multiple_ctes() -> None:
assert find_sql_refs(sql) == ["table1", "table2"]

@staticmethod
@pytest.mark.xfail(reason="Nested joins are not supported")
def test_find_sql_refs_with_nested_joins() -> None:
sql = """
SELECT * FROM t1
Expand All @@ -593,7 +590,7 @@ def test_find_sql_refs_with_lateral_join() -> None:
SELECT * FROM employees,
LATERAL (SELECT * FROM departments WHERE departments.id = employees.dept_id) dept;
"""
assert find_sql_refs(sql) == ["employees", "departments"]
assert find_sql_refs(sql) == ["departments", "employees"]

@staticmethod
def test_find_sql_refs_with_schema_switching() -> None:
Expand All @@ -614,3 +611,123 @@ def test_find_sql_refs_with_complex_subqueries() -> None:
) t2;
"""
assert find_sql_refs(sql) == ["deeply", "table", "another_table"]

@staticmethod
def test_find_sql_refs_nested_intersect() -> None:
sql = """
SELECT * FROM table1
WHERE id IN (
SELECT id FROM table2
UNION
SELECT id FROM table3
INTERSECT
SELECT id FROM table4
);
"""
assert find_sql_refs(sql) == ["table2", "table3", "table4", "table1"]

@staticmethod
def test_find_sql_refs_with_alias() -> None:
sql = "SELECT * FROM employees AS e;"
assert find_sql_refs(sql) == ["employees"]

@staticmethod
def test_find_sql_refs_comment() -> None:
sql = """
-- comment
SELECT * FROM table1;
-- comment
"""
assert find_sql_refs(sql) == ["table1"]

@staticmethod
def test_find_sql_refs_ddl() -> None:
# we are not referencing any table hence no refs
sql = "CREATE TABLE t1 (id int);"
assert find_sql_refs(sql) == []

@staticmethod
def test_find_sql_refs_ddl_with_reference() -> None:
sql = """
CREATE TABLE table2 AS
WITH x AS (
SELECT * from table1
)
SELECT * FROM x;
"""
assert find_sql_refs(sql) == ["table1"]

@staticmethod
def test_find_sql_refs_update() -> None:
sql = "UPDATE my_schema.table1 SET id = 1"
assert find_sql_refs(sql) == ["my_schema", "table1"]

@staticmethod
def test_find_sql_refs_insert() -> None:
sql = "INSERT INTO my_schema.table1 (id INT) VALUES (1,2);"
assert find_sql_refs(sql) == ["my_schema", "table1"]

@staticmethod
def test_find_sql_refs_delete() -> None:
sql = "DELETE FROM my_schema.table1 WHERE true;"
assert find_sql_refs(sql) == ["my_schema", "table1"]

@staticmethod
def test_find_sql_refs_multi_dml() -> None:
sql = """
INSERT INTO table1 (id INT) VALUES (1,2);
DELETE FROM table2 WHERE true;
UPDATE table3 SET id = 1;
"""
assert find_sql_refs(sql) == ["table1", "table2", "table3"]

@staticmethod
def test_find_sql_refs_multiple_selects_in_update() -> None:
sql = """
UPDATE schema1.table1
SET table1.column1 = (
SELECT table2.column2 FROM schema2.table2
),
table1.column3 = (
SELECT table3.column3 FROM table3
)
WHERE EXISTS (
SELECT 1 FROM table2
)
AND table1.column4 IN (
SELECT table4.column4 FROM table4
);
"""
assert find_sql_refs(sql) == [
"schema1",
"table1",
"schema2",
"table2",
"table3",
"table4",
]

@staticmethod
def test_find_sql_refs_select_in_insert() -> None:
sql = """
INSERT INTO table1 (column1, column2)
SELECT column1, column2 FROM table2
WHERE column3 = 'value';
"""
assert find_sql_refs(sql) == ["table1", "table2"]

@staticmethod
def test_find_sql_refs_select_in_delete() -> None:
sql = """
DELETE FROM table1
WHERE column1 IN (
SELECT column1 FROM table2
WHERE column2 = 'value'
);
"""
assert find_sql_refs(sql) == ["table1", "table2"]

@staticmethod
def test_find_sql_refs_invalid_sql() -> None:
sql = "SELECT * FROM"
assert find_sql_refs(sql) == []

0 comments on commit be4d05c

Please sign in to comment.