Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Feature/#815- Filter database datanodes should not read all data #820

Merged
merged 3 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 72 additions & 14 deletions src/taipy/core/data/_abstract_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import urllib.parse
from abc import abstractmethod
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, Tuple, Union

import modin.pandas as modin_pd
import numpy as np
Expand All @@ -24,6 +24,7 @@
from taipy.config.common.scope import Scope

from .._version._version_manager_factory import _VersionManagerFactory
from ..data.operator import JoinOperator, Operator
from ..exceptions.exceptions import MissingRequiredProperty, UnknownDatabaseEngine
from ._abstract_tabular import _AbstractTabularDataNode
from .data_node import DataNode
Expand Down Expand Up @@ -198,6 +199,15 @@ def _conn_string(self) -> str:

raise UnknownDatabaseEngine(f"Unknown engine: {engine}")

def filter(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
if self.properties[self.__EXPOSED_TYPE_PROPERTY] == self.__EXPOSED_TYPE_PANDAS:
return self._read_as_pandas_dataframe(operators=operators, join_operator=join_operator)
if self.properties[self.__EXPOSED_TYPE_PROPERTY] == self.__EXPOSED_TYPE_MODIN:
return self._read_as_modin_dataframe(operators=operators, join_operator=join_operator)
if self.properties[self.__EXPOSED_TYPE_PROPERTY] == self.__EXPOSED_TYPE_NUMPY:
return self._read_as_numpy(operators=operators, join_operator=join_operator)
return self._read_as(operators=operators, join_operator=join_operator)

def _read(self):
if self.properties[self.__EXPOSED_TYPE_PROPERTY] == self.__EXPOSED_TYPE_PANDAS:
return self._read_as_pandas_dataframe()
Expand All @@ -207,32 +217,76 @@ def _read(self):
return self._read_as_numpy()
return self._read_as()

def _read_as(self):
def _read_as(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
custom_class = self.properties[self.__EXPOSED_TYPE_PROPERTY]
with self._get_engine().connect() as connection:
query_result = connection.execute(text(self._get_read_query()))
query_result = connection.execute(text(self._get_read_query(operators, join_operator)))
return list(map(lambda row: custom_class(**row), query_result))

def _read_as_numpy(self) -> np.ndarray:
return self._read_as_pandas_dataframe().to_numpy()
def _read_as_numpy(
self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND
) -> np.ndarray:
return self._read_as_pandas_dataframe(operators=operators, join_operator=join_operator).to_numpy()

def _read_as_pandas_dataframe(self, columns: Optional[List[str]] = None):
def _read_as_pandas_dataframe(
self,
columns: Optional[List[str]] = None,
operators: Optional[Union[List, Tuple]] = None,
join_operator=JoinOperator.AND,
):
with self._get_engine().connect() as conn:
if columns:
return pd.DataFrame(conn.execute(text(self._get_read_query())))[columns]
return pd.DataFrame(conn.execute(text(self._get_read_query())))
return pd.DataFrame(conn.execute(text(self._get_read_query(operators, join_operator))))[columns]
return pd.DataFrame(conn.execute(text(self._get_read_query(operators, join_operator))))

def _read_as_modin_dataframe(self, columns: Optional[List[str]] = None):
def _read_as_modin_dataframe(
self,
columns: Optional[List[str]] = None,
operators: Optional[Union[List, Tuple]] = None,
join_operator=JoinOperator.AND,
):
if columns:
return modin_pd.read_sql_query(self._get_read_query(), con=self._get_engine())[columns]
return modin_pd.read_sql_query(self._get_read_query(), con=self._get_engine())
return modin_pd.read_sql_query(self._get_read_query(operators, join_operator), con=self._get_engine())[
columns
]
return modin_pd.read_sql_query(self._get_read_query(operators, join_operator), con=self._get_engine())

@abstractmethod
def _get_read_query(self):
raise NotImplementedError
def _get_read_query(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
query = self._get_base_read_query()

if not operators:
return query

if not isinstance(operators, List):
operators = [operators]

conditions = []
for key, value, operator in operators:
if operator == Operator.EQUAL:
conditions.append(f"{key} = '{value}'")
elif operator == Operator.NOT_EQUAL:
conditions.append(f"{key} <> '{value}'")
elif operator == Operator.GREATER_THAN:
conditions.append(f"{key} > '{value}'")
elif operator == Operator.GREATER_OR_EQUAL:
conditions.append(f"{key} >= '{value}'")
elif operator == Operator.LESS_THAN:
conditions.append(f"{key} < '{value}'")
elif operator == Operator.LESS_OR_EQUAL:
conditions.append(f"{key} <= '{value}'")

if join_operator == JoinOperator.AND:
query += f" WHERE {' AND '.join(conditions)}"
elif join_operator == JoinOperator.OR:
query += f" WHERE {' OR '.join(conditions)}"
else:
raise NotImplementedError(f"Join operator {join_operator} not implemented.")

return query

@abstractmethod
def _do_write(self, data, engine, connection) -> None:
def _get_base_read_query(self) -> str:
raise NotImplementedError

def _write(self, data) -> None:
Expand All @@ -248,6 +302,10 @@ def _write(self, data) -> None:
else:
transaction.commit()

@abstractmethod
def _do_write(self, data, engine, connection) -> None:
raise NotImplementedError

def __setattr__(self, key: str, value) -> None:
if key in self.__ENGINE_PROPERTIES:
self._engine = None
Expand Down
3 changes: 0 additions & 3 deletions src/taipy/core/data/data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import modin.pandas as modin_pd
import networkx as nx
import numpy as np
import pandas as pd

from taipy.config.common._validate_id import _validate_id
from taipy.config.common.scope import Scope
Expand Down
41 changes: 36 additions & 5 deletions src/taipy/core/data/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

from datetime import datetime, timedelta
from inspect import isclass
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from taipy.config.common.scope import Scope

from .._version._version_manager_factory import _VersionManagerFactory
from ..common._mongo_connector import _connect_mongodb
from ..data.operator import JoinOperator, Operator
from ..exceptions.exceptions import InvalidCustomDocument, MissingRequiredProperty
from .data_node import DataNode
from .data_node_id import DataNodeId, Edit
Expand Down Expand Up @@ -175,19 +176,49 @@ def _check_custom_document(self, custom_document):
def storage_type(cls) -> str:
return cls.__STORAGE_TYPE

def filter(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
cursor = self._read_by_query(operators, join_operator)
return list(map(lambda row: self._decoder(row), cursor))

def _read(self):
cursor = self._read_by_query()

return list(map(lambda row: self._decoder(row), cursor))

def _read_by_query(self):
def _read_by_query(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
"""Query from a Mongo collection, exclude the _id field"""
if not operators:
return self.collection.find()

if not isinstance(operators, List):
operators = [operators]

conditions = []
for key, value, operator in operators:
if operator == Operator.EQUAL:
conditions.append({key: value})
elif operator == Operator.NOT_EQUAL:
conditions.append({key: {"$ne": value}})
elif operator == Operator.GREATER_THAN:
conditions.append({key: {"$gt": value}})
elif operator == Operator.GREATER_OR_EQUAL:
conditions.append({key: {"$gte": value}})
elif operator == Operator.LESS_THAN:
conditions.append({key: {"$lt": value}})
elif operator == Operator.LESS_OR_EQUAL:
conditions.append({key: {"$lte": value}})

query = {}
if join_operator == JoinOperator.AND:
query = {"$and": conditions}
elif join_operator == JoinOperator.OR:
query = {"$or": conditions}
else:
raise NotImplementedError(f"Join operator {join_operator} is not supported.")

return self.collection.find()
return self.collection.find(query)

def _write(self, data) -> None:
"""Check data against a collection of types to handle insertion on the database."""

if not isinstance(data, list):
data = [data]

Expand Down
2 changes: 1 addition & 1 deletion src/taipy/core/data/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
def storage_type(cls) -> str:
return cls.__STORAGE_TYPE

def _get_read_query(self):
def _get_base_read_query(self) -> str:
return self.properties.get(self.__READ_QUERY_KEY)

def _do_write(self, data, engine, connection) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/taipy/core/data/sql_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
def storage_type(cls) -> str:
return cls.__STORAGE_TYPE

def _get_read_query(self):
def _get_base_read_query(self) -> str:
return f"SELECT * FROM {self.properties[self.__TABLE_KEY]}"

def _do_write(self, data, engine, connection) -> None:
Expand Down
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def tmp_sqlite_db_file_path(tmpdir_factory):
file_extension = ".db"
db = create_engine("sqlite:///" + os.path.join(fn.strpath, f"{db_name}{file_extension}"))
conn = db.connect()
conn.execute(text("CREATE TABLE foo (foo int, bar int);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (1, 2);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (3, 4);"))
conn.execute(text("CREATE TABLE example (foo int, bar int);"))
conn.execute(text("INSERT INTO example (foo, bar) VALUES (1, 2);"))
conn.execute(text("INSERT INTO example (foo, bar) VALUES (3, 4);"))
conn.commit()
conn.close()
db.dispose()
Expand All @@ -162,9 +162,9 @@ def tmp_sqlite_sqlite3_file_path(tmpdir_factory):

db = create_engine("sqlite:///" + os.path.join(fn.strpath, f"{db_name}{file_extension}"))
conn = db.connect()
conn.execute(text("CREATE TABLE foo (foo int, bar int);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (1, 2);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (3, 4);"))
conn.execute(text("CREATE TABLE example (foo int, bar int);"))
conn.execute(text("INSERT INTO example (foo, bar) VALUES (1, 2);"))
conn.execute(text("INSERT INTO example (foo, bar) VALUES (3, 4);"))
conn.commit()
conn.close()
db.dispose()
Expand Down
14 changes: 14 additions & 0 deletions tests/core/data/test_mongo_data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from dataclasses import dataclass
from datetime import datetime
from unittest.mock import patch

import mongomock
import pymongo
Expand Down Expand Up @@ -339,3 +340,16 @@ def test_filter(self, properties):
{"bar": 2},
{},
]

@mongomock.patch(servers=(("localhost", 27017),))
@pytest.mark.parametrize("properties", __properties)
def test_filter_does_not_read_all_entities(self, properties):
mongo_dn = MongoCollectionDataNode("foo", Scope.SCENARIO, properties=properties)

# MongoCollectionDataNode.filter() should not call the MongoCollectionDataNode._read() method
with patch.object(MongoCollectionDataNode, "_read") as read_mock:
mongo_dn.filter(("foo", 1, Operator.EQUAL))
mongo_dn.filter(("bar", 2, Operator.NOT_EQUAL))
mongo_dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)

assert read_mock["_read"].call_count == 0
Loading
Loading