Skip to content

Commit

Permalink
0.15.1
Browse files Browse the repository at this point in the history
  • Loading branch information
EcmaXp committed Sep 20, 2024
1 parent daa53ca commit db59873
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "reloader.py"
version = "0.15.0"
version = "0.15.1"
description = "A simple script reloader"
license = "MIT"
authors = ["EcmaXp <ecmaxp@ecmaxp.kr>"]
Expand Down
40 changes: 27 additions & 13 deletions reloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from watchdog.utils.event_debouncer import EventDebouncer

__author__ = "EcmaXp"
__version__ = "0.15.0"
__version__ = "0.15.1"
__license__ = "MIT"
__url__ = "https://pypi.org/project/reloader.py/"
__all__ = [
Expand Down Expand Up @@ -182,15 +182,19 @@ def patch(self, module_globals: dict):
try:
yield
finally:
self.patch_module(old_globals, module_globals)
self.patch_module(old_globals, module_globals, visited=set())

def patch_module(self, old_globals: dict, new_globals: dict):
def patch_module(self, old_globals: dict, new_globals: dict, *, visited: set[int]):
for key, new_value in new_globals.items():
old_value = old_globals.get(key)
if old_value is not new_value:
new_globals[key] = self.patch_object(old_value, new_value)
new_globals[key] = self.patch_object(
old_value,
new_value,
visited=visited,
)

def patch_object(self, old_value: Any, new_value: Any):
def patch_object(self, old_value: Any, new_value: Any, *, visited: set[int]):
if isinstance(new_value, MemberDescriptorType):
warnings.warn(
"MemberDescriptor is not supported",
Expand All @@ -200,9 +204,9 @@ def patch_object(self, old_value: Any, new_value: Any):
elif not self.check_object(old_value, new_value):
return new_value
elif isinstance(old_value, type) and isinstance(new_value, type):
return self.patch_class(old_value, new_value)
return self.patch_class(old_value, new_value, visited=visited)
elif callable(old_value) and callable(new_value):
return self.patch_callable(old_value, new_value)
return self.patch_callable(old_value, new_value, visited=visited)
else:
return new_value

Expand All @@ -212,12 +216,14 @@ def check_object(old_value: Any, new_value: Any):
new_module = getattr(new_value, "__module__", None)
return old_module == new_module

def patch_class(self, old_class: type, new_class: type):
self.patch_vars(old_class, new_class)
def patch_class(self, old_class: type, new_class: type, *, visited: set[int]):
self.patch_vars(old_class, new_class, visited=visited)
return old_class

def patch_callable(self, old_callable: Callable, new_callable: Callable):
self.patch_vars(old_callable, new_callable)
def patch_callable(
self, old_callable: Callable, new_callable: Callable, *, visited: set[int]
):
self.patch_vars(old_callable, new_callable, visited=visited)

old_func = inspect.unwrap(old_callable)
new_func = inspect.unwrap(new_callable)
Expand Down Expand Up @@ -246,14 +252,22 @@ def patch_callable(self, old_callable: Callable, new_callable: Callable):

return old_callable

def patch_vars(self, old_obj, new_obj):
def patch_vars(self, old_obj, new_obj, *, visited: set[int]):
if id(old_obj) in visited:
return

visited.add(id(old_obj))
old_vars = vars(old_obj)
for key, new_value in vars(new_obj).items():
old_value = old_vars.get(key)
if key == "__dict__":
continue

setattr(old_obj, key, self.patch_object(old_value, new_value))
setattr(
old_obj,
key,
self.patch_object(old_value, new_value, visited=visited),
)


class CodeModule:
Expand Down

0 comments on commit db59873

Please sign in to comment.