Skip to content

Commit

Permalink
Graphene v3 (tests) (#317)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonathan Ehwald <github@ehwald.info>
Co-authored-by: Zbigniew Siciarz <zbigniew@siciarz.net>
Co-authored-by: Cole Lin <colelin26@gmail.com>
  • Loading branch information
4 people authored Sep 21, 2021
1 parent cba727c commit d6dd67e
Show file tree
Hide file tree
Showing 21 changed files with 196 additions and 220 deletions.
50 changes: 25 additions & 25 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,31 @@ jobs:
strategy:
max-parallel: 10
matrix:
sql-alchemy: ["1.2", "1.3"]
sql-alchemy: ["1.2", "1.3", "1.4"]
python-version: ["3.6", "3.7", "3.8", "3.9"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox tox-gh-actions
- name: Test with tox
run: tox
env:
SQLALCHEMY: ${{ matrix.sql-alchemy }}
TOXENV: ${{ matrix.toxenv }}
- name: Upload coverage.xml
if: ${{ matrix.sql-alchemy == '1.3' && matrix.python-version == '3.9' }}
uses: actions/upload-artifact@v2
with:
name: graphene-sqlalchemy-coverage
path: coverage.xml
if-no-files-found: error
- name: Upload coverage.xml to codecov
if: ${{ matrix.sql-alchemy == '1.3' && matrix.python-version == '3.9' }}
uses: codecov/codecov-action@v1
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox tox-gh-actions
- name: Test with tox
run: tox
env:
SQLALCHEMY: ${{ matrix.sql-alchemy }}
TOXENV: ${{ matrix.toxenv }}
- name: Upload coverage.xml
if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.9' }}
uses: actions/upload-artifact@v2
with:
name: graphene-sqlalchemy-coverage
path: coverage.xml
if-no-files-found: error
- name: Upload coverage.xml to codecov
if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.9' }}
uses: codecov/codecov-action@v1
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ target/
# Databases
*.sqlite3
.vscode

# mypy cache
.mypy_cache/
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .fields import SQLAlchemyConnectionField
from .utils import get_query, get_session

__version__ = "2.3.0"
__version__ = "3.0.0b1"

__all__ = [
"__version__",
Expand Down
49 changes: 33 additions & 16 deletions graphene_sqlalchemy/batching.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import aiodataloader
import sqlalchemy
from promise import dataloader, promise
from sqlalchemy.orm import Session, strategies
from sqlalchemy.orm.query import QueryContext

from .utils import is_sqlalchemy_version_less_than


def get_batch_resolver(relationship_prop):

# Cache this across `batch_load_fn` calls
# This is so SQL string generation is cached under-the-hood via `bakery`
selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))

class RelationshipLoader(dataloader.DataLoader):
class RelationshipLoader(aiodataloader.DataLoader):
cache = False

def batch_load_fn(self, parents): # pylint: disable=method-hidden
async def batch_load_fn(self, parents):
"""
Batch loads the relationships of all the parents as one SQL statement.
Expand Down Expand Up @@ -52,21 +54,36 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden
states = [(sqlalchemy.inspect(parent), True) for parent in parents]

# For our purposes, the query_context will only used to get the session
query_context = QueryContext(session.query(parent_mapper.entity))

selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
)

return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents])
query_context = None
if is_sqlalchemy_version_less_than('1.4'):
query_context = QueryContext(session.query(parent_mapper.entity))
else:
parent_mapper_query = session.query(parent_mapper.entity)
query_context = parent_mapper_query._compile_context()

if is_sqlalchemy_version_less_than('1.4'):
selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper
)
else:
selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None
)

return [getattr(parent, relationship_prop.key) for parent in parents]

loader = RelationshipLoader()

def resolve(root, info, **args):
return loader.load(root)
async def resolve(root, info, **args):
return await loader.load(root)

return resolve
17 changes: 11 additions & 6 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from enum import EnumMeta
from functools import singledispatch

from singledispatch import singledispatch
from sqlalchemy import types
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import interfaces, strategies
Expand All @@ -21,6 +20,11 @@
except ImportError:
ChoiceType = JSONType = ScalarListType = TSVectorType = object

try:
from sqlalchemy_utils.types.choice import EnumTypeImpl
except ImportError:
EnumTypeImpl = object


is_selectin_available = getattr(strategies, 'SelectInLoader', None)

Expand Down Expand Up @@ -110,9 +114,9 @@ def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, conn


def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs):
if 'type' not in field_kwargs:
if 'type_' not in field_kwargs:
# TODO The default type should be dependent on the type of the property propety.
field_kwargs['type'] = String
field_kwargs['type_'] = String

return Field(
resolver=resolver,
Expand Down Expand Up @@ -156,7 +160,8 @@ def inner(fn):

def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs):
column = column_prop.columns[0]
field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))

field_kwargs.setdefault('type_', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))
field_kwargs.setdefault('required', not is_column_nullable(column))
field_kwargs.setdefault('description', get_column_doc(column))

Expand Down Expand Up @@ -221,7 +226,7 @@ def convert_enum_to_enum(type, column, registry=None):
@convert_sqlalchemy_type.register(ChoiceType)
def convert_choice_to_enum(type, column, registry=None):
name = "{}_{}".format(column.table.name, column.name).upper()
if isinstance(type.choices, EnumMeta):
if isinstance(type.type_impl, EnumTypeImpl):
# type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta
# do not use from_enum here because we can have more than one enum column in table
return Enum(name, list((v.name, v.value) for v in type.choices))
Expand Down
3 changes: 1 addition & 2 deletions graphene_sqlalchemy/enums.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import six
from sqlalchemy.orm import ColumnProperty
from sqlalchemy.types import Enum as SQLAlchemyEnumType

Expand Down Expand Up @@ -63,7 +62,7 @@ def enum_for_field(obj_type, field_name):
if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType):
raise TypeError(
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type))
if not field_name or not isinstance(field_name, six.string_types):
if not field_name or not isinstance(field_name, str):
raise TypeError(
"Expected a field name, but got: {!r}".format(field_name))
registry = obj_type._meta.registry
Expand Down
66 changes: 40 additions & 26 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
import enum
import warnings
from functools import partial

import six
from promise import Promise, is_thenable
from sqlalchemy.orm.query import Query

from graphene import NonNull
from graphene.relay import Connection, ConnectionField
from graphene.relay.connection import PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from graphene.relay.connection import connection_adapter, page_info_adapter
from graphql_relay.connection.arrayconnection import \
connection_from_array_slice

from .batching import get_batch_resolver
from .utils import get_query
from .utils import EnumValue, get_query


class UnsortedSQLAlchemyConnectionField(ConnectionField):
@property
def type(self):
from .types import SQLAlchemyObjectType

_type = super(ConnectionField, self).type
nullable_type = get_nullable_type(_type)
type_ = super(ConnectionField, self).type
nullable_type = get_nullable_type(type_)
if issubclass(nullable_type, Connection):
return _type
return type_
assert issubclass(nullable_type, SQLAlchemyObjectType), (
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
).format(nullable_type.__name__)
Expand All @@ -31,7 +32,7 @@ def type(self):
), "The type {} doesn't have a connection".format(
nullable_type.__name__
)
assert _type == nullable_type, (
assert type_ == nullable_type, (
"Passing a SQLAlchemyObjectType instance is deprecated. "
"Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
)
Expand All @@ -53,15 +54,19 @@ def resolve_connection(cls, connection_type, model, info, args, resolved):
_len = resolved.count()
else:
_len = len(resolved)
connection = connection_from_list_slice(
resolved,
args,

def adjusted_connection_adapter(edges, pageInfo):
return connection_adapter(connection_type, edges, pageInfo)

connection = connection_from_array_slice(
array_slice=resolved,
args=args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
connection_type=connection_type,
pageinfo_type=PageInfo,
array_length=_len,
array_slice_length=_len,
connection_type=adjusted_connection_adapter,
edge_type=connection_type.Edge,
page_info_type=page_info_adapter,
)
connection.iterable = resolved
connection.length = _len
Expand All @@ -77,7 +82,7 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg

return on_resolve(resolved)

def get_resolver(self, parent_resolver):
def wrap_resolve(self, parent_resolver):
return partial(
self.connection_resolver,
parent_resolver,
Expand All @@ -88,8 +93,8 @@ def get_resolver(self, parent_resolver):

# TODO Rename this to SortableSQLAlchemyConnectionField
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
def __init__(self, type, *args, **kwargs):
nullable_type = get_nullable_type(type)
def __init__(self, type_, *args, **kwargs):
nullable_type = get_nullable_type(type_)
if "sort" not in kwargs and issubclass(nullable_type, Connection):
# Let super class raise if type is not a Connection
try:
Expand All @@ -103,16 +108,25 @@ def __init__(self, type, *args, **kwargs):
)
elif "sort" in kwargs and kwargs["sort"] is None:
del kwargs["sort"]
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)
super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)

@classmethod
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if isinstance(sort, six.string_types):
query = query.order_by(sort.value)
else:
query = query.order_by(*(col.value for col in sort))
if not isinstance(sort, list):
sort = [sort]
sort_args = []
# ensure consistent handling of graphene Enums, enum values and
# plain strings
for item in sort:
if isinstance(item, enum.Enum):
sort_args.append(item.value.value)
elif isinstance(item, EnumValue):
sort_args.append(item.value)
else:
sort_args.append(item)
query = query.order_by(*sort_args)
return query


Expand All @@ -123,7 +137,7 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
Use at your own risk.
"""

def get_resolver(self, parent_resolver):
def wrap_resolve(self, parent_resolver):
return partial(
self.connection_resolver,
self.resolver,
Expand All @@ -148,13 +162,13 @@ def default_connection_field_factory(relationship, registry, **field_kwargs):
__connectionFactory = UnsortedSQLAlchemyConnectionField


def createConnectionField(_type, **field_kwargs):
def createConnectionField(type_, **field_kwargs):
warnings.warn(
'createConnectionField is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.',
DeprecationWarning,
)
return __connectionFactory(_type, **field_kwargs)
return __connectionFactory(type_, **field_kwargs)


def registerConnectionFieldFactory(factoryMethod):
Expand Down
3 changes: 1 addition & 2 deletions graphene_sqlalchemy/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import defaultdict

import six
from sqlalchemy.types import Enum as SQLAlchemyEnumType

from graphene import Enum
Expand Down Expand Up @@ -43,7 +42,7 @@ def register_orm_field(self, obj_type, field_name, orm_field):
raise TypeError(
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
)
if not field_name or not isinstance(field_name, six.string_types):
if not field_name or not isinstance(field_name, str):
raise TypeError("Expected a field name, but got: {!r}".format(field_name))
self._registry_orm_fields[obj_type][field_name] = orm_field

Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def convert_composite_class(composite, registry):
return graphene.Field(graphene.Int)


@pytest.yield_fixture(scope="function")
@pytest.fixture(scope="function")
def session_factory():
engine = create_engine(test_db_url)
Base.metadata.create_all(engine)
Expand Down
Loading

0 comments on commit d6dd67e

Please sign in to comment.