-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathresharpen.py
76 lines (53 loc) · 1.8 KB
/
resharpen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from functools import wraps
from typing import Callable
import latent_preview
import torch
ORIGINAL_PREP: Callable = latent_preview.prepare_callback
RESHARPEN_STRENGTH: float = 0.0
LATENT_CACHE: torch.Tensor = None
def disable_resharpen():
"""Reset the ReSharpen Strength"""
global RESHARPEN_STRENGTH
RESHARPEN_STRENGTH = 0.0
def hijack(PREP) -> Callable:
@wraps(PREP)
def prep_callback(*args, **kwargs):
global LATENT_CACHE
LATENT_CACHE = None
original_callback: Callable = PREP(*args, **kwargs)
if not RESHARPEN_STRENGTH:
return original_callback
print("[ReSharpen] Enabled~")
@torch.inference_mode()
@wraps(original_callback)
def hijack_callback(step, x0, x, total_steps):
if not RESHARPEN_STRENGTH:
return original_callback(step, x0, x, total_steps)
global LATENT_CACHE
if LATENT_CACHE is not None:
delta = x.detach().clone() - LATENT_CACHE
x += delta * RESHARPEN_STRENGTH
LATENT_CACHE = x.detach().clone()
return original_callback(step, x0, x, total_steps)
return hijack_callback
return prep_callback
latent_preview.prepare_callback = hijack(ORIGINAL_PREP)
class ReSharpen:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"latent": ("LATENT",),
"details": (
"FLOAT",
{"default": 0.0, "min": -2.0, "max": 2.0, "step": 0.1},
),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "hook"
CATEGORY = "latent"
def hook(self, latent, details: float):
global RESHARPEN_STRENGTH
RESHARPEN_STRENGTH = details / -10.0
return (latent,)