diff --git a/marimo/_ast/sql_visitor.py b/marimo/_ast/sql_visitor.py index 9ad3a62e3b8..837e012dc81 100644 --- a/marimo/_ast/sql_visitor.py +++ b/marimo/_ast/sql_visitor.py @@ -301,8 +301,6 @@ def find_sql_defs(sql_statement: str) -> SQLDefs: ) -# TODO(akshayka): there are other kinds of refs to find; this should be -# find_sql_refs def find_sql_refs( sql_statement: str, ) -> list[str]: @@ -315,109 +313,58 @@ def find_sql_refs( Returns: A list of table and schema names referenced in the statement. """ - if not DependencyManager.duckdb.has(): + + # Use sqlglot to parse ast (https://github.com/tobymao/sqlglot/blob/main/posts/ast_primer.md) + if not DependencyManager.sqlglot.has(): return [] - import duckdb + from sqlglot import exp, parse + from sqlglot.errors import ParseError + from sqlglot.optimizer.scope import build_scope - tokens = duckdb.tokenize(sql_statement) - token_extractor = TokenExtractor( - sql_statement=sql_statement, tokens=tokens - ) refs: list[str] = [] - cte_names: set[str] = set() - i = 0 - # First pass - collect CTE names - while i < len(tokens): - if token_extractor.is_keyword(i, "with"): - i += 1 - # Handle optional parenthesis after WITH - if token_extractor.token_str(i) == "(": - i += 1 - while i < len(tokens): - if token_extractor.is_keyword(i, "select"): - break - if ( - token_extractor.token_str(i) == "," - or token_extractor.token_str(i) == "(" - ): - i += 1 - continue - cte_name = token_extractor.strip_quotes( - token_extractor.token_str(i) - ) - if not token_extractor.is_keyword(i, "as"): - cte_names.add(cte_name) - i += 1 - if token_extractor.is_keyword(i, "as"): - break - i += 1 + def append_refs_from_table(table: exp.Table) -> None: + if table.catalog == "memory": + # Default in-memory catalog, only include table name + refs.append(table.name) + else: + # We skip schema if there is a catalog + # Because it may be called "public" or "main" across all catalogs + # and they aren't referenced in the code + if table.catalog: + refs.append(table.catalog) + elif table.db: + refs.append(table.db) # schema + + if table.name: + refs.append(table.name) + + try: + expression_list = parse(sql_statement, dialect="duckdb") + except ParseError as e: + LOGGER.error(f"Unable to parse SQL. Error: {e}") + return [] - # Second pass - collect references excluding CTEs - i = 0 - while i < len(tokens): - if token_extractor.is_keyword(i, "from") or token_extractor.is_keyword( - i, "join" - ): - i += 1 - if i < len(tokens): - # Skip over opening parenthesis for subqueries - if token_extractor.token_str(i) == "(": - continue - - # Get table name parts, this could be: - # - catalog.schema.table - # - catalog.table (this is shorthand for catalog.main.table) - # - table - - parts: List[str] = [] - while i < len(tokens): - part = token_extractor.strip_quotes( - token_extractor.token_str(i) - ) - parts.append(part) - # next token is a dot, so we continue getting parts - if ( - i + 1 < len(tokens) - and token_extractor.token_str(i + 1) == "." - ): - i += 2 - continue - break - - if len(parts) == 3: - # If its the default in-memory catalog, - # only add the table name - if parts[0] == "memory": - refs.append(parts[2]) - else: - # Just add the catalog and table, skip schema - refs.extend([parts[0], parts[2]]) - elif len(parts) == 2: - # If its the default in-memory catalog, only add the table - if parts[0] == "memory": - refs.append(parts[1]) - else: - # It's a catalog and table, add both - refs.extend(parts) - elif len(parts) == 1: - # It's a table, make sure it's not a CTE - if parts[0] not in cte_names: - refs.append(parts[0]) - else: - LOGGER.warning( - "Unexpected number of parts in SQL reference: %s", - parts, - ) + for expression in expression_list: + if expression is None: + continue - i -= 1 # Compensate for outer loop increment - i += 1 + if is_dml := bool(expression.find(exp.Update, exp.Insert, exp.Delete)): + for table in expression.find_all(exp.Table): + append_refs_from_table(table) + + # build_scope only works for select statements + if root := build_scope(expression): + if is_dml: + LOGGER.warning( + "Scopes should not exist for dml's, may need rework if this occurs" + ) - # Re-use find_sql_defs to find referenced schemas and catalogs during creation. - defs = find_sql_defs(sql_statement) - refs.extend(defs.reffed_schemas) - refs.extend(defs.reffed_catalogs) + for scope in root.traverse(): # type: ignore + for _alias, (_node, source) in scope.selected_sources.items(): + if isinstance(source, exp.Table): + append_refs_from_table(source) - # Remove duplicates while preserving order + # remove duplicates while preserving order return list(dict.fromkeys(refs)) diff --git a/marimo/_dependencies/dependencies.py b/marimo/_dependencies/dependencies.py index 808a7ded49a..cf2eef406c5 100644 --- a/marimo/_dependencies/dependencies.py +++ b/marimo/_dependencies/dependencies.py @@ -147,6 +147,7 @@ class DependencyManager: numpy = Dependency("numpy") altair = Dependency("altair", min_version="5.3.0", max_version="6.0.0") duckdb = Dependency("duckdb") + sqlglot = Dependency("sqlglot") pillow = Dependency("PIL") plotly = Dependency("plotly") bokeh = Dependency("bokeh") diff --git a/pyproject.toml b/pyproject.toml index 74bec2c56ba..2555db23d29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,16 +75,22 @@ marimo = "marimo._cli.cli:main" homepage = "https://github.com/marimo-team/marimo" [project.optional-dependencies] -sql = ["duckdb >= 1.0.0", "polars[pyarrow] >= 1.9.0"] +sql = [ + "duckdb>=1.0.0", + "polars[pyarrow]>=1.9.0", + "sqlglot>=23.4" +] + # List of deps that are recommended for most users # in order to unlock all features in marimo recommended = [ - "duckdb>=1.1.0", # SQL cells - "altair>=5.4.0", # Plotting in datasource viewer - "polars>=1.9.0", # SQL output back in Python - "openai>=1.41.1", # AI features - "ruff", # Formatting - "nbformat>=5.7.0", # Export as IPYNB + "duckdb>=1.0.0", # SQL cells + "altair>=5.4.0", # Plotting in datasource viewer + "polars[pyarrow]>=1.9.0", # SQL output back in Python + "sqlglot>=23.4", # SQL cells parsing + "openai>=1.41.1", # AI features + "ruff", # Formatting + "nbformat>=5.7.0", # Export as IPYNB ] dev = [ @@ -95,6 +101,7 @@ dev = [ "opentelemetry-sdk~=1.26.0", # For SQL "duckdb>=1.0.0", + "sqlglot>=23.4", # For linting "ruff~=0.6.1", # For AI diff --git a/tests/_ast/test_sql_visitor.py b/tests/_ast/test_sql_visitor.py index aa888892e3b..d1aa2910694 100644 --- a/tests/_ast/test_sql_visitor.py +++ b/tests/_ast/test_sql_visitor.py @@ -14,6 +14,7 @@ from marimo._dependencies.dependencies import DependencyManager HAS_DUCKDB = DependencyManager.duckdb.has() +HAS_SQLGLOT = DependencyManager.sqlglot.has() def test_execute_with_string_literal() -> None: @@ -469,7 +470,7 @@ def test_find_sql_defs_duckdb_not_available() -> None: assert find_sql_defs("CREATE TABLE test (id INT);") == SQLDefs() -@pytest.mark.skipif(not HAS_DUCKDB, reason="Missing DuckDB") +@pytest.mark.skipif(not HAS_SQLGLOT, reason="Missing sqlglot") class TestFindSQLRefs: @staticmethod def test_find_sql_refs_simple() -> None: @@ -509,8 +510,6 @@ def test_find_sql_refs_with_schema() -> None: @staticmethod def test_find_sql_refs_with_catalog() -> None: # Skip the schema if it's coming from a catalog - # Why? Because it may be called "public" or "main" across all catalogs - # and they aren't referenced in the code sql = "SELECT * FROM my_catalog.my_schema.my_table;" assert find_sql_refs(sql) == ["my_catalog", "my_table"] @@ -566,7 +565,6 @@ def test_find_sql_refs_with_quoted_names() -> None: assert find_sql_refs(sql) == ["My Table", "Weird.Name"] @staticmethod - @pytest.mark.xfail(reason="Multiple CTEs are not supported") def test_find_sql_refs_with_multiple_ctes() -> None: sql = """ WITH @@ -578,7 +576,6 @@ def test_find_sql_refs_with_multiple_ctes() -> None: assert find_sql_refs(sql) == ["table1", "table2"] @staticmethod - @pytest.mark.xfail(reason="Nested joins are not supported") def test_find_sql_refs_with_nested_joins() -> None: sql = """ SELECT * FROM t1 @@ -593,7 +590,7 @@ def test_find_sql_refs_with_lateral_join() -> None: SELECT * FROM employees, LATERAL (SELECT * FROM departments WHERE departments.id = employees.dept_id) dept; """ - assert find_sql_refs(sql) == ["employees", "departments"] + assert find_sql_refs(sql) == ["departments", "employees"] @staticmethod def test_find_sql_refs_with_schema_switching() -> None: @@ -614,3 +611,123 @@ def test_find_sql_refs_with_complex_subqueries() -> None: ) t2; """ assert find_sql_refs(sql) == ["deeply", "table", "another_table"] + + @staticmethod + def test_find_sql_refs_nested_intersect() -> None: + sql = """ + SELECT * FROM table1 + WHERE id IN ( + SELECT id FROM table2 + UNION + SELECT id FROM table3 + INTERSECT + SELECT id FROM table4 + ); + """ + assert find_sql_refs(sql) == ["table2", "table3", "table4", "table1"] + + @staticmethod + def test_find_sql_refs_with_alias() -> None: + sql = "SELECT * FROM employees AS e;" + assert find_sql_refs(sql) == ["employees"] + + @staticmethod + def test_find_sql_refs_comment() -> None: + sql = """ + -- comment + SELECT * FROM table1; + -- comment + """ + assert find_sql_refs(sql) == ["table1"] + + @staticmethod + def test_find_sql_refs_ddl() -> None: + # we are not referencing any table hence no refs + sql = "CREATE TABLE t1 (id int);" + assert find_sql_refs(sql) == [] + + @staticmethod + def test_find_sql_refs_ddl_with_reference() -> None: + sql = """ + CREATE TABLE table2 AS + WITH x AS ( + SELECT * from table1 + ) + SELECT * FROM x; + """ + assert find_sql_refs(sql) == ["table1"] + + @staticmethod + def test_find_sql_refs_update() -> None: + sql = "UPDATE my_schema.table1 SET id = 1" + assert find_sql_refs(sql) == ["my_schema", "table1"] + + @staticmethod + def test_find_sql_refs_insert() -> None: + sql = "INSERT INTO my_schema.table1 (id INT) VALUES (1,2);" + assert find_sql_refs(sql) == ["my_schema", "table1"] + + @staticmethod + def test_find_sql_refs_delete() -> None: + sql = "DELETE FROM my_schema.table1 WHERE true;" + assert find_sql_refs(sql) == ["my_schema", "table1"] + + @staticmethod + def test_find_sql_refs_multi_dml() -> None: + sql = """ + INSERT INTO table1 (id INT) VALUES (1,2); + DELETE FROM table2 WHERE true; + UPDATE table3 SET id = 1; + """ + assert find_sql_refs(sql) == ["table1", "table2", "table3"] + + @staticmethod + def test_find_sql_refs_multiple_selects_in_update() -> None: + sql = """ + UPDATE schema1.table1 + SET table1.column1 = ( + SELECT table2.column2 FROM schema2.table2 + ), + table1.column3 = ( + SELECT table3.column3 FROM table3 + ) + WHERE EXISTS ( + SELECT 1 FROM table2 + ) + AND table1.column4 IN ( + SELECT table4.column4 FROM table4 + ); + """ + assert find_sql_refs(sql) == [ + "schema1", + "table1", + "schema2", + "table2", + "table3", + "table4", + ] + + @staticmethod + def test_find_sql_refs_select_in_insert() -> None: + sql = """ + INSERT INTO table1 (column1, column2) + SELECT column1, column2 FROM table2 + WHERE column3 = 'value'; + """ + assert find_sql_refs(sql) == ["table1", "table2"] + + @staticmethod + def test_find_sql_refs_select_in_delete() -> None: + sql = """ + DELETE FROM table1 + WHERE column1 IN ( + SELECT column1 FROM table2 + WHERE column2 = 'value' + ); + """ + assert find_sql_refs(sql) == ["table1", "table2"] + + @staticmethod + def test_find_sql_refs_invalid_sql() -> None: + sql = "SELECT * FROM" + assert find_sql_refs(sql) == []