diff --git a/README.md b/README.md index 9b59cf4..c5d4911 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ NOTE: you can also use custom locations for models/motion loras by making use of - Per-frame GLIGEN coordinates control - Currently requires GLIGENTextBoxApplyBatch from KJNodes to do so, but I will add native nodes to do this soon. - Image Injection mid-sampling +- ContextRef and NaiveReuse (novel cross-context consistency techniques) ## Upcoming Features - Example workflows for **every feature** in AnimateDiff-Evolved repo, and hopefully a long Youtube video showing all features (Goal: before Elden Ring DLC releases. Working on it right now.) diff --git a/animatediff/context.py b/animatediff/context.py index b44fa42..8ba6dcf 100644 --- a/animatediff/context.py +++ b/animatediff/context.py @@ -11,8 +11,10 @@ from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher +from .context_extras import ContextExtrasGroup from .utils_motion import get_sorted_list_via_attr + class ContextFuseMethod: FLAT = "flat" PYRAMID = "pyramid" @@ -76,17 +78,21 @@ def clone(self): class ContextOptionsGroup: def __init__(self): self.contexts: list[ContextOptions] = [] + self.extras = ContextExtrasGroup() self._current_context: ContextOptions = None self._current_used_steps: int = 0 self._current_index: int = 0 + self._previous_t = -1 self._step = 0 def reset(self): self._current_context = None self._current_used_steps = 0 self._current_index = 0 + self._previous_t = -1 self.step = 0 self._set_first_as_current() + self.extras.cleanup() @property def step(self): @@ -121,9 +127,10 @@ def has_index(self, index: int) -> int: def is_empty(self) -> bool: return len(self.contexts) == 0 - + def clone(self): cloned = ContextOptionsGroup() + cloned.extras = self.extras.clone() for context in self.contexts: cloned.contexts.append(context) cloned._set_first_as_current() @@ -132,9 +139,17 @@ def clone(self): def initialize_timesteps(self, model: BaseModel): for context in self.contexts: context.start_t = model.model_sampling.percent_to_sigma(context.start_percent) + self.extras.initialize_timesteps(model) + + def prepare_current(self, t: Tensor): + self.prepare_current_context(t) + self.extras.prepare_current(t) def prepare_current_context(self, t: Tensor): curr_t: float = t[0] + # if same as previous, do nothing as step already accounted for + if curr_t == self._previous_t: + return prev_index = self._current_index # if met guaranteed steps, look for next context in case need to switch if self._current_used_steps >= self._current_context.guarantee_steps: @@ -156,6 +171,8 @@ def prepare_current_context(self, t: Tensor): break # update steps current context is used self._current_used_steps += 1 + # update previous_t + self._previous_t = curr_t def _set_first_as_current(self): if len(self.contexts) > 0: @@ -620,7 +637,7 @@ def generate_context_visualization(context_opts: ContextOptionsGroup, model: Mod for i, t in enumerate(sigmas): # make context_opts reflect current step/sigma - context_opts.prepare_current_context([t]) + context_opts.prepare_current([t]) context_opts.step = start_step+i # check if context should even be active in this case diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py new file mode 100644 index 0000000..2f19856 --- /dev/null +++ b/animatediff/context_extras.py @@ -0,0 +1,477 @@ +from typing import Union +import math +import torch +from torch import Tensor + +from comfy.model_base import BaseModel + +from .utils_motion import (prepare_mask_batch, extend_to_batch_size, get_combined_multival, resize_multival, + get_sorted_list_via_attr) + + +CONTEXTREF_VERSION = 1 + + +class ContextExtra: + def __init__(self, start_percent: float, end_percent: float): + # scheduling + self.start_percent = float(start_percent) + self.start_t = 999999999.9 + self.end_percent = float(end_percent) + self.end_t = 0.0 + self.curr_t = 999999999.9 + + def initialize_timesteps(self, model: BaseModel): + self.start_t = model.model_sampling.percent_to_sigma(self.start_percent) + self.end_t = model.model_sampling.percent_to_sigma(self.end_percent) + + def prepare_current(self, t: Tensor): + self.curr_t = t[0] + + def should_run(self): + if self.curr_t > self.start_t or self.curr_t < self.end_t: + return False + return True + + def cleanup(self): + pass + + +################################ +# ContextRef +class ContextRefTune: + def __init__(self, + attn_style_fidelity=0.0, attn_ref_weight=0.0, attn_strength=0.0, + adain_style_fidelity=0.0, adain_ref_weight=0.0, adain_strength=0.0): + # attn1 + self.attn_style_fidelity = float(attn_style_fidelity) + self.attn_ref_weight = float(attn_ref_weight) + self.attn_strength = float(attn_strength) + # adain + self.adain_style_fidelity = float(adain_style_fidelity) + self.adain_ref_weight = float(adain_ref_weight) + self.adain_strength = float(adain_strength) + + def create_dict(self): + return { + "attn_style_fidelity": self.attn_style_fidelity, + "attn_ref_weight": self.attn_ref_weight, + "attn_strength": self.attn_strength, + "adain_style_fidelity": self.adain_style_fidelity, + "adain_ref_weight": self.adain_ref_weight, + "adain_strength": self.adain_strength, + } + + +class ContextRefMode: + FIRST = "first" + SLIDING = "sliding" + INDEXES = "indexes" + _LIST = [FIRST, SLIDING, INDEXES] + + def __init__(self, mode: str, sliding_width=2, indexes: set[int]=set([0])): + self.mode = mode + self.sliding_width = sliding_width + self.indexes = indexes + self.single_trigger = True + + @classmethod + def init_first(cls): + return ContextRefMode(cls.FIRST) + + @classmethod + def init_sliding(cls, sliding_width: int): + return ContextRefMode(cls.SLIDING, sliding_width=sliding_width) + + @classmethod + def init_indexes(cls, indexes: set[int]): + return ContextRefMode(cls.INDEXES, indexes=indexes) + + +class ContextRefKeyframe: + def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=None, tune_replace: ContextRefTune=None, mode_replace: ContextRefMode=None, + start_percent=0.0, guarantee_steps=1, inherit_missing=True): + self.mult = mult + self.orig_mult_multival = mult_multival + self.orig_tune_replace = tune_replace + self.orig_mode_replace = mode_replace + self.mult_multival = self.orig_mult_multival + self.tune_replace = self.orig_tune_replace + self.mode_replace = self.orig_mode_replace + # scheduling + self.start_percent = float(start_percent) + self.guarantee_steps = guarantee_steps + self.inherit_missing = inherit_missing + + def clone(self): + c = ContextRefKeyframe(mult=self.mult, mult_multival=self.orig_mult_multival, tune_replace=self.orig_tune_replace, mode_replace=self.orig_mode_replace, + start_percent=self.start_percent, guarantee_steps=self.guarantee_steps, inherit_missing=self.inherit_missing) + return c + + +class ContextRefKeyframeGroup: + def __init__(self): + self.keyframes: list[ContextRefKeyframe] = [] + self._current_keyframe: NaiveReuseKeyframe = None + self._current_used_steps: int = 0 + self._current_index: int = 0 + self._previous_t = -1 + + def reset(self): + self._current_keyframe = None + self._current_used_steps = 0 + self._current_index = 0 + self._set_first_as_current() + + def add(self, keyframe: ContextRefKeyframe): + # add to end of list, then sort + self.keyframes.append(keyframe) + self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") + self._set_first_as_current() + self._prepare_all_keyframe_vals() + + def _set_first_as_current(self): + if len(self.keyframes) > 0: + self._current_keyframe = self.keyframes[0] + else: + self._current_keyframe = None + + def _prepare_all_keyframe_vals(self): + if self.is_empty(): + return + multival = None + tune = None + mode = None + for kf in self.keyframes: + # if shouldn't inherit, clear cache + if not kf.inherit_missing: + multival = None + tune = None + mode = None + # assign cached values, if origs were None + # Mult ################# + if kf.orig_mult_multival is None: + kf.mult_multival = multival + else: + kf.mult_multival = kf.orig_mult_multival + # Tune ################# + if kf.orig_tune_replace is None: + kf.tune_replace = tune + else: + kf.tune_replace = kf.orig_tune_replace + # Mode ################# + if kf.orig_mode_replace is None: + kf.mode_replace = mode + else: + kf.mode_replace = kf.orig_mode_replace + # save new caches, in case next keyframe inherits missing + if kf.mult_multival is not None: + multival = kf.mult_multival + if kf.tune_replace is not None: + tune = kf.tune_replace + if kf.mode_replace is not None: + mode = kf.mode_replace + + def has_index(self, index: int) -> int: + return index >=0 and index < len(self.keyframes) + + def is_empty(self) -> bool: + return len(self.keyframes) == 0 + + def clone(self): + cloned = ContextRefKeyframeGroup() + for keyframe in self.keyframes: + cloned.keyframes.append(keyframe.clone()) + cloned._set_first_as_current() + cloned._prepare_all_keyframe_vals() + return cloned + + def create_list_of_dicts(self): + # for each keyframe, create a dict representing values relevant to TimestepKeyframe creation in ACN + c = [] + for kf in self.keyframes: + d = {} + # scheduling + d["start_percent"] = kf.start_percent + d["guarantee_steps"] = kf.guarantee_steps + d["inherit_missing"] = kf.inherit_missing + # values + if type(kf.mult_multival) == Tensor: + d["strength"] = kf.mult + d["mask"] = kf.mult_multival + else: + if kf.mult_multival is None: + d["strength"] = kf.mult + else: + d["strength"] = kf.mult * kf.mult_multival + d["mask"] = None + d["tune"] = kf.tune_replace + d["mode"] = kf.mode_replace + # add to list + c.append(d) + return c + + +class ContextRef(ContextExtra): + def __init__(self, start_percent: float, end_percent: float, + strength_multival: Union[float, Tensor], tune: ContextRefTune, mode: ContextRefMode, + keyframe: ContextRefKeyframeGroup=None): + super().__init__(start_percent=start_percent, end_percent=end_percent) + self.tune = tune + self.mode = mode + self.keyframe = keyframe if keyframe else ContextRefKeyframeGroup() + self.version = CONTEXTREF_VERSION + # stuff for ACN usage + self.strength = 1.0 + self.mask = None + self._strength_multival = strength_multival + self.strength_multival = strength_multival + + @property + def strength_multival(self): + return self.strength_multival + @strength_multival.setter + def strength_multival(self, value): + if value is None: + value = 1.0 + if type(value) == Tensor: + self.strength = 1.0 + self.mask = value + else: + self.strength = value + self.mask = None + self._strength_multival = value + + def should_run(self): + return super().should_run() +#-------------------------------- + + +################################ +# NaiveReuse +class NaiveReuseKeyframe: + def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=None, start_percent=0.0, guarantee_steps=1, inherit_missing=True): + self.mult = mult + self.orig_mult_multival = mult_multival + self.mult_multival = mult_multival + # scheduling + self.start_percent = float(start_percent) + self.start_t = 999999999.9 + self.guarantee_steps = guarantee_steps + self.inherit_missing = inherit_missing + + def clone(self): + c = NaiveReuseKeyframe(mult=self.mult, mult_multival=self.mult_multival, + start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) + c.start_t = self.start_t + return c + + +class NaiveReuseKeyframeGroup: + def __init__(self): + self.keyframes: list[NaiveReuseKeyframe] = [] + self._current_keyframe: NaiveReuseKeyframe = None + self._current_used_steps: int = 0 + self._current_index: int = 0 + self._previous_t = -1 + + def reset(self): + self._current_keyframe = None + self._current_used_steps = 0 + self._current_index = 0 + self._set_first_as_current() + + def add(self, keyframe: NaiveReuseKeyframe): + # add to end of list, then sort + self.keyframes.append(keyframe) + self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") + self._set_first_as_current() + self._prepare_all_keyframe_vals() + + def _set_first_as_current(self): + if len(self.keyframes) > 0: + self._current_keyframe = self.keyframes[0] + else: + self._current_keyframe = None + + def _prepare_all_keyframe_vals(self): + if self.is_empty(): + return + multival = None + for kf in self.keyframes: + # if shouldn't inherit, clear cache + if not kf.inherit_missing: + multival = None + # assign cached values, if origs were None + # Mult ################# + if kf.orig_mult_multival is None: + kf.mult_multival = multival + else: + kf.mult_multival = kf.orig_mult_multival + # save new caches, in case next keyframe inherits missing + if kf.mult_multival is not None: + multival = kf.mult_multival + + def has_index(self, index: int) -> int: + return index >=0 and index < len(self.keyframes) + + def is_empty(self) -> bool: + return len(self.keyframes) == 0 + + def clone(self): + cloned = NaiveReuseKeyframeGroup() + for keyframe in self.keyframes: + cloned.keyframes.append(keyframe) + cloned._set_first_as_current() + cloned._prepare_all_keyframe_vals() + return cloned + + def initialize_timesteps(self, model: BaseModel): + for keyframe in self.keyframes: + keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) + + def prepare_current_keyframe(self, t: Tensor): + if self.is_empty(): + return + curr_t: float = t[0] + # if curr_t same as before, do nothing as step already accounted for + if curr_t == self._previous_t: + return + prev_index = self._current_index + # if met guaranteed steps, look for next keyframe in case need to switch + if self._current_used_steps >= self._current_keyframe.guarantee_steps: + # if has next index, loop through and see if need t oswitch + if self.has_index(self._current_index+1): + for i in range(self._current_index+1, len(self.keyframes)): + eval_c = self.keyframes[i] + # check if start_t is greater or equal to curr_t + # NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling + if eval_c.start_t >= curr_t: + self._current_index = i + self._current_keyframe = eval_c + self._current_used_steps = 0 + # if guarantee_steps greater than zero, stop searching for other keyframes + if self._current_keyframe.guarantee_steps > 0: + break + # if eval_c is outside the percent range, stop looking further + else: break + # update steps current context is used + self._current_used_steps += 1 + # update previous_t + self._previous_t = curr_t + + # properties shadow those of NaiveReuseKeyframe + @property + def mult(self): + if self._current_keyframe != None: + return self._current_keyframe.mult + return 1.0 + + @property + def mult_multival(self): + if self._current_keyframe != None: + return self._current_keyframe.mult_multival + return None + + +class NaiveReuse(ContextExtra): + def __init__(self, start_percent: float, end_percent: float, weighted_mean: float, multival_opt: Union[float, Tensor]=None, naivereuse_kf: NaiveReuseKeyframeGroup=None): + super().__init__(start_percent=start_percent, end_percent=end_percent) + self.weighted_mean = weighted_mean + self.orig_multival = multival_opt + self.mask: Tensor = None + self.keyframe = naivereuse_kf if naivereuse_kf else NaiveReuseKeyframeGroup() + self._prev_keyframe = None + + def cleanup(self): + super().cleanup() + del self.mask + self.mask = None + self._prev_keyframe = None + self.keyframe.reset() + + def initialize_timesteps(self, model: BaseModel): + super().initialize_timesteps(model) + self.keyframe.initialize_timesteps(model) + + def prepare_current(self, t: Tensor): + super().prepare_current(t) + self.keyframe.prepare_current_keyframe(t) + + def get_effective_weighted_mean(self, x: Tensor, idxs: list[int]): + if self.orig_multival is None and self.keyframe.mult_multival is None: + return self.weighted_mean * self.keyframe.mult + # check if keyframe changed + keyframe_changed = False + if self.keyframe._current_keyframe != self._prev_keyframe: + keyframe_changed = True + self._prev_keyframe = self.keyframe._current_keyframe + + if type(self.orig_multival) != Tensor and type(self.keyframe.mult_multival) != Tensor: + return self.weighted_mean * self.keyframe.mult * get_combined_multival(self.orig_multival, self.keyframe.mult_multival) + + if self.mask is None or keyframe_changed or self.mask.shape[0] != x.shape[0] or self.mask.shape[-1] != x.shape[-1] or self.mask.shape[-2] != x.shape[-2]: + del self.mask + real_mult_multival = resize_multival(self.keyframe.mult_multival, batch_size=x.shape[0], height=x.shape[-1], width=x.shape[-2]) + self.mask = resize_multival(self.orig_multival, batch_size=x.shape[0], height=x.shape[-1], width=x.shape[-2]) + self.mask = get_combined_multival(self.mask, real_mult_multival) + return self.weighted_mean * self.keyframe.mult * self.mask[idxs].to(dtype=x.dtype, device=x.device) + + def should_run(self): + to_return = super().should_run() + # if keyframe has 0.0 val, should not run + if self.keyframe.mult_multival is not None and type(self.keyframe.mult_multival) != Tensor and math.isclose(self.keyframe.mult_multival, 0.0): + return False + # if weighted_mean is 0.0, then reuse will take no effect anyway + return to_return and self.weighted_mean > 0.0 and self.keyframe.mult > 0.0 +#-------------------------------- + + +class ContextExtrasGroup: + def __init__(self): + self.context_ref: ContextRef = None + self.naive_reuse: NaiveReuse = None + + def get_extras_list(self) -> list[ContextExtra]: + extras_list = [] + if self.context_ref is not None: + extras_list.append(self.context_ref) + if self.naive_reuse is not None: + extras_list.append(self.naive_reuse) + return extras_list + + def initialize_timesteps(self, model: BaseModel): + for extra in self.get_extras_list(): + extra.initialize_timesteps(model) + + def prepare_current(self, t: Tensor): + for extra in self.get_extras_list(): + extra.prepare_current(t) + + def should_run_context_ref(self): + if not self.context_ref: + return False + return self.context_ref.should_run() + + def should_run_naive_reuse(self): + if not self.naive_reuse: + return False + return self.naive_reuse.should_run() + + def add(self, extra: ContextExtra): + if type(extra) == ContextRef: + self.context_ref = extra + elif type(extra) == NaiveReuse: + self.naive_reuse = extra + else: + raise Exception(f"Unrecognized ContextExtras type: {type(extra)}") + + def cleanup(self): + for extra in self.get_extras_list(): + extra.cleanup() + + def clone(self): + cloned = ContextExtrasGroup() + cloned.context_ref = self.context_ref + cloned.naive_reuse = self.naive_reuse + return cloned diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 3e952eb..2a9f0ff 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -751,6 +751,7 @@ def __init__(self, *args, **kwargs): self.current_used_steps = 0 self.current_keyframe: ADKeyframe = None self.current_index = -1 + self.previous_t = -1 self.current_scale: Union[float, Tensor] = None self.current_effect: Union[float, Tensor] = None self.current_cameractrl_effect: Union[float, Tensor] = None @@ -803,6 +804,9 @@ def initialize_timesteps(self, model: BaseModel): def prepare_current_keyframe(self, x: Tensor, t: Tensor): curr_t: float = t[0] + # if curr_t was previous_t, then do nothing (already accounted for this step) + if curr_t == self.previous_t: + return prev_index = self.current_index # if met guaranteed steps, look for next keyframe in case need to switch if self.current_keyframe is None or self.current_used_steps >= self.current_keyframe.guarantee_steps: @@ -862,6 +866,8 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor): self.was_within_range = True # update steps current keyframe is used self.current_used_steps += 1 + # update previous_t + self.previous_t = curr_t def prepare_img_features(self, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str], latent_format): # if no img_encoder, done @@ -1006,6 +1012,7 @@ def cleanup(self): self.current_used_steps = 0 self.current_keyframe = None self.current_index = -1 + self.previous_t = -1 self.current_scale = None self.current_effect = None self.combined_scale = None diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 07ba149..69a5496 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -29,6 +29,11 @@ from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode, StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom) +from .nodes_context_extras import (SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef, + ContextRef_ModeFirst, ContextRef_ModeSliding, ContextRef_ModeIndexes, + ContextRef_TuneAttn, ContextRef_TuneAttnAdain, + ContextRef_KeyframeMultivalNode, ContextRef_KeyframeInterpolationNode, ContextRef_KeyframeFromListNode, + NaiveReuse_KeyframeMultivalNode, NaiveReuse_KeyframeInterpolationNode, NaiveReuse_KeyframeFromListNode) from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode, WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode, WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode) @@ -54,13 +59,15 @@ "ADE_MultivalDynamicFloatInput": MultivalDynamicFloatInputNode, "ADE_MultivalScaledMask": MultivalScaledMaskNode, "ADE_MultivalConvertToMask": MultivalConvertToMaskNode, + ############################################################################### + #------------------------------------------------------------------------------ # Context Opts "ADE_StandardStaticContextOptions": StandardStaticContextOptionsNode, "ADE_StandardUniformContextOptions": StandardUniformContextOptionsNode, "ADE_LoopedUniformContextOptions": LoopedUniformContextOptionsNode, "ADE_ViewsOnlyContextOptions": ViewAsContextOptionsNode, "ADE_BatchedContextOptions": BatchedContextOptionsNode, - "ADE_AnimateDiffUniformContextOptions": LegacyLoopedUniformContextOptionsNode, # Legacy + "ADE_AnimateDiffUniformContextOptions": LegacyLoopedUniformContextOptionsNode, # Legacy/Deprecated "ADE_VisualizeContextOptionsK": VisualizeContextOptionsK, "ADE_VisualizeContextOptionsKAdv": VisualizeContextOptionsKAdv, "ADE_VisualizeContextOptionsSCustom": VisualizeContextOptionsSCustom, @@ -68,6 +75,23 @@ "ADE_StandardStaticViewOptions": StandardStaticViewOptionsNode, "ADE_StandardUniformViewOptions": StandardUniformViewOptionsNode, "ADE_LoopedUniformViewOptions": LoopedUniformViewOptionsNode, + # Context Extras + "ADE_ContextExtras_Set": SetContextExtrasOnContextOptions, + "ADE_ContextExtras_ContextRef": ContextExtras_ContextRef, + "ADE_ContextExtras_ContextRef_ModeFirst": ContextRef_ModeFirst, + "ADE_ContextExtras_ContextRef_ModeSliding": ContextRef_ModeSliding, + "ADE_ContextExtras_ContextRef_ModeIndexes": ContextRef_ModeIndexes, + "ADE_ContextExtras_ContextRef_TuneAttn": ContextRef_TuneAttn, + "ADE_ContextExtras_ContextRef_TuneAttnAdain": ContextRef_TuneAttnAdain, + "ADE_ContextExtras_ContextRef_Keyframe": ContextRef_KeyframeMultivalNode, + "ADE_ContextExtras_ContextRef_KeyframeInterpolation": ContextRef_KeyframeInterpolationNode, + "ADE_ContextExtras_ContextRef_KeyframeFromList": ContextRef_KeyframeFromListNode, + "ADE_ContextExtras_NaiveReuse": ContextExtras_NaiveReuse, + "ADE_ContextExtras_NaiveReuse_Keyframe": NaiveReuse_KeyframeMultivalNode, + "ADE_ContextExtras_NaiveReuse_KeyframeInterpolation": NaiveReuse_KeyframeInterpolationNode, + "ADE_ContextExtras_NaiveReuse_KeyframeFromList": NaiveReuse_KeyframeFromListNode, + #------------------------------------------------------------------------------ + ############################################################################### # Iteration Opts "ADE_IterationOptsDefault": IterationOptionsNode, "ADE_IterationOptsFreeInit": FreeInitOptionsNode, @@ -184,13 +208,15 @@ "ADE_MultivalDynamicFloatInput": "Multival [Float List] πŸŽ­πŸ…πŸ…“", "ADE_MultivalScaledMask": "Multival Scaled Mask πŸŽ­πŸ…πŸ…“", "ADE_MultivalConvertToMask": "Multival to Mask πŸŽ­πŸ…πŸ…“", + ############################################################################### + #------------------------------------------------------------------------------ # Context Opts "ADE_StandardStaticContextOptions": "Context Optionsβ—†Standard Static πŸŽ­πŸ…πŸ…“", "ADE_StandardUniformContextOptions": "Context Optionsβ—†Standard Uniform πŸŽ­πŸ…πŸ…“", "ADE_LoopedUniformContextOptions": "Context Optionsβ—†Looped Uniform πŸŽ­πŸ…πŸ…“", "ADE_ViewsOnlyContextOptions": "Context Optionsβ—†Views Only [VRAMβ‡ˆ] πŸŽ­πŸ…πŸ…“", "ADE_BatchedContextOptions": "Context Optionsβ—†Batched [Non-AD] πŸŽ­πŸ…πŸ…“", - "ADE_AnimateDiffUniformContextOptions": "Context Optionsβ—†Looped Uniform πŸŽ­πŸ…πŸ…“", # Legacy + "ADE_AnimateDiffUniformContextOptions": "Context Optionsβ—†Looped Uniform πŸŽ­πŸ…πŸ…“", # Legacy/Deprecated "ADE_VisualizeContextOptionsK": "Visualize Context Options (K.) πŸŽ­πŸ…πŸ…“", "ADE_VisualizeContextOptionsKAdv": "Visualize Context Options (K.Adv.) πŸŽ­πŸ…πŸ…“", "ADE_VisualizeContextOptionsSCustom": "Visualize Context Options (S.Cus.) πŸŽ­πŸ…πŸ…“", @@ -198,6 +224,23 @@ "ADE_StandardStaticViewOptions": "View Optionsβ—†Standard Static πŸŽ­πŸ…πŸ…“", "ADE_StandardUniformViewOptions": "View Optionsβ—†Standard Uniform πŸŽ­πŸ…πŸ…“", "ADE_LoopedUniformViewOptions": "View Optionsβ—†Looped Uniform πŸŽ­πŸ…πŸ…“", + # Context Extras + "ADE_ContextExtras_Set": "Set Context Extras πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef": "Context Extrasβ—†ContextRef πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_ModeFirst": "ContextRef Modeβ—†First πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_ModeSliding": "ContextRef Modeβ—†Sliding πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_ModeIndexes": "ContextRef Modeβ—†Indexes πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_TuneAttn": "ContextRef Tuneβ—†Attn πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_TuneAttnAdain": "ContextRef Tuneβ—†Attn+Adain πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_Keyframe": "ContextRef Keyframe πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_KeyframeInterpolation": "ContextRef Keyframes Interp. πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_KeyframeFromList": "ContextRef Keyframes From List πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_NaiveReuse": "Context Extrasβ—†NaiveReuse πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_NaiveReuse_Keyframe": "NaiveReuse Keyframe πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_NaiveReuse_KeyframeInterpolation": "NaiveReuse Keyframes Interp. πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_NaiveReuse_KeyframeFromList": "NaiveReuse Keyframes From List πŸŽ­πŸ…πŸ…“", + #------------------------------------------------------------------------------ + ############################################################################### # Iteration Opts "ADE_IterationOptsDefault": "Default Iteration Options πŸŽ­πŸ…πŸ…“", "ADE_IterationOptsFreeInit": "FreeInit Iteration Options πŸŽ­πŸ…πŸ…“", @@ -212,7 +255,7 @@ "ADE_SetLoraHookKeyframe": "Set LoRA Hook Keyframes πŸŽ­πŸ…πŸ…“", "ADE_AttachLoraHookToCLIP": "Set CLIP LoRA Hook πŸŽ­πŸ…πŸ…“", "ADE_LoraHookKeyframe": "LoRA Hook Keyframe πŸŽ­πŸ…πŸ…“", - "ADE_LoraHookKeyframeInterpolation": "LoRA Hook Keyframes Interpolation πŸŽ­πŸ…πŸ…“", + "ADE_LoraHookKeyframeInterpolation": "LoRA Hook Keyframes Interp. πŸŽ­πŸ…πŸ…“", "ADE_LoraHookKeyframeFromStrengthList": "LoRA Hook Keyframes From List πŸŽ­πŸ…πŸ…“", "ADE_AttachLoraHookToConditioning": "Set Model LoRA Hook πŸŽ­πŸ…πŸ…“", "ADE_PairedConditioningSetMask": "Set Props on Conds πŸŽ­πŸ…πŸ…“", @@ -244,7 +287,7 @@ "ADE_CustomCFG": "Custom CFG [Multival] πŸŽ­πŸ…πŸ…“", "ADE_CustomCFGKeyframeSimple": "Custom CFG Keyframe πŸŽ­πŸ…πŸ…“", "ADE_CustomCFGKeyframe": "Custom CFG Keyframe [Multival] πŸŽ­πŸ…πŸ…“", - "ADE_CustomCFGKeyframeInterpolation": "Custom CFG Keyframes Interpolation πŸŽ­πŸ…πŸ…“", + "ADE_CustomCFGKeyframeInterpolation": "Custom CFG Keyframes Interp. πŸŽ­πŸ…πŸ…“", "ADE_CustomCFGKeyframeFromList": "Custom CFG Keyframes From List πŸŽ­πŸ…πŸ…“", "ADE_CFGExtrasPAGSimple": "CFG Extrasβ—†PAG πŸŽ­πŸ…πŸ…“", "ADE_CFGExtrasPAG": "CFG Extrasβ—†PAG [Multival] πŸŽ­πŸ…πŸ…“", @@ -253,7 +296,7 @@ "ADE_SigmaSchedule": "Create Sigma Schedule πŸŽ­πŸ…πŸ…“", "ADE_RawSigmaSchedule": "Create Raw Sigma Schedule πŸŽ­πŸ…πŸ…“", "ADE_SigmaScheduleWeightedAverage": "Sigma Schedule Weighted Mean πŸŽ­πŸ…πŸ…“", - "ADE_SigmaScheduleWeightedAverageInterp": "Sigma Schedule Interpolated Mean πŸŽ­πŸ…πŸ…“", + "ADE_SigmaScheduleWeightedAverageInterp": "Sigma Schedule Interp. Mean πŸŽ­πŸ…πŸ…“", "ADE_SigmaScheduleSplitAndCombine": "Sigma Schedule Split Combine πŸŽ­πŸ…πŸ…“", "ADE_SigmaScheduleToSigmas": "Sigma Schedule To Sigmas πŸŽ­πŸ…πŸ…“", "ADE_NoisedImageInjection": "Image Injection πŸŽ­πŸ…πŸ…“", diff --git a/animatediff/nodes_conditioning.py b/animatediff/nodes_conditioning.py index 9c694b2..57f1233 100644 --- a/animatediff/nodes_conditioning.py +++ b/animatediff/nodes_conditioning.py @@ -337,6 +337,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_hook_kf": ("LORA_HOOK_KEYFRAMES",), + "autosize": ("ADEAUTOSIZE", {"padding": 70}), } } diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index a4ca94c..fd42683 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -1,11 +1,12 @@ -import torch from torch import Tensor +from typing import Union import comfy.samplers from comfy.model_patcher import ModelPatcher from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules, generate_context_visualization) +from .model_injection import ModelPatcherAndInjector from .utils_model import BIGMAX, MAX_RESOLUTION @@ -378,7 +379,7 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, + def visualize(self, model: ModelPatcherAndInjector, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, visual_width: 1280, latents_length=32, steps=20, start_step=0, end_step=20): images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sampler_name=sampler_name, scheduler=scheduler, @@ -408,7 +409,7 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, + def visualize(self, model: ModelPatcherAndInjector, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, visual_width: 1280, latents_length=32, steps=20, denoise=1.0): images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sampler_name=sampler_name, scheduler=scheduler, @@ -435,7 +436,7 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sigmas, + def visualize(self, model: ModelPatcherAndInjector, context_opts: ContextOptionsGroup, sigmas, visual_width: 1280, latents_length=32): images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sigmas=sigmas) diff --git a/animatediff/nodes_context_extras.py b/animatediff/nodes_context_extras.py new file mode 100644 index 0000000..8a28e8d --- /dev/null +++ b/animatediff/nodes_context_extras.py @@ -0,0 +1,512 @@ +from torch import Tensor +from typing import Union +from collections.abc import Iterable + +from .context import (ContextOptionsGroup) +from .context_extras import (ContextExtrasGroup, + ContextRef, ContextRefTune, ContextRefMode, ContextRefKeyframeGroup, ContextRefKeyframe, + NaiveReuse, NaiveReuseKeyframe, NaiveReuseKeyframeGroup) +from .utils_model import BIGMAX, InterpolationMethod +from .utils_scheduling import convert_str_to_indexes +from .logger import logger + + +class SetContextExtrasOnContextOptions: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "context_opts": ("CONTEXT_OPTIONS",), + "context_extras": ("CONTEXT_EXTRAS",), + }, + "optional": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("CONTEXT_OPTIONS",) + RETURN_NAMES = ("CONTEXT_OPTS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" + FUNCTION = "set_context_extras" + + def set_context_extras(self, context_opts: ContextOptionsGroup, context_extras: ContextExtrasGroup): + context_opts = context_opts.clone() + context_opts.extras = context_extras.clone() + return (context_opts,) + + +######################################### +# NaiveReuse +class ContextExtras_NaiveReuse: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_extras": ("CONTEXT_EXTRAS",), + "strength_multival": ("MULTIVAL",), + "naivereuse_kf": ("NAIVEREUSE_KEYFRAME",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), + "weighted_mean": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.001}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("CONTEXT_EXTRAS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" + FUNCTION = "create_context_extra" + + def create_context_extra(self, start_percent=0.0, end_percent=0.1, weighted_mean=0.95, strength_multival: Union[float, Tensor]=None, + naivereuse_kf: NaiveReuseKeyframeGroup=None, prev_extras: ContextExtrasGroup=None): + if prev_extras is None: + prev_extras = prev_extras = ContextExtrasGroup() + prev_extras = prev_extras.clone() + # create extra + naive_reuse = NaiveReuse(start_percent=start_percent, end_percent=end_percent, weighted_mean=weighted_mean, multival_opt=strength_multival, + naivereuse_kf=naivereuse_kf) + prev_extras.add(naive_reuse) + return (prev_extras,) + + +class NaiveReuse_KeyframeMultivalNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_kf": ("NAIVEREUSE_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + "inherit_missing": ("BOOLEAN", {"default": True}, ), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("NAIVEREUSE_KEYFRAME",) + RETURN_NAMES = ("NAIVEREUSE_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" + FUNCTION = "create_keyframe" + + def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=None, + start_percent=0.0, guarantee_steps=1, inherit_missing=True): + if prev_kf is None: + prev_kf = NaiveReuseKeyframeGroup() + prev_kf = prev_kf.clone() + kf = NaiveReuseKeyframe(mult=mult, mult_multival=mult_multival, + start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing) + prev_kf.add(kf) + return (prev_kf,) + + +class NaiveReuse_KeyframeInterpolationNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "mult_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "mult_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "interpolation": (InterpolationMethod._LIST, ), + "intervals": ("INT", {"default": 50, "min": 2, "max": 100, "step": 1}), + "inherit_missing": ("BOOLEAN", {"default": True}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_kf": ("NAIVEREUSE_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "autosize": ("ADEAUTOSIZE", {"padding": 50}), + } + } + + RETURN_TYPES = ("NAIVEREUSE_KEYFRAME",) + RETURN_NAMES = ("NAIVEREUSE_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" + FUNCTION = "create_keyframe" + + def create_keyframe(self, + start_percent: float, end_percent: float, + mult_start: float, mult_end: float, interpolation: str, intervals: int, + inherit_missing=True, prev_kf: NaiveReuseKeyframeGroup=None, + mult_multival=None, print_keyframes=False): + if prev_kf is None: + prev_kf = NaiveReuseKeyframeGroup() + prev_kf = prev_kf.clone() + prev_kf = prev_kf.clone() + percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=intervals, method=InterpolationMethod.LINEAR) + mults = InterpolationMethod.get_weights(num_from=mult_start, num_to=mult_end, length=intervals, method=interpolation) + + is_first = True + for percent, mult in zip(percents, mults): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_kf.add(NaiveReuseKeyframe(mult=mult, mult_multival=mult_multival, + start_percent=percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing)) + if print_keyframes: + logger.info(f"NaiveReuseKeyframe - start_percent:{percent} = {mult}") + return (prev_kf,) + + +class NaiveReuse_KeyframeFromListNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mults_float": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "inherit_missing": ("BOOLEAN", {"default": True}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_kf": ("NAIVEREUSE_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("NAIVEREUSE_KEYFRAME",) + RETURN_NAMES = ("NAIVEREUSE_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" + FUNCTION = "create_keyframe" + + def create_keyframe(self, mults_float: Union[float, list[float]], + start_percent: float, end_percent: float, + inherit_missing=True, prev_kf: NaiveReuseKeyframeGroup=None, + mult_multival=None, print_keyframes=False): + if prev_kf is None: + prev_kf = NaiveReuseKeyframeGroup() + prev_kf = prev_kf.clone() + if type(mults_float) in (float, int): + mults_float = [float(mults_float)] + elif isinstance(mults_float, Iterable): + pass + else: + raise Exception(f"strengths_float must be either an interable input or a float, but was {type(mults_float).__repr__}.") + percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(mults_float), method=InterpolationMethod.LINEAR) + + is_first = True + for percent, mult in zip(percents, mults_float): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_kf.add(NaiveReuseKeyframe(mult=mult, mult_multival=mult_multival, + start_percent=percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing)) + if print_keyframes: + logger.info(f"NaiveReuseKeyframe - start_percent:{percent} = {mult}") + return (prev_kf,) +#---------------------------------------- +######################################### + + +######################################### +# ContextRef +class ContextExtras_ContextRef: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_extras": ("CONTEXT_EXTRAS",), + "strength_multival": ("MULTIVAL",), + "contextref_mode": ("CONTEXTREF_MODE",), + "contextref_tune": ("CONTEXTREF_TUNE",), + "contextref_kf": ("CONTEXTREF_KEYFRAME",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("CONTEXT_EXTRAS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" + FUNCTION = "create_context_extra" + + def create_context_extra(self, start_percent=0.0, end_percent=0.1, strength_multival: Union[float, Tensor]=None, + contextref_mode: ContextRefMode=None, contextref_tune: ContextRefTune=None, + contextref_kf: ContextRefKeyframeGroup=None, prev_extras: ContextExtrasGroup=None): + if prev_extras is None: + prev_extras = prev_extras = ContextExtrasGroup() + prev_extras = prev_extras.clone() + # create extra + # TODO: make customizable, and allow mask input + if contextref_tune is None: + contextref_tune = ContextRefTune(attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0) + if contextref_mode is None: + contextref_mode = ContextRefMode.init_first() + context_ref = ContextRef(start_percent=start_percent, end_percent=end_percent, + strength_multival=strength_multival, tune=contextref_tune, mode=contextref_mode, + keyframe=contextref_kf) + prev_extras.add(context_ref) + return (prev_extras,) + + +class ContextRef_KeyframeMultivalNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_kf": ("CONTEXTREF_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "mode_replace": ("CONTEXTREF_MODE",), + "tune_replace": ("CONTEXTREF_TUNE",), + "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + "inherit_missing": ("BOOLEAN", {"default": True}, ), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("CONTEXTREF_KEYFRAME",) + RETURN_NAMES = ("CONTEXTREF_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_keyframe" + + def create_keyframe(self, prev_kf: ContextRefKeyframeGroup=None, + mult=1.0, mult_multival=None, mode_replace=None, tune_replace=None, + start_percent=1.0, guarantee_steps=1, inherit_missing=True): + if prev_kf is None: + prev_kf = ContextRefKeyframeGroup() + prev_kf = prev_kf.clone() + kf = ContextRefKeyframe(mult=mult, mult_multival=mult_multival, tune_replace=tune_replace, mode_replace=mode_replace, + start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing) + prev_kf.add(kf) + return (prev_kf,) + + +class ContextRef_KeyframeInterpolationNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "mult_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "mult_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "interpolation": (InterpolationMethod._LIST, ), + "intervals": ("INT", {"default": 50, "min": 2, "max": 100, "step": 1}), + "inherit_missing": ("BOOLEAN", {"default": True}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_kf": ("CONTEXTREF_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "mode_replace": ("CONTEXTREF_MODE",), + "tune_replace": ("CONTEXTREF_TUNE",), + "autosize": ("ADEAUTOSIZE", {"padding": 50}), + } + } + + RETURN_TYPES = ("CONTEXTREF_KEYFRAME",) + RETURN_NAMES = ("CONTEXTREF_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_keyframe" + + def create_keyframe(self, + start_percent: float, end_percent: float, + mult_start: float, mult_end: float, interpolation: str, intervals: int, + inherit_missing=True, prev_kf: ContextRefKeyframeGroup=None, + mult_multival=None, mode_replace=None, tune_replace=None, print_keyframes=False): + if prev_kf is None: + prev_kf = ContextRefKeyframeGroup() + prev_kf = prev_kf.clone() + percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=intervals, method=InterpolationMethod.LINEAR) + mults = InterpolationMethod.get_weights(num_from=mult_start, num_to=mult_end, length=intervals, method=interpolation) + + is_first = True + for percent, mult in zip(percents, mults): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_kf.add(ContextRefKeyframe(mult=mult, mult_multival=mult_multival, tune_replace=tune_replace, mode_replace=mode_replace, + start_percent=percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing)) + if print_keyframes: + logger.info(f"ContextRefKeyframe - start_percent:{percent} = {mult}") + return (prev_kf,) + + +class ContextRef_KeyframeFromListNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mults_float": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "inherit_missing": ("BOOLEAN", {"default": True}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_kf": ("CONTEXTREF_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "mode_replace": ("CONTEXTREF_MODE",), + "tune_replace": ("CONTEXTREF_TUNE",), + "autosize": ("ADEAUTOSIZE", {"padding": 50}), + } + } + + RETURN_TYPES = ("CONTEXTREF_KEYFRAME",) + RETURN_NAMES = ("CONTEXTREF_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_keyframe" + + def create_keyframe(self, mults_float: Union[float, list[float]], + start_percent: float, end_percent: float, + inherit_missing=True, prev_kf: ContextRefKeyframeGroup=None, + mult_multival=None, mode_replace=None, tune_replace=None, print_keyframes=False): + if prev_kf is None: + prev_kf = ContextRefKeyframeGroup() + prev_kf = prev_kf.clone() + if type(mults_float) in (float, int): + mults_float = [float(mults_float)] + elif isinstance(mults_float, Iterable): + pass + else: + raise Exception(f"strengths_float must be either an interable input or a float, but was {type(mults_float).__repr__}.") + percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(mults_float), method=InterpolationMethod.LINEAR) + + is_first = True + for percent, mult in zip(percents, mults_float): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_kf.add(ContextRefKeyframe(mult=mult, mult_multival=mult_multival, tune_replace=tune_replace, mode_replace=mode_replace, + start_percent=percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing)) + if print_keyframes: + logger.info(f"ContextRefKeyframe - start_percent:{percent} = {mult}") + return (prev_kf,) + + +class ContextRef_ModeFirst: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "autosize": ("ADEAUTOSIZE", {"padding": 25}), + }, + } + + RETURN_TYPES = ("CONTEXTREF_MODE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_contextref_mode" + + def create_contextref_mode(self): + mode = ContextRefMode.init_first() + return (mode,) + + +class ContextRef_ModeSliding: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "sliding_width": ("INT", {"default": 2, "min": 2, "max": BIGMAX, "step": 1}), + "autosize": ("ADEAUTOSIZE", {"padding": 42}), + } + } + + RETURN_TYPES = ("CONTEXTREF_MODE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_contextref_mode" + + def create_contextref_mode(self, sliding_width): + mode = ContextRefMode.init_sliding(sliding_width=sliding_width) + return (mode,) + + +class ContextRef_ModeIndexes: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "switch_on_idxs": ("STRING", {"default": ""}), + "always_include_0": ("BOOLEAN", {"default": True},), + "autosize": ("ADEAUTOSIZE", {"padding": 50}), + }, + } + + RETURN_TYPES = ("CONTEXTREF_MODE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_contextref_mode" + + def create_contextref_mode(self, switch_on_idxs: str, always_include_0: bool): + idxs = set(convert_str_to_indexes(indexes_str=switch_on_idxs, length=0, allow_range=False)) + if always_include_0 and 0 not in idxs: + idxs.add(0) + mode = ContextRefMode.init_indexes(indexes=idxs) + return (mode,) + + +class ContextRef_TuneAttnAdain: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "attn_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "adain_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "adain_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "adain_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "autosize": ("ADEAUTOSIZE", {"padding": 65}), + } + } + + RETURN_TYPES = ("CONTEXTREF_TUNE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_contextref_tune" + + def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0, + adain_style_fidelity=1.0, adain_ref_weight=1.0, adain_strength=1.0): + params = ContextRefTune(attn_style_fidelity=attn_style_fidelity, adain_style_fidelity=adain_style_fidelity, + attn_ref_weight=attn_ref_weight, adain_ref_weight=adain_ref_weight, + attn_strength=attn_strength, adain_strength=adain_strength) + return (params,) + + +class ContextRef_TuneAttn: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "attn_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "autosize": ("ADEAUTOSIZE", {"padding": 15}), + } + } + + RETURN_TYPES = ("CONTEXTREF_TUNE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_contextref_tune" + + def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0): + return ContextRef_TuneAttnAdain.create_contextref_tune(self, + attn_style_fidelity=attn_style_fidelity, attn_ref_weight=attn_ref_weight, attn_strength=attn_strength, + adain_ref_weight=0.0, adain_style_fidelity=0.0, adain_strength=0.0) +#---------------------------------------- +######################################### diff --git a/animatediff/nodes_multival.py b/animatediff/nodes_multival.py index 1149104..4579e17 100644 --- a/animatediff/nodes_multival.py +++ b/animatediff/nodes_multival.py @@ -4,7 +4,7 @@ import torch from torch import Tensor -from .utils_motion import linear_conversion, normalize_min_max, extend_to_batch_size, extend_list_to_batch_size +from .utils_motion import create_multival_combo, linear_conversion, normalize_min_max, extend_to_batch_size, extend_list_to_batch_size class ScaleType: @@ -31,39 +31,7 @@ def INPUT_TYPES(s): FUNCTION = "create_multival" def create_multival(self, float_val: Union[float, list[float]]=1.0, mask_optional: Tensor=None): - # first, normalize inputs - # if float_val is iterable, treat as a list and assume inputs are floats - float_is_iterable = False - if isinstance(float_val, Iterable): - float_is_iterable = True - float_val = list(float_val) - # if mask present, make sure float_val list can be applied to list - match lengths - if mask_optional is not None: - if len(float_val) < mask_optional.shape[0]: - # copies last entry enough times to match mask shape - float_val = extend_list_to_batch_size(float_val, mask_optional.shape[0]) - if mask_optional.shape[0] < len(float_val): - mask_optional = extend_to_batch_size(mask_optional, len(float_val)) - float_val = float_val[:mask_optional.shape[0]] - float_val: Tensor = torch.tensor(float_val).unsqueeze(-1).unsqueeze(-1) - # now that inputs are normalized, figure out what value to actually return - if mask_optional is not None: - mask_optional = mask_optional.clone() - if float_is_iterable: - mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) - else: - mask_optional = mask_optional * float_val - return (mask_optional,) - else: - if not float_is_iterable: - return (float_val,) - # create a dummy mask of b,h,w=float_len,1,1 (sigle pixel) - # purpose is for float input to work with mask code, without special cases - float_len = float_val.shape[0] if float_is_iterable else 1 - shape = (float_len,1,1) - mask_optional = torch.ones(shape) - mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) - return (mask_optional,) + return (create_multival_combo(float_val=float_val, mask_optional=mask_optional),) class MultivalScaledMaskNode: diff --git a/animatediff/nodes_sample.py b/animatediff/nodes_sample.py index d988a62..70f4f1c 100644 --- a/animatediff/nodes_sample.py +++ b/animatediff/nodes_sample.py @@ -337,6 +337,7 @@ def INPUT_TYPES(s): "optional": { "prev_custom_cfg": ("CUSTOM_CFG",), "cfg_extras": ("CFG_EXTRAS",), + "autosize": ("ADEAUTOSIZE", {"padding": 70}), } } diff --git a/animatediff/nodes_sigma_schedule.py b/animatediff/nodes_sigma_schedule.py index 20580a2..828e2ca 100644 --- a/animatediff/nodes_sigma_schedule.py +++ b/animatediff/nodes_sigma_schedule.py @@ -101,6 +101,9 @@ def INPUT_TYPES(s): "weight_A_Start": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}), "weight_A_End": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}), "interpolation": (InterpolationMethod._LIST,), + }, + "optional": { + "autosize": ("ADEAUTOSIZE", {"padding": 70}), } } diff --git a/animatediff/sample_settings.py b/animatediff/sample_settings.py index e14dc72..bff8445 100644 --- a/animatediff/sample_settings.py +++ b/animatediff/sample_settings.py @@ -546,6 +546,7 @@ def __init__(self): self._current_keyframe: CustomCFGKeyframe = None self._current_used_steps: int = 0 self._current_index: int = 0 + self._previous_t = -1 def reset(self): self._current_keyframe = None @@ -584,6 +585,9 @@ def initialize_timesteps(self, model: BaseModel): def prepare_current_keyframe(self, t: Tensor): curr_t: float = t[0] + # if curr_t same as before, do nothing as step already accounted for + if curr_t == self._previous_t: + return prev_index = self._current_index # if met guaranteed steps, look for next keyframe in case need to switch if self._current_used_steps >= self._current_keyframe.guarantee_steps: @@ -604,6 +608,8 @@ def prepare_current_keyframe(self, t: Tensor): else: break # update steps current context is used self._current_used_steps += 1 + # update previous_t + self._previous_t = curr_t def get_cfg_scale(self, cond: Tensor): cond_scale = self.cfg_multival diff --git a/animatediff/sampling.py b/animatediff/sampling.py index dc743fb..6a73aea 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -26,8 +26,9 @@ from .conditioning import COND_CONST, LoraHookGroup, conditioning_set_values from .context import ContextFuseMethod, ContextSchedules, get_context_weights, get_context_windows +from .context_extras import ContextRefMode from .sample_settings import IterationOptions, SampleSettings, SeedNoiseGeneration, NoisedImageToInject -from .utils_model import ModelTypeSD, vae_encode_raw_batched, vae_decode_raw_batched +from .utils_model import ModelTypeSD, MachineState, vae_encode_raw_batched, vae_decode_raw_batched from .utils_motion import composite_extend, get_combined_multival, prepare_mask_batch, extend_to_batch_size from .model_injection import InjectionParams, ModelPatcherAndInjector, MotionModelGroup, MotionModelPatcher from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion, VanillaTemporalModule @@ -71,7 +72,7 @@ def prepare_current_keyframes(self, x: Tensor, timestep: Tensor): if self.motion_models is not None: self.motion_models.prepare_current_keyframe(x=x, t=timestep) if self.params.context_options is not None: - self.params.context_options.prepare_current_context(t=timestep) + self.params.context_options.prepare_current(t=timestep) if self.sample_settings.custom_cfg is not None: self.sample_settings.custom_cfg.prepare_current_keyframe(t=timestep) @@ -114,6 +115,7 @@ def reset(self): del self.motion_models self.motion_models = None if self.params is not None: + self.params.context_options.reset() del self.params self.params = None if self.sample_settings is not None: @@ -403,6 +405,28 @@ def __exit__(self, *args, **kwargs): self.previous_dwi_gn_cast_weights = None +class ContextRefInjector: + def __init__(self): + self.orig_can_concat_cond = None + + def inject(self): + self.orig_can_concat_cond = comfy.samplers.can_concat_cond + comfy.samplers.can_concat_cond = ContextRefInjector.can_concat_cond_contextref_factory(self.orig_can_concat_cond) + + def restore(self): + if self.orig_can_concat_cond is not None: + comfy.samplers.can_concat_cond = self.orig_can_concat_cond + + @staticmethod + def can_concat_cond_contextref_factory(orig_func: Callable): + def can_concat_cond_contextref_injection(c1, c2, *args, **kwargs): + #return orig_func(c1, c2, *args, **kwargs) + if c1 is c2: + return True + return False + return can_concat_cond_contextref_injection + + def motion_sample_factory(orig_comfy_sample: Callable, is_custom: bool=False) -> Callable: def motion_sample(model: ModelPatcherAndInjector, noise: Tensor, *args, **kwargs): # check if model is intended for injecting @@ -622,8 +646,10 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, # add AD/evolved-sampling params to model_options (transformer_options) model_options = model_options.copy() - if "tranformer_options" not in model_options: - model_options["tranformer_options"] = {} + if "transformer_options" not in model_options: + model_options["transformer_options"] = {} + else: + model_options["transformer_options"] = model_options["transformer_options"].copy() model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params() if not ADGS.is_using_sliding_context(): @@ -792,61 +818,160 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list # get context windows ADGS.params.context_options.step = ADGS.current_step context_windows = get_context_windows(ADGS.params.full_length, ADGS.params.context_options) - # figure out how input is split - batched_conds = x_in.size(0)//ADGS.params.full_length if ADGS.motion_models is not None: ADGS.motion_models.set_view_options(ADGS.params.context_options.view_options) # prepare final conds, out_counts, and biases conds_final = [torch.zeros_like(x_in) for _ in conds] - counts_final = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] + if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: + # counts_final not used for RELATIVE fuse_method + counts_final = [torch.ones((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] + else: + # default counts_final initialization + counts_final = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] biases_final = [([0.0] * x_in.shape[0]) for _ in conds] - # perform calc_conds_batch per context window - for ctx_idxs in context_windows: - ADGS.params.sub_idxs = ctx_idxs - if ADGS.motion_models is not None: - ADGS.motion_models.set_sub_idxs(ctx_idxs) - ADGS.motion_models.set_video_length(len(ctx_idxs), ADGS.params.full_length) - # update exposed params - model_options["transformer_options"]["ad_params"]["sub_idxs"] = ctx_idxs - model_options["transformer_options"]["ad_params"]["context_length"] = len(ctx_idxs) - # account for all portions of input frames - full_idxs = [] - for n in range(batched_conds): - for ind in ctx_idxs: - full_idxs.append((ADGS.params.full_length*n)+ind) - # get subsections of x, timestep, conds - sub_x = x_in[full_idxs] - sub_timestep = timestep[full_idxs] - sub_conds = [get_resized_cond(cond, full_idxs, len(ctx_idxs)) for cond in conds] - - sub_conds_out = calc_cond_uncond_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options) - - if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: - full_length = ADGS.params.full_length - for pos, idx in enumerate(ctx_idxs): - # bias is the influence of a specific index in relation to the whole context window - bias = 1 - abs(idx - (ctx_idxs[0] + ctx_idxs[-1]) / 2) / ((ctx_idxs[-1] - ctx_idxs[0] + 1e-2) / 2) - bias = max(1e-2, bias) - # take weighted average relative to total bias of current idx - # and account for batched_conds - for i in range(len(sub_conds_out)): - for n in range(batched_conds): - bias_total = biases_final[i][(full_length*n)+idx] + CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all" + CONTEXTREF_MACHINE_STATE = "contextref_machine_state" + CONTEXTREF_CLEAN_FUNC = "contextref_clean_func" + contextref_active = False + contextref_injector = None + contextref_mode = None + contextref_idxs_set = None + first_context = True + # need to make sure that contextref stuff gets cleaned up, no matter what + try: + if ADGS.params.context_options.extras.should_run_context_ref(): + # check that ACN provided ContextRef as requested + temp_refcn_list = model_options["transformer_options"].get(CONTEXTREF_CONTROL_LIST_ALL, None) + if temp_refcn_list is None: + raise Exception("Advanced-ControlNet nodes are either missing or too outdated to support ContextRef. Update/install ComfyUI-Advanced-ControlNet to use ContextRef.") + if len(temp_refcn_list) == 0: + raise Exception("Unexpected ContextRef issue; Advanced-ControlNet did not provide any ContextRef objs for AnimateDiff-Evolved.") + del temp_refcn_list + # check if ContextRef ReferenceAdvanced ACN objs should_run + actually_should_run = True + for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: + refcn.prepare_current_timestep(timestep) + if not refcn.should_run(): + actually_should_run = False + if actually_should_run: + contextref_active = True + for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: + # get mode_override if present, mode otherwise + contextref_mode = refcn.get_contextref_mode_replace() or ADGS.params.context_options.extras.context_ref.mode + contextref_idxs_set = contextref_mode.indexes.copy() + # use injector to ensure only 1 cond or uncond will be batched at a time + contextref_injector = ContextRefInjector() + contextref_injector.inject() + + curr_window_idx = -1 + naivereuse_active = False + cached_naive_conds = None + cached_naive_ctx_idxs = None + if ADGS.params.context_options.extras.should_run_naive_reuse(): + cached_naive_conds = [torch.zeros_like(x_in) for _ in conds] + #cached_naive_counts = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] + naivereuse_active = True + # perform calc_conds_batch per context window + for ctx_idxs in context_windows: + # allow processing to end between context window executions for faster Cancel + comfy.model_management.throw_exception_if_processing_interrupted() + curr_window_idx += 1 + ADGS.params.sub_idxs = ctx_idxs + if ADGS.motion_models is not None: + ADGS.motion_models.set_sub_idxs(ctx_idxs) + ADGS.motion_models.set_video_length(len(ctx_idxs), ADGS.params.full_length) + # update exposed params + model_options["transformer_options"]["ad_params"]["sub_idxs"] = ctx_idxs + model_options["transformer_options"]["ad_params"]["context_length"] = len(ctx_idxs) + # get subsections of x, timestep, conds + sub_x = x_in[ctx_idxs] + sub_timestep = timestep[ctx_idxs] + sub_conds = [get_resized_cond(cond, ctx_idxs, len(ctx_idxs)) for cond in conds] + + if contextref_active: + # set cond counter to 0 (each cond encountered will increment it by 1) + for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: + refcn.contextref_cond_idx = 0 + if first_context: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.WRITE + else: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ + if contextref_mode.mode == ContextRefMode.SLIDING: # if sliding, check if time to READ and WRITE + if curr_window_idx % (contextref_mode.sliding_width-1) == 0: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ_WRITE + # override with indexes mode, if set + if contextref_mode.mode == ContextRefMode.INDEXES: + contains_idx = False + for i in ctx_idxs: + if i in contextref_idxs_set: + contains_idx = True + # single trigger decides if each index should only trigger READ_WRITE once per step + if not contextref_mode.single_trigger: + break + contextref_idxs_set.remove(i) + if contains_idx: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ_WRITE + if first_context: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.WRITE + else: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ + else: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.OFF + #logger.info(f"window: {curr_window_idx} - {model_options['transformer_options'][CONTEXTREF_MACHINE_STATE]}") + + sub_conds_out = calc_cond_uncond_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options) + + if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: + full_length = ADGS.params.full_length + for pos, idx in enumerate(ctx_idxs): + # bias is the influence of a specific index in relation to the whole context window + bias = 1 - abs(idx - (ctx_idxs[0] + ctx_idxs[-1]) / 2) / ((ctx_idxs[-1] - ctx_idxs[0] + 1e-2) / 2) + bias = max(1e-2, bias) + # take weighted average relative to total bias of current idx + for i in range(len(sub_conds_out)): + bias_total = biases_final[i][idx] prev_weight = (bias_total / (bias_total + bias)) new_weight = (bias / (bias_total + bias)) - conds_final[i][(full_length*n)+idx] = conds_final[i][(full_length*n)+idx] * prev_weight + sub_conds_out[i][(full_length*n)+pos] * new_weight - biases_final[i][(full_length*n)+idx] = bias_total + bias - else: - # add conds and counts based on weights of fuse method - weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method, sigma=timestep) * batched_conds - weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - for i in range(len(sub_conds_out)): - conds_final[i][full_idxs] += sub_conds_out[i] * weights_tensor - counts_final[i][full_idxs] += weights_tensor - + conds_final[i][idx] = conds_final[i][idx] * prev_weight + sub_conds_out[i][pos] * new_weight + biases_final[i][idx] = bias_total + bias + else: + # add conds and counts based on weights of fuse method + weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method, sigma=timestep) + weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + for i in range(len(sub_conds_out)): + conds_final[i][ctx_idxs] += sub_conds_out[i] * weights_tensor + counts_final[i][ctx_idxs] += weights_tensor + # handle NaiveReuse + if naivereuse_active: + cached_naive_ctx_idxs = ctx_idxs + for i in range(len(sub_conds)): + cached_naive_conds[i][ctx_idxs] = conds_final[i][ctx_idxs] / counts_final[i][ctx_idxs] + naivereuse_active = False + # toggle first_context off, if needed + if first_context: + first_context = False + finally: + # clean contextref stuff with provided ACN function, if applicable + if contextref_active: + model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC]() + contextref_injector.restore() + + # handle NaiveReuse + if cached_naive_conds is not None: + start_idx = cached_naive_ctx_idxs[0] + for z in range(0, ADGS.params.full_length, len(cached_naive_ctx_idxs)): + for i in range(len(cached_naive_conds)): + # get the 'true' idxs of this window + new_ctx_idxs = [(zz+start_idx) % ADGS.params.full_length for zz in list(range(z, z+len(cached_naive_ctx_idxs))) if zz < ADGS.params.full_length] + # make sure when getting cached_naive idxs, they are adjusted for actual length leftover length + adjusted_naive_ctx_idxs = cached_naive_ctx_idxs[:len(new_ctx_idxs)] + weighted_mean = ADGS.params.context_options.extras.naive_reuse.get_effective_weighted_mean(x_in, new_ctx_idxs) + conds_final[i][new_ctx_idxs] = (weighted_mean * (cached_naive_conds[i][adjusted_naive_ctx_idxs]*counts_final[i][new_ctx_idxs])) + ((1.-weighted_mean) * conds_final[i][new_ctx_idxs]) + del cached_naive_conds + if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: # already normalized, so return as is del counts_final diff --git a/animatediff/scheduling.py b/animatediff/scheduling.py new file mode 100644 index 0000000..e69de29 diff --git a/animatediff/utils_model.py b/animatediff/utils_model.py index 0e62d00..48c432b 100644 --- a/animatediff/utils_model.py +++ b/animatediff/utils_model.py @@ -26,6 +26,12 @@ MAX_RESOLUTION = 16384 # mirrors ComfyUI's nodes.py MAX_RESOLUTION +class MachineState: + READ = "read" + WRITE = "write" + READ_WRITE = "read_write" + OFF = "off" + def vae_encode_raw_dynamic_batched(vae: VAE, pixels: Tensor, max_batch=16, min_batch=1, max_size=512*512, show_pbar=False): b, h, w, c = pixels.shape diff --git a/animatediff/utils_motion.py b/animatediff/utils_motion.py index fcc205a..501a7f6 100644 --- a/animatediff/utils_motion.py +++ b/animatediff/utils_motion.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from torch import Tensor, nn from abc import ABC, abstractmethod +from collections.abc import Iterable import comfy.model_management as model_management import comfy.ops @@ -238,23 +239,65 @@ def get_mask(self, x: Tensor): return mask * self.multival -def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor]) -> Union[float, Tensor]: +def create_multival_combo(float_val: Union[float, list[float]], mask_optional: Tensor=None): + # first, normalize inputs + # if float_val is iterable, treat as a list and assume inputs are floats + float_is_iterable = False + if isinstance(float_val, Iterable): + float_is_iterable = True + float_val = list(float_val) + # if mask present, make sure float_val list can be applied to list - match lengths + if mask_optional is not None: + if len(float_val) < mask_optional.shape[0]: + # copies last entry enough times to match mask shape + float_val = extend_list_to_batch_size(float_val, mask_optional.shape[0]) + if mask_optional.shape[0] < len(float_val): + mask_optional = extend_to_batch_size(mask_optional, len(float_val)) + float_val = float_val[:mask_optional.shape[0]] + float_val: Tensor = torch.tensor(float_val).unsqueeze(-1).unsqueeze(-1) + # now that inputs are normalized, figure out what value to actually return + if mask_optional is not None: + mask_optional = mask_optional.clone() + if float_is_iterable: + mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) + else: + mask_optional = mask_optional * float_val + return mask_optional + else: + if not float_is_iterable: + return float_val + # create a dummy mask of b,h,w=float_len,1,1 (sigle pixel) + # purpose is for float input to work with mask code, without special cases + float_len = float_val.shape[0] if float_is_iterable else 1 + shape = (float_len,1,1) + mask_optional = torch.ones(shape) + mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) + return mask_optional + + +def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor], force_leader_A=False) -> Union[float, Tensor]: + if multivalA is None and multivalB is None: + return 1.0 # if one is None, use the other - if multivalA == None: + if multivalA is None: return multivalB - elif multivalB == None: + elif multivalB is None: return multivalA # both have a value - combine them based on type # if both are Tensors, make dims match before multiplying if type(multivalA) == Tensor and type(multivalB) == Tensor: - areaA = multivalA.shape[1]*multivalA.shape[2] - areaB = multivalB.shape[1]*multivalB.shape[2] - # match height/width to mask with larger area - leader,follower = (multivalA,multivalB) if areaA >= areaB else (multivalB,multivalA) - batch_size = multivalA.shape[0] if multivalA.shape[0] >= multivalB.shape[0] else multivalB.shape[0] + if force_leader_A: + leader,follower = (multivalA,multivalB) + batch_size = multivalA.shape[0] + else: + areaA = multivalA.shape[1]*multivalA.shape[2] + areaB = multivalB.shape[1]*multivalB.shape[2] + # match height/width to mask with larger area + leader,follower = (multivalA,multivalB) if areaA >= areaB else (multivalB,multivalA) + batch_size = multivalA.shape[0] if multivalA.shape[0] >= multivalB.shape[0] else multivalB.shape[0] # make follower same dimensions as leader follower = torch.unsqueeze(follower, 1) - follower = comfy.utils.common_upscale(follower, leader.shape[2], leader.shape[1], "bilinear", "center") + follower = comfy.utils.common_upscale(follower, leader.shape[-1], leader.shape[-2], "bilinear", "center") follower = torch.squeeze(follower, 1) # make sure batch size will match leader = extend_to_batch_size(leader, batch_size) @@ -264,6 +307,18 @@ def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[floa return multivalA * multivalB +def resize_multival(multival: Union[float, Tensor], batch_size: int, height: int, width: int): + if multival is None: + return 1.0 + if type(multival) != Tensor: + return multival + multival = torch.unsqueeze(multival, 1) + multival = comfy.utils.common_upscale(multival, height, width, "bilinear", "center") + multival = torch.squeeze(multival, 1) + multival = extend_to_batch_size(multival, batch_size) + return multival + + def get_combined_input(inputA: Union[InputPIA, None], inputB: Union[InputPIA, None], x: Tensor): if inputA is None: inputA = InputPIA_Multival(1.0) diff --git a/animatediff/utils_scheduling.py b/animatediff/utils_scheduling.py new file mode 100644 index 0000000..92ac5ab --- /dev/null +++ b/animatediff/utils_scheduling.py @@ -0,0 +1,90 @@ +from typing import Union + +from torch import Tensor + + +def validate_index(index: int, length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False) -> int: + # if part of range, do nothing + if is_range: + return index + # otherwise, validate index + # validate not out of range - only when latent_count is passed in + if length > 0 and index > length-1 and not allow_missing: + raise IndexError(f"Index '{index}' out of range for {length} item(s).") + # if negative, validate not out of range + if index < 0: + if not allow_negative: + raise IndexError(f"Negative indeces not allowed, but was '{index}'.") + conv_index = length+index + if conv_index < 0 and not allow_missing: + raise IndexError(f"Index '{index}', converted to '{conv_index}' out of range for {length} item(s).") + index = conv_index + return index + + +def convert_to_index_int(raw_index: str, length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False) -> int: + try: + return validate_index(int(raw_index), length=length, is_range=is_range, allow_negative=allow_negative, allow_missing=allow_missing) + except ValueError as e: + raise ValueError(f"Index '{raw_index}' must be an integer.", e) + + +def convert_str_to_indexes(indexes_str: str, length: int=0, allow_range=True, allow_missing=False) -> list[int]: + if not indexes_str: + return [] + int_indexes = list(range(0, length)) + allow_negative = length > 0 + chosen_indexes = [] + # parse string - allow positive ints, negative ints, and ranges separated by ':' + groups = indexes_str.split(",") + groups = [g.strip() for g in groups] + for g in groups: + # parse range of indeces (e.g. 2:16) + if ':' in g: + if not allow_range: + raise Exception("Ranges (:) not allowed for this input.") + index_range = g.split(":", 2) + index_range = [r.strip() for r in index_range] + + start_index = index_range[0] + if len(start_index) > 0: + start_index = convert_to_index_int(start_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing) + else: + start_index = 0 + end_index = index_range[1] + if len(end_index) > 0: + end_index = convert_to_index_int(end_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing) + else: + end_index = length + # support step as well, to allow things like reversing, every-other, etc. + step = 1 + if len(index_range) > 2: + step = index_range[2] + if len(step) > 0: + step = convert_to_index_int(step, length=length, is_range=True, allow_negative=True, allow_missing=True) + else: + step = 1 + # if latents were passed in, base indeces on known latent count + if len(int_indexes) > 0: + chosen_indexes.extend(int_indexes[start_index:end_index][::step]) + # otherwise, assume indeces are valid + else: + chosen_indexes.extend(list(range(start_index, end_index, step))) + # parse individual indeces + else: + chosen_indexes.append(convert_to_index_int(g, length=length, allow_negative=allow_negative, allow_missing=allow_missing)) + return chosen_indexes + + +def select_indexes(input_obj: Union[Tensor, list], idxs: list): + if type(input_obj) == Tensor: + return input_obj[idxs] + else: + return [input_obj[i] for i in idxs] + + +def select_indexes_from_str(input_obj: Union[Tensor, list], indexes: str, allow_range=True, err_if_missing=True, err_if_empty=True): + real_idxs = convert_str_to_indexes(indexes, len(input_obj), allow_range=allow_range, allow_missing=not err_if_missing) + if err_if_empty and len(real_idxs) == 0: + raise Exception(f"Nothing was selected based on indexes found in '{indexes}'.") + return select_indexes(input_obj, real_idxs) diff --git a/pyproject.toml b/pyproject.toml index c268580..2684188 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.0.12" +version = "1.1.0" license = { file = "LICENSE" } dependencies = []