Skip to content

Commit

Permalink
Feature/bind by type handling generic (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxzhenzhera authored Oct 4, 2023
1 parent e88c8f9 commit 0e7adb3
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
os: ubuntu-latest
- python: "3.10"
os: ubuntu-latest
- python: "3.11.0-beta.1 - 3.11"
- python: "3.11"
os: ubuntu-latest
# test OSs
- python: "3.x"
Expand Down
4 changes: 2 additions & 2 deletions di/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ def bind_by_type(
def hook(
param: inspect.Parameter | None, dependent: DependentBase[Any]
) -> DependentBase[Any] | None:
if dependent.call is dependency:
if dependent.call == dependency:
return provider
if param is None:
return None
type_annotation_option = get_type(param)
if type_annotation_option is None:
return None
type_annotation = type_annotation_option.value
if type_annotation is dependency:
if type_annotation == dependency:
return provider
if covariant:
if inspect.isclass(type_annotation) and inspect.isclass(dependency):
Expand Down
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "di"
version = "0.77.0"
version = "0.78.0"
description = "Dependency injection toolkit"
authors = ["Adrian Garcia Badaracco <adrian@adriangb.com>"]
readme = "README.md"
Expand Down Expand Up @@ -38,6 +38,7 @@ anyio = ["anyio"]
# linting
black = "~23"
mypy = "~1"
ruff = "^0.0.286"
pre-commit = "~2"
# testing
pytest = "~7"
Expand All @@ -48,18 +49,15 @@ coverage = { extras = ["toml"], version = "~6" }
# docs
mkdocs = "~1"
mkdocs-material = "~8,!=8.1.3"
mkdocstrings = {version = "^0.19.0", extras = ["python"]}
mike = "~1"
# benchmarking
pyinstrument = "~4"
mkdocstrings = {version = "^0.19.0", extras = ["python"]}
ruff = "^0.0.286"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.isort]
profile = "black"

[tool.coverage.run]
branch = true
Expand Down
64 changes: 63 additions & 1 deletion tests/test_binding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List
import sys
from abc import abstractmethod
from typing import List, TypeVar

import pytest

Expand All @@ -7,6 +9,11 @@
from di.executors import SyncExecutor
from di.typing import Annotated

if sys.version_info < (3, 8): # pragma: no cover
from typing_extensions import Protocol
else: # pragma: no cover
from typing import Protocol


class Request:
def __init__(self, value: int = 0) -> None:
Expand Down Expand Up @@ -47,6 +54,61 @@ def __init__(self, v: int = 1) -> None:
assert res.v == 1


T_co = TypeVar("T_co", covariant=True)


def test_bind_generic():
container = Container()
executor = SyncExecutor()
expected = 100

class GetterInterface(Protocol[T_co]):
@abstractmethod
def get(self) -> T_co:
...

class GetterIntImpl(GetterInterface[int]):
def __init__(self, v: int) -> None:
self.v = v

def get(self) -> int:
return self.v

def factory() -> GetterIntImpl:
return GetterIntImpl(expected)

hook = bind_by_type(
Dependent(factory),
GetterInterface[int],
)
container.bind(hook)

# ===========================================
# clean `_tp_cache`
from typing import _cleanups as cache_cleanups # type: ignore[attr-defined]

for cache_cleanup in cache_cleanups:
cache_cleanup()
# ===========================================

class IntService:
"""Declared after binding and cache clearing."""

def __init__(self, getter: GetterInterface[int]) -> None:
self.getter = getter

scopes = [None]
flat_dependent = Dependent(GetterInterface[int])
wired_dependent = Dependent(IntService)
with container.enter_scope(None) as state:
flat_solved = container.solve(flat_dependent, scopes)
wired_solved = container.solve(wired_dependent, scopes)
flat = flat_solved.execute_sync(executor, state)
wired = wired_solved.execute_sync(executor, state)

assert flat.get() == wired.getter.get() == expected


def test_bind_transitive_dependency_results_skips_subdpendencies():
"""If we bind a transitive dependency none of it's sub-dependencies should be executed
since they are no longer required.
Expand Down

0 comments on commit 0e7adb3

Please sign in to comment.