From 3a8744860dead6fb2275fa846ded876130c22cf8 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 23 Jul 2024 03:26:01 -0500 Subject: [PATCH 01/22] Added initial work on two novel context consistency techs: NaiveReuse and ContextRef; with this commit, NaiveReuse is functional, but not mask_opt, ContextRef is totally unfunctional as will require appropriate rework of Advanced-ControlNet refcn code to be pushed out --- animatediff/context.py | 13 +++- animatediff/context_extras.py | 112 ++++++++++++++++++++++++++++++++++ animatediff/nodes.py | 23 ++++++- animatediff/nodes_context.py | 88 ++++++++++++++++++++++++++ animatediff/sampling.py | 62 +++++++++++++++++-- animatediff/scheduling.py | 0 6 files changed, 289 insertions(+), 9 deletions(-) create mode 100644 animatediff/context_extras.py create mode 100644 animatediff/scheduling.py diff --git a/animatediff/context.py b/animatediff/context.py index b44fa42..50af0d8 100644 --- a/animatediff/context.py +++ b/animatediff/context.py @@ -11,8 +11,10 @@ from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher +from .context_extras import ContextExtrasGroup from .utils_motion import get_sorted_list_via_attr + class ContextFuseMethod: FLAT = "flat" PYRAMID = "pyramid" @@ -76,6 +78,7 @@ def clone(self): class ContextOptionsGroup: def __init__(self): self.contexts: list[ContextOptions] = [] + self.extras = ContextExtrasGroup() self._current_context: ContextOptions = None self._current_used_steps: int = 0 self._current_index: int = 0 @@ -121,9 +124,10 @@ def has_index(self, index: int) -> int: def is_empty(self) -> bool: return len(self.contexts) == 0 - + def clone(self): cloned = ContextOptionsGroup() + cloned.extras = self.extras.clone() for context in self.contexts: cloned.contexts.append(context) cloned._set_first_as_current() @@ -132,6 +136,11 @@ def clone(self): def initialize_timesteps(self, model: BaseModel): for context in self.contexts: context.start_t = model.model_sampling.percent_to_sigma(context.start_percent) + self.extras.initialize_timesteps(model) + + def prepare_current(self, t: Tensor): + self.prepare_current_context(t) + self.extras.prepare_current(t) def prepare_current_context(self, t: Tensor): curr_t: float = t[0] @@ -620,7 +629,7 @@ def generate_context_visualization(context_opts: ContextOptionsGroup, model: Mod for i, t in enumerate(sigmas): # make context_opts reflect current step/sigma - context_opts.prepare_current_context([t]) + context_opts.prepare_current([t]) context_opts.step = start_step+i # check if context should even be active in this case diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py new file mode 100644 index 0000000..519e0e3 --- /dev/null +++ b/animatediff/context_extras.py @@ -0,0 +1,112 @@ +from torch import Tensor + +from comfy.model_base import BaseModel + + +class ContextExtra: + def __init__(self, start_percent: float, end_percent: float): + # scheduling + self.start_percent = float(start_percent) + self.start_t = 999999999.9 + self.end_percent = float(end_percent) + self.end_t = 0.0 + self.curr_t = 999999999.9 + + def initialize_timesteps(self, model: BaseModel): + self.start_t = model.model_sampling.percent_to_sigma(self.start_percent) + self.end_t = model.model_sampling.percent_to_sigma(self.end_percent) + + def prepare_current(self, t: Tensor): + self.curr_t = t[0] + + def should_run(self): + if self.curr_t > self.start_t or self.curr_t < self.end_t: + return False + return True + + +################################ +# Context Ref +class ContextRefParams: + def __init__(self, + attn_style_fidelity=0.0, attn_ref_weight=0.0, attn_atrength=0.0, + adain_style_fidelity=0.0, adain_ref_weight=0.0, adain_strength=0.0): + # attn1 + self.attn_style_fidelity = attn_style_fidelity + self.attn_ref_weight = attn_ref_weight + self.attn_strength = attn_atrength + # adain + self.adain_style_fidelity = adain_style_fidelity + self.adain_ref_weight = adain_ref_weight + self.adain_strength = adain_strength + + +class ContextRef(ContextExtra): + def __init__(self, start_percent: float, end_percent: float, params: ContextRefParams): + super().__init__(start_percent=start_percent, end_percent=end_percent) + self.params = params + + def should_run(self): + return super().should_run() +#-------------------------------- + + +################################ +# NaiveReuse +class NaiveReuse(ContextExtra): + def __init__(self, start_percent: float, end_percent: float, weighted_mean: float, mask_opt: Tensor=None): + super().__init__(start_percent=start_percent, end_percent=end_percent) + self.weighted_mean = weighted_mean + self.mask_opt = mask_opt + + def should_run(self): + to_return = super().should_run() + # if weighted_mean is 0.0, then reuse will take no effect anyway + return to_return and self.weighted_mean > 0.0 +#-------------------------------- + + +class ContextExtrasGroup: + def __init__(self): + self.context_ref: ContextRef = None + self.naive_reuse: NaiveReuse = None + + def get_extras_list(self) -> list[ContextExtra]: + extras_list = [] + if self.context_ref is not None: + extras_list.append(self.context_ref) + if self.naive_reuse is not None: + extras_list.append(self.naive_reuse) + return extras_list + + def initialize_timesteps(self, model: BaseModel): + for extra in self.get_extras_list(): + extra.initialize_timesteps(model) + + def prepare_current(self, t: Tensor): + for extra in self.get_extras_list(): + extra.prepare_current(t) + + def should_run_context_ref(self): + if not self.context_ref: + return False + return self.context_ref.should_run() + + def should_run_naive_reuse(self): + if not self.naive_reuse: + return False + return self.naive_reuse.should_run() + + def add(self, extra: ContextExtra): + if type(extra) == ContextRef: + self.context_ref = extra + elif type(extra) == NaiveReuse: + self.naive_reuse = extra + else: + raise Exception(f"Unrecognized ContextExtras type: {type(extra)}") + + def clone(self): + cloned = ContextExtrasGroup() + cloned.context_ref = self.context_ref + cloned.naive_reuse = self.naive_reuse + return cloned diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 07ba149..e14e9ca 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -28,7 +28,8 @@ from .nodes_sigma_schedule import (SigmaScheduleNode, RawSigmaScheduleNode, WeightedAverageSigmaScheduleNode, InterpolatedWeightedAverageSigmaScheduleNode, SplitAndCombineSigmaScheduleNode, SigmaScheduleToSigmasNode) from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode, StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, - VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom) + VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom, + SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef) from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode, WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode, WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode) @@ -54,13 +55,15 @@ "ADE_MultivalDynamicFloatInput": MultivalDynamicFloatInputNode, "ADE_MultivalScaledMask": MultivalScaledMaskNode, "ADE_MultivalConvertToMask": MultivalConvertToMaskNode, + ############################################################################### + #------------------------------------------------------------------------------ # Context Opts "ADE_StandardStaticContextOptions": StandardStaticContextOptionsNode, "ADE_StandardUniformContextOptions": StandardUniformContextOptionsNode, "ADE_LoopedUniformContextOptions": LoopedUniformContextOptionsNode, "ADE_ViewsOnlyContextOptions": ViewAsContextOptionsNode, "ADE_BatchedContextOptions": BatchedContextOptionsNode, - "ADE_AnimateDiffUniformContextOptions": LegacyLoopedUniformContextOptionsNode, # Legacy + "ADE_AnimateDiffUniformContextOptions": LegacyLoopedUniformContextOptionsNode, # Legacy/Deprecated "ADE_VisualizeContextOptionsK": VisualizeContextOptionsK, "ADE_VisualizeContextOptionsKAdv": VisualizeContextOptionsKAdv, "ADE_VisualizeContextOptionsSCustom": VisualizeContextOptionsSCustom, @@ -68,6 +71,12 @@ "ADE_StandardStaticViewOptions": StandardStaticViewOptionsNode, "ADE_StandardUniformViewOptions": StandardUniformViewOptionsNode, "ADE_LoopedUniformViewOptions": LoopedUniformViewOptionsNode, + # Context Extras + "ADE_ContextExtras_Set": SetContextExtrasOnContextOptions, + "ADE_ContextExtras_ContextRef": ContextExtras_ContextRef, + "ADE_ContextExtras_NaiveReuse": ContextExtras_NaiveReuse, + #------------------------------------------------------------------------------ + ############################################################################### # Iteration Opts "ADE_IterationOptsDefault": IterationOptionsNode, "ADE_IterationOptsFreeInit": FreeInitOptionsNode, @@ -184,13 +193,15 @@ "ADE_MultivalDynamicFloatInput": "Multival [Float List] πŸŽ­πŸ…πŸ…“", "ADE_MultivalScaledMask": "Multival Scaled Mask πŸŽ­πŸ…πŸ…“", "ADE_MultivalConvertToMask": "Multival to Mask πŸŽ­πŸ…πŸ…“", + ############################################################################### + #------------------------------------------------------------------------------ # Context Opts "ADE_StandardStaticContextOptions": "Context Optionsβ—†Standard Static πŸŽ­πŸ…πŸ…“", "ADE_StandardUniformContextOptions": "Context Optionsβ—†Standard Uniform πŸŽ­πŸ…πŸ…“", "ADE_LoopedUniformContextOptions": "Context Optionsβ—†Looped Uniform πŸŽ­πŸ…πŸ…“", "ADE_ViewsOnlyContextOptions": "Context Optionsβ—†Views Only [VRAMβ‡ˆ] πŸŽ­πŸ…πŸ…“", "ADE_BatchedContextOptions": "Context Optionsβ—†Batched [Non-AD] πŸŽ­πŸ…πŸ…“", - "ADE_AnimateDiffUniformContextOptions": "Context Optionsβ—†Looped Uniform πŸŽ­πŸ…πŸ…“", # Legacy + "ADE_AnimateDiffUniformContextOptions": "Context Optionsβ—†Looped Uniform πŸŽ­πŸ…πŸ…“", # Legacy/Deprecated "ADE_VisualizeContextOptionsK": "Visualize Context Options (K.) πŸŽ­πŸ…πŸ…“", "ADE_VisualizeContextOptionsKAdv": "Visualize Context Options (K.Adv.) πŸŽ­πŸ…πŸ…“", "ADE_VisualizeContextOptionsSCustom": "Visualize Context Options (S.Cus.) πŸŽ­πŸ…πŸ…“", @@ -198,6 +209,12 @@ "ADE_StandardStaticViewOptions": "View Optionsβ—†Standard Static πŸŽ­πŸ…πŸ…“", "ADE_StandardUniformViewOptions": "View Optionsβ—†Standard Uniform πŸŽ­πŸ…πŸ…“", "ADE_LoopedUniformViewOptions": "View Optionsβ—†Looped Uniform πŸŽ­πŸ…πŸ…“", + # Context Extras + "ADE_ContextExtras_Set": "Set Context Extras πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef": "Context Extrasβ—†ContextRef πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_NaiveReuse": "Context Extrasβ—†NaiveReuse πŸŽ­πŸ…πŸ…“", + #------------------------------------------------------------------------------ + ############################################################################### # Iteration Opts "ADE_IterationOptsDefault": "Default Iteration Options πŸŽ­πŸ…πŸ…“", "ADE_IterationOptsFreeInit": "FreeInit Iteration Options πŸŽ­πŸ…πŸ…“", diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index a4ca94c..38c194d 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -6,6 +6,7 @@ from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules, generate_context_visualization) +from .context_extras import ContextExtrasGroup, ContextRef, ContextRefParams, NaiveReuse from .utils_model import BIGMAX, MAX_RESOLUTION @@ -440,3 +441,90 @@ def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sigm images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sigmas=sigmas) return (images,) + + +######################### +# Context Extras +class SetContextExtrasOnContextOptions: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "context_opts": ("CONTEXT_OPTIONS",), + "context_extras": ("CONTEXT_EXTRAS",), + }, + "optional": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("CONTEXT_OPTIONS",) + RETURN_NAMES = ("CONTEXT_OPTS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" + FUNCTION = "set_context_extras" + + def set_context_extras(self, context_opts: ContextOptionsGroup, context_extras: ContextExtrasGroup): + context_opts = context_opts.clone() + context_opts.extras = context_extras.clone() + return (context_opts,) + + +class ContextExtras_NaiveReuse: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_extras": ("CONTEXT_EXTRAS",), + "mask_opt": ("MASK",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), + "weighted_mean": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.001}), + "autosize": ("ADEAUTOSIZE", {"padding": 55}), + } + } + + RETURN_TYPES = ("CONTEXT_EXTRAS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" + FUNCTION = "create_context_extra" + + def create_context_extra(self, start_percent=0.0, end_percent=0.1, weighted_mean=0.95, mask_opt: Tensor=None, prev_extras: ContextExtrasGroup=None): + if prev_extras is None: + prev_extras = prev_extras = ContextExtrasGroup() + prev_extras = prev_extras.clone() + # create extra + naive_reuse = NaiveReuse(start_percent=start_percent, end_percent=end_percent, weighted_mean=weighted_mean, mask_opt=mask_opt) + prev_extras.add(naive_reuse) + return (prev_extras,) + + +class ContextExtras_ContextRef: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_extras": ("CONTEXT_EXTRAS",), + "mask_opt": ("MASK",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001}), + "autosize": ("ADEAUTOSIZE", {"padding": 55}), + } + } + + RETURN_TYPES = ("CONTEXT_EXTRAS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" + FUNCTION = "create_context_extra" + + def create_context_extra(self, start_percent=0.0, end_percent=0.1, mask_opt: Tensor=None, prev_extras: ContextExtrasGroup=None): + if prev_extras is None: + prev_extras = prev_extras = ContextExtrasGroup() + prev_extras = prev_extras.clone() + # create extra + # TODO: make customizable, and allow mask input + params = ContextRefParams(attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_atrength=1.0) + context_ref = ContextRef(start_percent=start_percent, end_percent=end_percent, params=params) + prev_extras.add(context_ref) + return (prev_extras,) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 02a2267..ddbde56 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -71,7 +71,7 @@ def prepare_current_keyframes(self, x: Tensor, timestep: Tensor): if self.motion_models is not None: self.motion_models.prepare_current_keyframe(x=x, t=timestep) if self.params.context_options is not None: - self.params.context_options.prepare_current_context(t=timestep) + self.params.context_options.prepare_current(t=timestep) if self.sample_settings.custom_cfg is not None: self.sample_settings.custom_cfg.prepare_current_keyframe(t=timestep) @@ -617,8 +617,10 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, # add AD/evolved-sampling params to model_options (transformer_options) model_options = model_options.copy() - if "tranformer_options" not in model_options: - model_options["tranformer_options"] = {} + if "transformer_options" not in model_options: + model_options["transformer_options"] = {} + else: + model_options["transformer_options"] = model_options["transformer_options"].copy() model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params() if not ADGS.is_using_sliding_context(): @@ -798,7 +800,25 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list counts_final = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] biases_final = [([0.0] * x_in.shape[0]) for _ in conds] - # perform calc_conds_batch per context window + CONTEXTREF_ATTN_MACHINE_STATE = "contextref_attn_machine_state" + CONTEXTREF_ADAIN_MACHINE_STATE = "contextref_adain_machine_state" + #context_ref_steps = [0,1,2,3,4,5]#[0] + context_ref = False + first_context = False + if ADGS.params.context_options.extras.should_run_context_ref(): + context_ref = True + first_context = True + + #naive_steps = [0,1]#[0,1,2]#[0,1,2,3] + naive_counts_mult = 50 + naive_init = False + cached_naive_conds = None + cached_naive_ctx_idxs = None + if ADGS.params.context_options.extras.should_run_naive_reuse(): + cached_naive_conds = [torch.zeros_like(x_in) for _ in conds] + #cached_naive_counts = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] + naive_init = True + # perform calc_conds_batch per context window for ctx_idxs in context_windows: ADGS.params.sub_idxs = ctx_idxs if ADGS.motion_models is not None: @@ -817,6 +837,18 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list sub_timestep = timestep[full_idxs] sub_conds = [get_resized_cond(cond, full_idxs, len(ctx_idxs)) for cond in conds] + if context_ref: + if first_context: + first_context = False + model_options["transformer_options"][CONTEXTREF_ATTN_MACHINE_STATE] = "write" + model_options["transformer_options"][CONTEXTREF_ADAIN_MACHINE_STATE] = "write" + else: + model_options["transformer_options"][CONTEXTREF_ATTN_MACHINE_STATE] = "read" + model_options["transformer_options"][CONTEXTREF_ADAIN_MACHINE_STATE] = "read" + else: + model_options["transformer_options"][CONTEXTREF_ATTN_MACHINE_STATE] = "off" + model_options["transformer_options"][CONTEXTREF_ADAIN_MACHINE_STATE] = "off" + sub_conds_out = calc_cond_uncond_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options) if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: @@ -841,12 +873,34 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list for i in range(len(sub_conds_out)): conds_final[i][full_idxs] += sub_conds_out[i] * weights_tensor counts_final[i][full_idxs] += weights_tensor + if naive_init: + cached_naive_ctx_idxs = ctx_idxs + for i in range(len(sub_conds)): + cached_naive_conds[i][full_idxs] = conds_final[i][full_idxs] / counts_final[i][full_idxs] + #cached_naive_conds[i][full_idxs] = conds_final[i][full_idxs] + #cached_naive_counts[i][full_idxs] = counts_final[i][full_idxs] + naive_init = False if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: # already normalized, so return as is del counts_final return conds_final else: + if cached_naive_conds is not None: + #start_idx = cached_naive_ctx_idxs[-1] + 1 + start_idx = cached_naive_ctx_idxs[0] + for z in range(start_idx, ADGS.params.full_length, len(cached_naive_ctx_idxs)): + for i in range(len(cached_naive_conds)): + new_ctx_idxs = [zz for zz in list(range(z, z+len(cached_naive_ctx_idxs))) if zz < ADGS.params.full_length] + # make sure when getting cached_naive idxs, they are adjusted for actual length leftover length + adjusted_cnaive_ctx_idxs = cached_naive_ctx_idxs[:len(new_ctx_idxs)] + weighted_mean = ADGS.params.context_options.extras.naive_reuse.weighted_mean + conds_final[i][new_ctx_idxs] = (weighted_mean * (cached_naive_conds[i][adjusted_cnaive_ctx_idxs]*counts_final[i][new_ctx_idxs])) + ((1.-weighted_mean) * conds_final[i][new_ctx_idxs]) + #conds_final[i][new_idxs] += (cached_naive_conds[i][cached_naive_full_idxs] / cached_naive_counts[i][cached_naive_full_idxs]) * counts + #counts = counts_final[i][new_idxs] * naive_counts_mult# / 2 + #conds_final[i][new_idxs] += (cached_naive_conds[i][cached_naive_full_idxs] / cached_naive_counts[i][cached_naive_full_idxs]) * counts + #counts_final[i][new_idxs] += counts# * 10#counts_final[i][full_idxs] + del cached_naive_conds # normalize conds via division by context usage counts for i in range(len(conds_final)): conds_final[i] /= counts_final[i] diff --git a/animatediff/scheduling.py b/animatediff/scheduling.py new file mode 100644 index 0000000..e69de29 From bd514925de8dafce908ea54507fc3f8c133776ee Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 25 Jul 2024 00:52:15 -0500 Subject: [PATCH 02/22] Added strength_multival inputs to NaiveReuse (working) and ContextRef (no functionality at all) nodes, some scaffolding code for cleanup --- animatediff/context.py | 1 + animatediff/context_extras.py | 34 +++++++++++++++++++++++++++++--- animatediff/nodes_context.py | 17 +++++++++------- animatediff/nodes_multival.py | 36 ++-------------------------------- animatediff/sampling.py | 3 ++- animatediff/utils_motion.py | 37 +++++++++++++++++++++++++++++++++++ 6 files changed, 83 insertions(+), 45 deletions(-) diff --git a/animatediff/context.py b/animatediff/context.py index 50af0d8..c0514f0 100644 --- a/animatediff/context.py +++ b/animatediff/context.py @@ -90,6 +90,7 @@ def reset(self): self._current_index = 0 self.step = 0 self._set_first_as_current() + self.extras.cleanup() @property def step(self): diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py index 519e0e3..1d28edb 100644 --- a/animatediff/context_extras.py +++ b/animatediff/context_extras.py @@ -2,6 +2,8 @@ from comfy.model_base import BaseModel +from .utils_motion import prepare_mask_batch, extend_to_batch_size, get_combined_multival + class ContextExtra: def __init__(self, start_percent: float, end_percent: float): @@ -24,9 +26,12 @@ def should_run(self): return False return True + def cleanup(self): + pass + ################################ -# Context Ref +# ContextRef class ContextRefParams: def __init__(self, attn_style_fidelity=0.0, attn_ref_weight=0.0, attn_atrength=0.0, @@ -54,11 +59,30 @@ def should_run(self): ################################ # NaiveReuse class NaiveReuse(ContextExtra): - def __init__(self, start_percent: float, end_percent: float, weighted_mean: float, mask_opt: Tensor=None): + def __init__(self, start_percent: float, end_percent: float, weighted_mean: float, multival_opt: Tensor=None): super().__init__(start_percent=start_percent, end_percent=end_percent) self.weighted_mean = weighted_mean - self.mask_opt = mask_opt + self.orig_multival = multival_opt + self.mask: Tensor = None + def cleanup(self): + super().cleanup() + del self.mask + self.mask = None + + def get_effective_weighted_mean(self, x: Tensor, idxs: list[int]): + if self.orig_multival is None: + return self.weighted_mean + # otherwise, is Tensor and should be extended to match dims and size of x; + # see if needs to be recalculated + if type(self.orig_multival) != Tensor: + return self.weighted_mean * self.orig_multival + elif self.mask is None or self.mask.shape[0] != x.shape[0] or self.mask.shape[-1] != x.shape[-1] or self.mask.shape[-2] != x.shape[-2]: + del self.mask + self.mask = prepare_mask_batch(self.orig_multival, x.shape) + self.mask = extend_to_batch_size(self.mask, x.shape[0]) + return self.weighted_mean * self.mask[idxs].to(dtype=x.dtype, device=x.device) + def should_run(self): to_return = super().should_run() # if weighted_mean is 0.0, then reuse will take no effect anyway @@ -105,6 +129,10 @@ def add(self, extra: ContextExtra): else: raise Exception(f"Unrecognized ContextExtras type: {type(extra)}") + def cleanup(self): + for extra in self.get_extras_list(): + extra.cleanup() + def clone(self): cloned = ContextExtrasGroup() cloned.context_ref = self.context_ref diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index 38c194d..8f69aad 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -1,5 +1,6 @@ import torch from torch import Tensor +from typing import Union import comfy.samplers from comfy.model_patcher import ModelPatcher @@ -477,11 +478,11 @@ def INPUT_TYPES(s): }, "optional": { "prev_extras": ("CONTEXT_EXTRAS",), - "mask_opt": ("MASK",), + "strength_multival": ("MULTIVAL",), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), "weighted_mean": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.001}), - "autosize": ("ADEAUTOSIZE", {"padding": 55}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -489,12 +490,13 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" FUNCTION = "create_context_extra" - def create_context_extra(self, start_percent=0.0, end_percent=0.1, weighted_mean=0.95, mask_opt: Tensor=None, prev_extras: ContextExtrasGroup=None): + def create_context_extra(self, start_percent=0.0, end_percent=0.1, weighted_mean=0.95, strength_multival: Union[float, Tensor]=None, + prev_extras: ContextExtrasGroup=None): if prev_extras is None: prev_extras = prev_extras = ContextExtrasGroup() prev_extras = prev_extras.clone() # create extra - naive_reuse = NaiveReuse(start_percent=start_percent, end_percent=end_percent, weighted_mean=weighted_mean, mask_opt=mask_opt) + naive_reuse = NaiveReuse(start_percent=start_percent, end_percent=end_percent, weighted_mean=weighted_mean, multival_opt=strength_multival) prev_extras.add(naive_reuse) return (prev_extras,) @@ -507,10 +509,10 @@ def INPUT_TYPES(s): }, "optional": { "prev_extras": ("CONTEXT_EXTRAS",), - "mask_opt": ("MASK",), + "strength_multival": ("MULTIVAL",), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001}), - "autosize": ("ADEAUTOSIZE", {"padding": 55}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -518,7 +520,8 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" FUNCTION = "create_context_extra" - def create_context_extra(self, start_percent=0.0, end_percent=0.1, mask_opt: Tensor=None, prev_extras: ContextExtrasGroup=None): + def create_context_extra(self, start_percent=0.0, end_percent=0.1, strength_multival: Union[float, Tensor]=None, + prev_extras: ContextExtrasGroup=None): if prev_extras is None: prev_extras = prev_extras = ContextExtrasGroup() prev_extras = prev_extras.clone() diff --git a/animatediff/nodes_multival.py b/animatediff/nodes_multival.py index 1149104..4579e17 100644 --- a/animatediff/nodes_multival.py +++ b/animatediff/nodes_multival.py @@ -4,7 +4,7 @@ import torch from torch import Tensor -from .utils_motion import linear_conversion, normalize_min_max, extend_to_batch_size, extend_list_to_batch_size +from .utils_motion import create_multival_combo, linear_conversion, normalize_min_max, extend_to_batch_size, extend_list_to_batch_size class ScaleType: @@ -31,39 +31,7 @@ def INPUT_TYPES(s): FUNCTION = "create_multival" def create_multival(self, float_val: Union[float, list[float]]=1.0, mask_optional: Tensor=None): - # first, normalize inputs - # if float_val is iterable, treat as a list and assume inputs are floats - float_is_iterable = False - if isinstance(float_val, Iterable): - float_is_iterable = True - float_val = list(float_val) - # if mask present, make sure float_val list can be applied to list - match lengths - if mask_optional is not None: - if len(float_val) < mask_optional.shape[0]: - # copies last entry enough times to match mask shape - float_val = extend_list_to_batch_size(float_val, mask_optional.shape[0]) - if mask_optional.shape[0] < len(float_val): - mask_optional = extend_to_batch_size(mask_optional, len(float_val)) - float_val = float_val[:mask_optional.shape[0]] - float_val: Tensor = torch.tensor(float_val).unsqueeze(-1).unsqueeze(-1) - # now that inputs are normalized, figure out what value to actually return - if mask_optional is not None: - mask_optional = mask_optional.clone() - if float_is_iterable: - mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) - else: - mask_optional = mask_optional * float_val - return (mask_optional,) - else: - if not float_is_iterable: - return (float_val,) - # create a dummy mask of b,h,w=float_len,1,1 (sigle pixel) - # purpose is for float input to work with mask code, without special cases - float_len = float_val.shape[0] if float_is_iterable else 1 - shape = (float_len,1,1) - mask_optional = torch.ones(shape) - mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) - return (mask_optional,) + return (create_multival_combo(float_val=float_val, mask_optional=mask_optional),) class MultivalScaledMaskNode: diff --git a/animatediff/sampling.py b/animatediff/sampling.py index ddbde56..72c1794 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -114,6 +114,7 @@ def reset(self): del self.motion_models self.motion_models = None if self.params is not None: + self.params.context_options.reset() del self.params self.params = None if self.sample_settings is not None: @@ -894,7 +895,7 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list new_ctx_idxs = [zz for zz in list(range(z, z+len(cached_naive_ctx_idxs))) if zz < ADGS.params.full_length] # make sure when getting cached_naive idxs, they are adjusted for actual length leftover length adjusted_cnaive_ctx_idxs = cached_naive_ctx_idxs[:len(new_ctx_idxs)] - weighted_mean = ADGS.params.context_options.extras.naive_reuse.weighted_mean + weighted_mean = ADGS.params.context_options.extras.naive_reuse.get_effective_weighted_mean(x_in, new_ctx_idxs) conds_final[i][new_ctx_idxs] = (weighted_mean * (cached_naive_conds[i][adjusted_cnaive_ctx_idxs]*counts_final[i][new_ctx_idxs])) + ((1.-weighted_mean) * conds_final[i][new_ctx_idxs]) #conds_final[i][new_idxs] += (cached_naive_conds[i][cached_naive_full_idxs] / cached_naive_counts[i][cached_naive_full_idxs]) * counts #counts = counts_final[i][new_idxs] * naive_counts_mult# / 2 diff --git a/animatediff/utils_motion.py b/animatediff/utils_motion.py index fcc205a..0cf5b1a 100644 --- a/animatediff/utils_motion.py +++ b/animatediff/utils_motion.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from torch import Tensor, nn from abc import ABC, abstractmethod +from collections.abc import Iterable import comfy.model_management as model_management import comfy.ops @@ -238,6 +239,42 @@ def get_mask(self, x: Tensor): return mask * self.multival +def create_multival_combo(float_val: Union[float, list[float]], mask_optional: Tensor=None): + # first, normalize inputs + # if float_val is iterable, treat as a list and assume inputs are floats + float_is_iterable = False + if isinstance(float_val, Iterable): + float_is_iterable = True + float_val = list(float_val) + # if mask present, make sure float_val list can be applied to list - match lengths + if mask_optional is not None: + if len(float_val) < mask_optional.shape[0]: + # copies last entry enough times to match mask shape + float_val = extend_list_to_batch_size(float_val, mask_optional.shape[0]) + if mask_optional.shape[0] < len(float_val): + mask_optional = extend_to_batch_size(mask_optional, len(float_val)) + float_val = float_val[:mask_optional.shape[0]] + float_val: Tensor = torch.tensor(float_val).unsqueeze(-1).unsqueeze(-1) + # now that inputs are normalized, figure out what value to actually return + if mask_optional is not None: + mask_optional = mask_optional.clone() + if float_is_iterable: + mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) + else: + mask_optional = mask_optional * float_val + return mask_optional + else: + if not float_is_iterable: + return float_val + # create a dummy mask of b,h,w=float_len,1,1 (sigle pixel) + # purpose is for float input to work with mask code, without special cases + float_len = float_val.shape[0] if float_is_iterable else 1 + shape = (float_len,1,1) + mask_optional = torch.ones(shape) + mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) + return mask_optional + + def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor]) -> Union[float, Tensor]: # if one is None, use the other if multivalA == None: From b55fb0f56f51b8d399c4a54ca44f10088ff3e321 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 27 Jul 2024 18:53:19 -0500 Subject: [PATCH 03/22] Adjustments for soon-to-be ContextRef support in Advanced-ControlNet --- animatediff/sampling.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 72c1794..83b659f 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -801,8 +801,8 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list counts_final = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] biases_final = [([0.0] * x_in.shape[0]) for _ in conds] - CONTEXTREF_ATTN_MACHINE_STATE = "contextref_attn_machine_state" - CONTEXTREF_ADAIN_MACHINE_STATE = "contextref_adain_machine_state" + CONTEXTREF_MACHINE_STATE = "contextref_machine_state" + CONTEXTREF_CLEAN_FUNC = "contextref_clean_func" #context_ref_steps = [0,1,2,3,4,5]#[0] context_ref = False first_context = False @@ -841,14 +841,11 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list if context_ref: if first_context: first_context = False - model_options["transformer_options"][CONTEXTREF_ATTN_MACHINE_STATE] = "write" - model_options["transformer_options"][CONTEXTREF_ADAIN_MACHINE_STATE] = "write" + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "write" else: - model_options["transformer_options"][CONTEXTREF_ATTN_MACHINE_STATE] = "read" - model_options["transformer_options"][CONTEXTREF_ADAIN_MACHINE_STATE] = "read" + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "read" else: - model_options["transformer_options"][CONTEXTREF_ATTN_MACHINE_STATE] = "off" - model_options["transformer_options"][CONTEXTREF_ADAIN_MACHINE_STATE] = "off" + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "off" sub_conds_out = calc_cond_uncond_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options) @@ -881,7 +878,11 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list #cached_naive_conds[i][full_idxs] = conds_final[i][full_idxs] #cached_naive_counts[i][full_idxs] = counts_final[i][full_idxs] naive_init = False - + + # clean contextref stuff with provided ACN function, if applicable + if context_ref: + model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC]() + if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: # already normalized, so return as is del counts_final From 58f914a42603b2ff0cd706e00ebfa3df223c23ef Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 27 Jul 2024 21:06:06 -0500 Subject: [PATCH 04/22] Small cleanup --- animatediff/sampling.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 83b659f..fcc36ed 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -803,15 +803,12 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list CONTEXTREF_MACHINE_STATE = "contextref_machine_state" CONTEXTREF_CLEAN_FUNC = "contextref_clean_func" - #context_ref_steps = [0,1,2,3,4,5]#[0] context_ref = False first_context = False if ADGS.params.context_options.extras.should_run_context_ref(): context_ref = True first_context = True - #naive_steps = [0,1]#[0,1,2]#[0,1,2,3] - naive_counts_mult = 50 naive_init = False cached_naive_conds = None cached_naive_ctx_idxs = None @@ -875,8 +872,6 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list cached_naive_ctx_idxs = ctx_idxs for i in range(len(sub_conds)): cached_naive_conds[i][full_idxs] = conds_final[i][full_idxs] / counts_final[i][full_idxs] - #cached_naive_conds[i][full_idxs] = conds_final[i][full_idxs] - #cached_naive_counts[i][full_idxs] = counts_final[i][full_idxs] naive_init = False # clean contextref stuff with provided ACN function, if applicable @@ -898,10 +893,6 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list adjusted_cnaive_ctx_idxs = cached_naive_ctx_idxs[:len(new_ctx_idxs)] weighted_mean = ADGS.params.context_options.extras.naive_reuse.get_effective_weighted_mean(x_in, new_ctx_idxs) conds_final[i][new_ctx_idxs] = (weighted_mean * (cached_naive_conds[i][adjusted_cnaive_ctx_idxs]*counts_final[i][new_ctx_idxs])) + ((1.-weighted_mean) * conds_final[i][new_ctx_idxs]) - #conds_final[i][new_idxs] += (cached_naive_conds[i][cached_naive_full_idxs] / cached_naive_counts[i][cached_naive_full_idxs]) * counts - #counts = counts_final[i][new_idxs] * naive_counts_mult# / 2 - #conds_final[i][new_idxs] += (cached_naive_conds[i][cached_naive_full_idxs] / cached_naive_counts[i][cached_naive_full_idxs]) * counts - #counts_final[i][new_idxs] += counts# * 10#counts_final[i][full_idxs] del cached_naive_conds # normalize conds via division by context usage counts for i in range(len(conds_final)): From d1fc329c619498f3c7b2c62da73fd2f0bb257c32 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 28 Jul 2024 19:12:56 -0500 Subject: [PATCH 05/22] Fixed ContextRef so it keeps track of current cond/uncond idx - matching commit made to Advanced-ControlNet --- animatediff/sampling.py | 172 ++++++++++++++++++++++++---------------- 1 file changed, 102 insertions(+), 70 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index fcc36ed..e8342fb 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -404,6 +404,28 @@ def __exit__(self, *args, **kwargs): self.previous_dwi_gn_cast_weights = None +class ContextRefInjector: + def __init__(self): + self.orig_can_concat_cond = None + + def inject(self): + self.orig_can_concat_cond = comfy.samplers.can_concat_cond + comfy.samplers.can_concat_cond = ContextRefInjector.can_concat_cond_contextref_factory(self.orig_can_concat_cond) + + def restore(self): + if self.orig_can_concat_cond is not None: + comfy.samplers.can_concat_cond = self.orig_can_concat_cond + + @staticmethod + def can_concat_cond_contextref_factory(orig_func: Callable): + def can_concat_cond_contextref_injection(c1, c2, *args, **kwargs): + #return orig_func(c1, c2, *args, **kwargs) + if c1 is c2: + return True + return False + return can_concat_cond_contextref_injection + + def motion_sample_factory(orig_comfy_sample: Callable, is_custom: bool=False) -> Callable: def motion_sample(model: ModelPatcherAndInjector, noise: Tensor, *args, **kwargs): # check if model is intended for injecting @@ -801,82 +823,92 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list counts_final = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] biases_final = [([0.0] * x_in.shape[0]) for _ in conds] + CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all" CONTEXTREF_MACHINE_STATE = "contextref_machine_state" CONTEXTREF_CLEAN_FUNC = "contextref_clean_func" context_ref = False + context_ref_injector = None first_context = False - if ADGS.params.context_options.extras.should_run_context_ref(): - context_ref = True - first_context = True - - naive_init = False - cached_naive_conds = None - cached_naive_ctx_idxs = None - if ADGS.params.context_options.extras.should_run_naive_reuse(): - cached_naive_conds = [torch.zeros_like(x_in) for _ in conds] - #cached_naive_counts = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] - naive_init = True - # perform calc_conds_batch per context window - for ctx_idxs in context_windows: - ADGS.params.sub_idxs = ctx_idxs - if ADGS.motion_models is not None: - ADGS.motion_models.set_sub_idxs(ctx_idxs) - ADGS.motion_models.set_video_length(len(ctx_idxs), ADGS.params.full_length) - # update exposed params - model_options["transformer_options"]["ad_params"]["sub_idxs"] = ctx_idxs - model_options["transformer_options"]["ad_params"]["context_length"] = len(ctx_idxs) - # account for all portions of input frames - full_idxs = [] - for n in range(batched_conds): - for ind in ctx_idxs: - full_idxs.append((ADGS.params.full_length*n)+ind) - # get subsections of x, timestep, conds - sub_x = x_in[full_idxs] - sub_timestep = timestep[full_idxs] - sub_conds = [get_resized_cond(cond, full_idxs, len(ctx_idxs)) for cond in conds] - - if context_ref: - if first_context: - first_context = False - model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "write" + # need to make sure that contextref stuff gets cleaned up, no matter what + try: + if ADGS.params.context_options.extras.should_run_context_ref(): + context_ref = True + first_context = True + # use injector to ensure only 1 cond or uncond will be batched at a time + context_ref_injector = ContextRefInjector() + context_ref_injector.inject() + + naive_init = False + cached_naive_conds = None + cached_naive_ctx_idxs = None + if ADGS.params.context_options.extras.should_run_naive_reuse(): + cached_naive_conds = [torch.zeros_like(x_in) for _ in conds] + #cached_naive_counts = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] + naive_init = True + # perform calc_conds_batch per context window + for ctx_idxs in context_windows: + ADGS.params.sub_idxs = ctx_idxs + if ADGS.motion_models is not None: + ADGS.motion_models.set_sub_idxs(ctx_idxs) + ADGS.motion_models.set_video_length(len(ctx_idxs), ADGS.params.full_length) + # update exposed params + model_options["transformer_options"]["ad_params"]["sub_idxs"] = ctx_idxs + model_options["transformer_options"]["ad_params"]["context_length"] = len(ctx_idxs) + # account for all portions of input frames + full_idxs = [] + for n in range(batched_conds): + for ind in ctx_idxs: + full_idxs.append((ADGS.params.full_length*n)+ind) + # get subsections of x, timestep, conds + sub_x = x_in[full_idxs] + sub_timestep = timestep[full_idxs] + sub_conds = [get_resized_cond(cond, full_idxs, len(ctx_idxs)) for cond in conds] + + if context_ref: + # set cond counter to 0 (each cond encountered will increment it by 1) + model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL][0].contextref_cond_idx = 0 + if first_context: + first_context = False + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "write" + else: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "read" else: - model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "read" - else: - model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "off" - - sub_conds_out = calc_cond_uncond_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options) - - if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: - full_length = ADGS.params.full_length - for pos, idx in enumerate(ctx_idxs): - # bias is the influence of a specific index in relation to the whole context window - bias = 1 - abs(idx - (ctx_idxs[0] + ctx_idxs[-1]) / 2) / ((ctx_idxs[-1] - ctx_idxs[0] + 1e-2) / 2) - bias = max(1e-2, bias) - # take weighted average relative to total bias of current idx - # and account for batched_conds + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "off" + + sub_conds_out = calc_cond_uncond_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options) + + if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: + full_length = ADGS.params.full_length + for pos, idx in enumerate(ctx_idxs): + # bias is the influence of a specific index in relation to the whole context window + bias = 1 - abs(idx - (ctx_idxs[0] + ctx_idxs[-1]) / 2) / ((ctx_idxs[-1] - ctx_idxs[0] + 1e-2) / 2) + bias = max(1e-2, bias) + # take weighted average relative to total bias of current idx + # and account for batched_conds + for i in range(len(sub_conds_out)): + for n in range(batched_conds): + bias_total = biases_final[i][(full_length*n)+idx] + prev_weight = (bias_total / (bias_total + bias)) + new_weight = (bias / (bias_total + bias)) + conds_final[i][(full_length*n)+idx] = conds_final[i][(full_length*n)+idx] * prev_weight + sub_conds_out[i][(full_length*n)+pos] * new_weight + biases_final[i][(full_length*n)+idx] = bias_total + bias + else: + # add conds and counts based on weights of fuse method + weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method, sigma=timestep) * batched_conds + weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) for i in range(len(sub_conds_out)): - for n in range(batched_conds): - bias_total = biases_final[i][(full_length*n)+idx] - prev_weight = (bias_total / (bias_total + bias)) - new_weight = (bias / (bias_total + bias)) - conds_final[i][(full_length*n)+idx] = conds_final[i][(full_length*n)+idx] * prev_weight + sub_conds_out[i][(full_length*n)+pos] * new_weight - biases_final[i][(full_length*n)+idx] = bias_total + bias - else: - # add conds and counts based on weights of fuse method - weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method, sigma=timestep) * batched_conds - weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - for i in range(len(sub_conds_out)): - conds_final[i][full_idxs] += sub_conds_out[i] * weights_tensor - counts_final[i][full_idxs] += weights_tensor - if naive_init: - cached_naive_ctx_idxs = ctx_idxs - for i in range(len(sub_conds)): - cached_naive_conds[i][full_idxs] = conds_final[i][full_idxs] / counts_final[i][full_idxs] - naive_init = False - - # clean contextref stuff with provided ACN function, if applicable - if context_ref: - model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC]() + conds_final[i][full_idxs] += sub_conds_out[i] * weights_tensor + counts_final[i][full_idxs] += weights_tensor + if naive_init: + cached_naive_ctx_idxs = ctx_idxs + for i in range(len(sub_conds)): + cached_naive_conds[i][full_idxs] = conds_final[i][full_idxs] / counts_final[i][full_idxs] + naive_init = False + finally: + # clean contextref stuff with provided ACN function, if applicable + if context_ref: + model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC]() + context_ref_injector.restore() if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: # already normalized, so return as is From 1c68c568094a97aab67cfe0dc74eef2d89760ef7 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 29 Jul 2024 01:35:12 -0500 Subject: [PATCH 06/22] NaiveReuse is now functional for Uniform Context Options and relative fuse_method --- animatediff/sampling.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index e8342fb..e8b294b 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -820,7 +820,12 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list # prepare final conds, out_counts, and biases conds_final = [torch.zeros_like(x_in) for _ in conds] - counts_final = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] + if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: + # counts_final not used for RELATIVE fuse_method + counts_final = [torch.ones((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] + else: + # default counts_final initialization + counts_final = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] biases_final = [([0.0] * x_in.shape[0]) for _ in conds] CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all" @@ -910,22 +915,24 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC]() context_ref_injector.restore() + # handle NaiveReuse + if cached_naive_conds is not None: + start_idx = cached_naive_ctx_idxs[0] + for z in range(0, ADGS.params.full_length, len(cached_naive_ctx_idxs)): + for i in range(len(cached_naive_conds)): + # get the 'true' idxs of this window + new_ctx_idxs = [(zz+start_idx) % ADGS.params.full_length for zz in list(range(z, z+len(cached_naive_ctx_idxs))) if zz < ADGS.params.full_length] + # make sure when getting cached_naive idxs, they are adjusted for actual length leftover length + adjusted_naive_ctx_idxs = cached_naive_ctx_idxs[:len(new_ctx_idxs)] + weighted_mean = ADGS.params.context_options.extras.naive_reuse.get_effective_weighted_mean(x_in, new_ctx_idxs) + conds_final[i][new_ctx_idxs] = (weighted_mean * (cached_naive_conds[i][adjusted_naive_ctx_idxs]*counts_final[i][new_ctx_idxs])) + ((1.-weighted_mean) * conds_final[i][new_ctx_idxs]) + del cached_naive_conds + if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: # already normalized, so return as is del counts_final return conds_final else: - if cached_naive_conds is not None: - #start_idx = cached_naive_ctx_idxs[-1] + 1 - start_idx = cached_naive_ctx_idxs[0] - for z in range(start_idx, ADGS.params.full_length, len(cached_naive_ctx_idxs)): - for i in range(len(cached_naive_conds)): - new_ctx_idxs = [zz for zz in list(range(z, z+len(cached_naive_ctx_idxs))) if zz < ADGS.params.full_length] - # make sure when getting cached_naive idxs, they are adjusted for actual length leftover length - adjusted_cnaive_ctx_idxs = cached_naive_ctx_idxs[:len(new_ctx_idxs)] - weighted_mean = ADGS.params.context_options.extras.naive_reuse.get_effective_weighted_mean(x_in, new_ctx_idxs) - conds_final[i][new_ctx_idxs] = (weighted_mean * (cached_naive_conds[i][adjusted_cnaive_ctx_idxs]*counts_final[i][new_ctx_idxs])) + ((1.-weighted_mean) * conds_final[i][new_ctx_idxs]) - del cached_naive_conds # normalize conds via division by context usage counts for i in range(len(conds_final)): conds_final[i] /= counts_final[i] From 11252fc34feca3f668940cc87ff3dcca4048452a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 31 Jul 2024 06:31:21 -0500 Subject: [PATCH 07/22] Added sliding reference window mode for ContextRef + ContextRef Mode nodes --- animatediff/context_extras.py | 21 ++++++++++++++- animatediff/nodes.py | 7 ++++- animatediff/nodes_context.py | 49 +++++++++++++++++++++++++++++++++-- animatediff/sampling.py | 41 +++++++++++++++++------------ animatediff/utils_model.py | 6 +++++ 5 files changed, 104 insertions(+), 20 deletions(-) diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py index 1d28edb..350ce9c 100644 --- a/animatediff/context_extras.py +++ b/animatediff/context_extras.py @@ -46,10 +46,29 @@ def __init__(self, self.adain_strength = adain_strength +class ContextRefMode: + FIRST = "first" + SLIDING = "sliding" + _LIST = [FIRST, SLIDING] + + def __init__(self, mode: str, sliding_width=2): + self.mode = mode + self.sliding_width = sliding_width + + @classmethod + def init_first(cls): + return ContextRefMode(cls.FIRST) + + @classmethod + def init_sliding(cls, sliding_width): + return ContextRefMode(cls.SLIDING, sliding_width=sliding_width) + + class ContextRef(ContextExtra): - def __init__(self, start_percent: float, end_percent: float, params: ContextRefParams): + def __init__(self, start_percent: float, end_percent: float, params: ContextRefParams, mode: ContextRefMode): super().__init__(start_percent=start_percent, end_percent=end_percent) self.params = params + self.mode = mode def should_run(self): return super().should_run() diff --git a/animatediff/nodes.py b/animatediff/nodes.py index e14e9ca..4796f90 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -29,7 +29,8 @@ from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode, StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom, - SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef) + SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef, + ContextRef_ModeFirst, ContextRef_ModeSliding) from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode, WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode, WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode) @@ -75,6 +76,8 @@ "ADE_ContextExtras_Set": SetContextExtrasOnContextOptions, "ADE_ContextExtras_ContextRef": ContextExtras_ContextRef, "ADE_ContextExtras_NaiveReuse": ContextExtras_NaiveReuse, + "ADE_ContextExtras_ContextRef_ModeFirst": ContextRef_ModeFirst, + "ADE_ContextExtras_ContextRef_ModeSliding": ContextRef_ModeSliding, #------------------------------------------------------------------------------ ############################################################################### # Iteration Opts @@ -213,6 +216,8 @@ "ADE_ContextExtras_Set": "Set Context Extras πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef": "Context Extrasβ—†ContextRef πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_NaiveReuse": "Context Extrasβ—†NaiveReuse πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_ModeFirst": "ContextRef Modeβ—†First πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_ModeSliding": "ContextRef Modeβ—†Sliding πŸŽ­πŸ…πŸ…“", #------------------------------------------------------------------------------ ############################################################################### # Iteration Opts diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index 8f69aad..37b6c46 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -7,7 +7,7 @@ from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules, generate_context_visualization) -from .context_extras import ContextExtrasGroup, ContextRef, ContextRefParams, NaiveReuse +from .context_extras import ContextExtrasGroup, ContextRef, ContextRefParams, ContextRefMode, NaiveReuse from .utils_model import BIGMAX, MAX_RESOLUTION @@ -510,6 +510,7 @@ def INPUT_TYPES(s): "optional": { "prev_extras": ("CONTEXT_EXTRAS",), "strength_multival": ("MULTIVAL",), + "contextref_mode": ("CONTEXTREF_MODE",), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001}), "autosize": ("ADEAUTOSIZE", {"padding": 0}), @@ -521,6 +522,7 @@ def INPUT_TYPES(s): FUNCTION = "create_context_extra" def create_context_extra(self, start_percent=0.0, end_percent=0.1, strength_multival: Union[float, Tensor]=None, + contextref_mode: ContextRefMode=None, prev_extras: ContextExtrasGroup=None): if prev_extras is None: prev_extras = prev_extras = ContextExtrasGroup() @@ -528,6 +530,49 @@ def create_context_extra(self, start_percent=0.0, end_percent=0.1, strength_mult # create extra # TODO: make customizable, and allow mask input params = ContextRefParams(attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_atrength=1.0) - context_ref = ContextRef(start_percent=start_percent, end_percent=end_percent, params=params) + if contextref_mode is None: + contextref_mode = ContextRefMode.init_first() + context_ref = ContextRef(start_percent=start_percent, end_percent=end_percent, params=params, mode=contextref_mode) prev_extras.add(context_ref) return (prev_extras,) + + +class ContextRef_ModeFirst: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "autosize": ("ADEAUTOSIZE", {"padding": 25}), + }, + } + + RETURN_TYPES = ("CONTEXTREF_MODE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/ContextRef" + FUNCTION = "create_contextref_mode" + + def create_contextref_mode(self): + mode = ContextRefMode.init_first() + return (mode,) + + +class ContextRef_ModeSliding: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "sliding_width": ("INT", {"default": 2, "min": 2, "max": BIGMAX, "step": 1}), + "autosize": ("ADEAUTOSIZE", {"padding": 42}), + } + } + + RETURN_TYPES = ("CONTEXTREF_MODE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/ContextRef" + FUNCTION = "create_contextref_mode" + + def create_contextref_mode(self, sliding_width): + mode = ContextRefMode.init_sliding(sliding_width=sliding_width) + return (mode,) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index e8b294b..87aedf4 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -26,8 +26,9 @@ from .conditioning import COND_CONST, LoraHookGroup, conditioning_set_values from .context import ContextFuseMethod, ContextSchedules, get_context_weights, get_context_windows +from .context_extras import ContextRefMode from .sample_settings import IterationOptions, SampleSettings, SeedNoiseGeneration, NoisedImageToInject -from .utils_model import ModelTypeSD, vae_encode_raw_batched, vae_decode_raw_batched +from .utils_model import ModelTypeSD, MachineState, vae_encode_raw_batched, vae_decode_raw_batched from .utils_motion import composite_extend, get_combined_multival, prepare_mask_batch, extend_to_batch_size from .model_injection import InjectionParams, ModelPatcherAndInjector, MotionModelGroup, MotionModelPatcher from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion, VanillaTemporalModule @@ -831,27 +832,31 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all" CONTEXTREF_MACHINE_STATE = "contextref_machine_state" CONTEXTREF_CLEAN_FUNC = "contextref_clean_func" - context_ref = False - context_ref_injector = None + contextref_active = False + contextref_injector = None + contextref_mode = None first_context = False # need to make sure that contextref stuff gets cleaned up, no matter what try: if ADGS.params.context_options.extras.should_run_context_ref(): - context_ref = True + contextref_active = True first_context = True + contextref_mode = ADGS.params.context_options.extras.context_ref.mode # use injector to ensure only 1 cond or uncond will be batched at a time - context_ref_injector = ContextRefInjector() - context_ref_injector.inject() + contextref_injector = ContextRefInjector() + contextref_injector.inject() - naive_init = False + curr_window_idx = -1 + naivereuse_active = False cached_naive_conds = None cached_naive_ctx_idxs = None if ADGS.params.context_options.extras.should_run_naive_reuse(): cached_naive_conds = [torch.zeros_like(x_in) for _ in conds] #cached_naive_counts = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] - naive_init = True + naivereuse_active = True # perform calc_conds_batch per context window for ctx_idxs in context_windows: + curr_window_idx += 1 ADGS.params.sub_idxs = ctx_idxs if ADGS.motion_models is not None: ADGS.motion_models.set_sub_idxs(ctx_idxs) @@ -869,16 +874,20 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list sub_timestep = timestep[full_idxs] sub_conds = [get_resized_cond(cond, full_idxs, len(ctx_idxs)) for cond in conds] - if context_ref: + if contextref_active: # set cond counter to 0 (each cond encountered will increment it by 1) model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL][0].contextref_cond_idx = 0 if first_context: first_context = False - model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "write" + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.WRITE else: - model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "read" + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ + if contextref_mode.mode == ContextRefMode.SLIDING: # if sliding, check if time to READ and WRITE + if curr_window_idx % (contextref_mode.sliding_width-1) == 0: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ_WRITE else: - model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = "off" + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.OFF + #logger.info(f"window: {curr_window_idx} - {model_options['transformer_options'][CONTEXTREF_MACHINE_STATE]}") sub_conds_out = calc_cond_uncond_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options) @@ -904,16 +913,16 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list for i in range(len(sub_conds_out)): conds_final[i][full_idxs] += sub_conds_out[i] * weights_tensor counts_final[i][full_idxs] += weights_tensor - if naive_init: + if naivereuse_active: cached_naive_ctx_idxs = ctx_idxs for i in range(len(sub_conds)): cached_naive_conds[i][full_idxs] = conds_final[i][full_idxs] / counts_final[i][full_idxs] - naive_init = False + naivereuse_active = False finally: # clean contextref stuff with provided ACN function, if applicable - if context_ref: + if contextref_active: model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC]() - context_ref_injector.restore() + contextref_injector.restore() # handle NaiveReuse if cached_naive_conds is not None: diff --git a/animatediff/utils_model.py b/animatediff/utils_model.py index 0e62d00..48c432b 100644 --- a/animatediff/utils_model.py +++ b/animatediff/utils_model.py @@ -26,6 +26,12 @@ MAX_RESOLUTION = 16384 # mirrors ComfyUI's nodes.py MAX_RESOLUTION +class MachineState: + READ = "read" + WRITE = "write" + READ_WRITE = "read_write" + OFF = "off" + def vae_encode_raw_dynamic_batched(vae: VAE, pixels: Tensor, max_batch=16, min_batch=1, max_size=512*512, show_pbar=False): b, h, w, c = pixels.shape From 0ba022c775126976908f2504f13ce176aec05554 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 2 Aug 2024 17:18:00 -0500 Subject: [PATCH 08/22] Added ContextRef Tune nodes to modify values of underlying ReferenceAdvanced ACN object --- animatediff/context_extras.py | 24 ++++++++++---- animatediff/nodes.py | 7 +++- animatediff/nodes_context.py | 60 +++++++++++++++++++++++++++++++++-- animatediff/sampling.py | 3 +- 4 files changed, 83 insertions(+), 11 deletions(-) diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py index 350ce9c..9265ade 100644 --- a/animatediff/context_extras.py +++ b/animatediff/context_extras.py @@ -34,16 +34,26 @@ def cleanup(self): # ContextRef class ContextRefParams: def __init__(self, - attn_style_fidelity=0.0, attn_ref_weight=0.0, attn_atrength=0.0, + attn_style_fidelity=0.0, attn_ref_weight=0.0, attn_strength=0.0, adain_style_fidelity=0.0, adain_ref_weight=0.0, adain_strength=0.0): # attn1 - self.attn_style_fidelity = attn_style_fidelity - self.attn_ref_weight = attn_ref_weight - self.attn_strength = attn_atrength + self.attn_style_fidelity = float(attn_style_fidelity) + self.attn_ref_weight = float(attn_ref_weight) + self.attn_strength = float(attn_strength) # adain - self.adain_style_fidelity = adain_style_fidelity - self.adain_ref_weight = adain_ref_weight - self.adain_strength = adain_strength + self.adain_style_fidelity = float(adain_style_fidelity) + self.adain_ref_weight = float(adain_ref_weight) + self.adain_strength = float(adain_strength) + + def create_dict(self): + return { + "attn_style_fidelity": self.attn_style_fidelity, + "attn_ref_weight": self.attn_ref_weight, + "attn_strength": self.attn_strength, + "adain_style_fidelity": self.adain_style_fidelity, + "adain_ref_weight": self.adain_ref_weight, + "adain_strength": self.adain_strength, + } class ContextRefMode: diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 4796f90..3e5edff 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -30,7 +30,8 @@ StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom, SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef, - ContextRef_ModeFirst, ContextRef_ModeSliding) + ContextRef_ModeFirst, ContextRef_ModeSliding, + ContextRef_TuneAttn, ContextRef_TuneAttnAdain) from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode, WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode, WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode) @@ -78,6 +79,8 @@ "ADE_ContextExtras_NaiveReuse": ContextExtras_NaiveReuse, "ADE_ContextExtras_ContextRef_ModeFirst": ContextRef_ModeFirst, "ADE_ContextExtras_ContextRef_ModeSliding": ContextRef_ModeSliding, + "ADE_ContextExtras_ContextRef_TuneAttn": ContextRef_TuneAttn, + "ADE_ContextExtras_ContextRef_TuneAttnAdain": ContextRef_TuneAttnAdain, #------------------------------------------------------------------------------ ############################################################################### # Iteration Opts @@ -218,6 +221,8 @@ "ADE_ContextExtras_NaiveReuse": "Context Extrasβ—†NaiveReuse πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_ModeFirst": "ContextRef Modeβ—†First πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_ModeSliding": "ContextRef Modeβ—†Sliding πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_TuneAttn": "ContextRef Tuneβ—†Attn πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_TuneAttnAdain": "ContextRef Tuneβ—†Attn+Adain πŸŽ­πŸ…πŸ…“", #------------------------------------------------------------------------------ ############################################################################### # Iteration Opts diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index 37b6c46..717b07e 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -511,6 +511,7 @@ def INPUT_TYPES(s): "prev_extras": ("CONTEXT_EXTRAS",), "strength_multival": ("MULTIVAL",), "contextref_mode": ("CONTEXTREF_MODE",), + "contextref_tune": ("CONTEXTREF_TUNE",), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001}), "autosize": ("ADEAUTOSIZE", {"padding": 0}), @@ -523,16 +524,18 @@ def INPUT_TYPES(s): def create_context_extra(self, start_percent=0.0, end_percent=0.1, strength_multival: Union[float, Tensor]=None, contextref_mode: ContextRefMode=None, + contextref_tune: ContextRefParams=None, prev_extras: ContextExtrasGroup=None): if prev_extras is None: prev_extras = prev_extras = ContextExtrasGroup() prev_extras = prev_extras.clone() # create extra # TODO: make customizable, and allow mask input - params = ContextRefParams(attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_atrength=1.0) + if contextref_tune is None: + contextref_tune = ContextRefParams(attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0) if contextref_mode is None: contextref_mode = ContextRefMode.init_first() - context_ref = ContextRef(start_percent=start_percent, end_percent=end_percent, params=params, mode=contextref_mode) + context_ref = ContextRef(start_percent=start_percent, end_percent=end_percent, params=contextref_tune, mode=contextref_mode) prev_extras.add(context_ref) return (prev_extras,) @@ -576,3 +579,56 @@ def INPUT_TYPES(s): def create_contextref_mode(self, sliding_width): mode = ContextRefMode.init_sliding(sliding_width=sliding_width) return (mode,) + + +class ContextRef_TuneAttnAdain: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "attn_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "adain_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "adain_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "adain_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "autosize": ("ADEAUTOSIZE", {"padding": 65}), + } + } + + RETURN_TYPES = ("CONTEXTREF_TUNE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/ContextRef" + FUNCTION = "create_contextref_tune" + + def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0, + adain_style_fidelity=1.0, adain_ref_weight=1.0, adain_strength=1.0): + params = ContextRefParams(attn_style_fidelity=attn_style_fidelity, adain_style_fidelity=adain_style_fidelity, + attn_ref_weight=attn_ref_weight, adain_ref_weight=adain_ref_weight, + attn_strength=attn_strength, adain_strength=adain_strength) + return (params,) + + +class ContextRef_TuneAttn: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "attn_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "autosize": ("ADEAUTOSIZE", {"padding": 15}), + } + } + + RETURN_TYPES = ("CONTEXTREF_TUNE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/ContextRef" + FUNCTION = "create_contextref_tune" + + def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0): + return ContextRef_TuneAttnAdain.create_contextref_tune(self, + attn_style_fidelity=attn_style_fidelity, attn_ref_weight=attn_ref_weight, attn_strength=attn_strength, + adain_ref_weight=0.0, adain_style_fidelity=0.0, adain_strength=0.0) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 87aedf4..5d91048 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -876,7 +876,8 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list if contextref_active: # set cond counter to 0 (each cond encountered will increment it by 1) - model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL][0].contextref_cond_idx = 0 + for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: + refcn.contextref_cond_idx = 0 if first_context: first_context = False model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.WRITE From 2ffa1ef8e5d1e3086ed16f98cbe51108023e6f80 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 4 Aug 2024 16:25:02 -0500 Subject: [PATCH 09/22] Added curr_t checks for Context Ketframes, AD Keyframes, and CustomCFG Keyframes to work as expected for tiling --- animatediff/context.py | 7 +++++++ animatediff/model_injection.py | 7 +++++++ animatediff/sample_settings.py | 6 ++++++ 3 files changed, 20 insertions(+) diff --git a/animatediff/context.py b/animatediff/context.py index c0514f0..8ba6dcf 100644 --- a/animatediff/context.py +++ b/animatediff/context.py @@ -82,12 +82,14 @@ def __init__(self): self._current_context: ContextOptions = None self._current_used_steps: int = 0 self._current_index: int = 0 + self._previous_t = -1 self._step = 0 def reset(self): self._current_context = None self._current_used_steps = 0 self._current_index = 0 + self._previous_t = -1 self.step = 0 self._set_first_as_current() self.extras.cleanup() @@ -145,6 +147,9 @@ def prepare_current(self, t: Tensor): def prepare_current_context(self, t: Tensor): curr_t: float = t[0] + # if same as previous, do nothing as step already accounted for + if curr_t == self._previous_t: + return prev_index = self._current_index # if met guaranteed steps, look for next context in case need to switch if self._current_used_steps >= self._current_context.guarantee_steps: @@ -166,6 +171,8 @@ def prepare_current_context(self, t: Tensor): break # update steps current context is used self._current_used_steps += 1 + # update previous_t + self._previous_t = curr_t def _set_first_as_current(self): if len(self.contexts) > 0: diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 0e5bf14..c40697d 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -751,6 +751,7 @@ def __init__(self, *args, **kwargs): self.current_used_steps = 0 self.current_keyframe: ADKeyframe = None self.current_index = -1 + self.previous_t = -1 self.current_scale: Union[float, Tensor] = None self.current_effect: Union[float, Tensor] = None self.current_cameractrl_effect: Union[float, Tensor] = None @@ -803,6 +804,9 @@ def initialize_timesteps(self, model: BaseModel): def prepare_current_keyframe(self, x: Tensor, t: Tensor): curr_t: float = t[0] + # if curr_t was previous_t, then do nothing (already accounted for this step) + if curr_t == self.previous_t: + return prev_index = self.current_index # if met guaranteed steps, look for next keyframe in case need to switch if self.current_keyframe is None or self.current_used_steps >= self.current_keyframe.guarantee_steps: @@ -862,6 +866,8 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor): self.was_within_range = True # update steps current keyframe is used self.current_used_steps += 1 + # update previous_t + self.previous_t = curr_t def prepare_img_features(self, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str], latent_format): # if no img_encoder, done @@ -1006,6 +1012,7 @@ def cleanup(self): self.current_used_steps = 0 self.current_keyframe = None self.current_index = -1 + self.previous_t = -1 self.current_scale = None self.current_effect = None self.combined_scale = None diff --git a/animatediff/sample_settings.py b/animatediff/sample_settings.py index e14dc72..bff8445 100644 --- a/animatediff/sample_settings.py +++ b/animatediff/sample_settings.py @@ -546,6 +546,7 @@ def __init__(self): self._current_keyframe: CustomCFGKeyframe = None self._current_used_steps: int = 0 self._current_index: int = 0 + self._previous_t = -1 def reset(self): self._current_keyframe = None @@ -584,6 +585,9 @@ def initialize_timesteps(self, model: BaseModel): def prepare_current_keyframe(self, t: Tensor): curr_t: float = t[0] + # if curr_t same as before, do nothing as step already accounted for + if curr_t == self._previous_t: + return prev_index = self._current_index # if met guaranteed steps, look for next keyframe in case need to switch if self._current_used_steps >= self._current_keyframe.guarantee_steps: @@ -604,6 +608,8 @@ def prepare_current_keyframe(self, t: Tensor): else: break # update steps current context is used self._current_used_steps += 1 + # update previous_t + self._previous_t = curr_t def get_cfg_scale(self, cond: Tensor): cond_scale = self.cfg_multival From 425641235215e21392102093aa7d17f16f6793d5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 4 Aug 2024 21:11:29 -0500 Subject: [PATCH 10/22] Add ContextRef Mode - Indexes node --- animatediff/context_extras.py | 13 +++-- animatediff/nodes.py | 4 +- animatediff/nodes_context.py | 26 ++++++++++ animatediff/sampling.py | 26 ++++++++-- animatediff/utils_scheduling.py | 90 +++++++++++++++++++++++++++++++++ 5 files changed, 152 insertions(+), 7 deletions(-) create mode 100644 animatediff/utils_scheduling.py diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py index 9265ade..af22e6b 100644 --- a/animatediff/context_extras.py +++ b/animatediff/context_extras.py @@ -59,19 +59,26 @@ def create_dict(self): class ContextRefMode: FIRST = "first" SLIDING = "sliding" - _LIST = [FIRST, SLIDING] + INDEXES = "indexes" + _LIST = [FIRST, SLIDING, INDEXES] - def __init__(self, mode: str, sliding_width=2): + def __init__(self, mode: str, sliding_width=2, indexes: set[int]=set([0])): self.mode = mode self.sliding_width = sliding_width + self.indexes = indexes + self.single_trigger = True @classmethod def init_first(cls): return ContextRefMode(cls.FIRST) @classmethod - def init_sliding(cls, sliding_width): + def init_sliding(cls, sliding_width: int): return ContextRefMode(cls.SLIDING, sliding_width=sliding_width) + + @classmethod + def init_indexes(cls, indexes: set[int]): + return ContextRefMode(cls.INDEXES, indexes=indexes) class ContextRef(ContextExtra): diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 3e5edff..0b73761 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -30,7 +30,7 @@ StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom, SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef, - ContextRef_ModeFirst, ContextRef_ModeSliding, + ContextRef_ModeFirst, ContextRef_ModeSliding, ContextRef_ModeIndexes, ContextRef_TuneAttn, ContextRef_TuneAttnAdain) from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode, WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode, @@ -79,6 +79,7 @@ "ADE_ContextExtras_NaiveReuse": ContextExtras_NaiveReuse, "ADE_ContextExtras_ContextRef_ModeFirst": ContextRef_ModeFirst, "ADE_ContextExtras_ContextRef_ModeSliding": ContextRef_ModeSliding, + "ADE_ContextExtras_ContextRef_ModeIndexes": ContextRef_ModeIndexes, "ADE_ContextExtras_ContextRef_TuneAttn": ContextRef_TuneAttn, "ADE_ContextExtras_ContextRef_TuneAttnAdain": ContextRef_TuneAttnAdain, #------------------------------------------------------------------------------ @@ -221,6 +222,7 @@ "ADE_ContextExtras_NaiveReuse": "Context Extrasβ—†NaiveReuse πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_ModeFirst": "ContextRef Modeβ—†First πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_ModeSliding": "ContextRef Modeβ—†Sliding πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_ModeIndexes": "ContextRef Modeβ—†Indexes πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_TuneAttn": "ContextRef Tuneβ—†Attn πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_TuneAttnAdain": "ContextRef Tuneβ—†Attn+Adain πŸŽ­πŸ…πŸ…“", #------------------------------------------------------------------------------ diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index 717b07e..481e5c0 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -9,6 +9,7 @@ generate_context_visualization) from .context_extras import ContextExtrasGroup, ContextRef, ContextRefParams, ContextRefMode, NaiveReuse from .utils_model import BIGMAX, MAX_RESOLUTION +from .utils_scheduling import convert_str_to_indexes LENGTH_MAX = 128 # keep an eye on these max values; @@ -581,6 +582,31 @@ def create_contextref_mode(self, sliding_width): return (mode,) +class ContextRef_ModeIndexes: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "switch_on_idxs": ("STRING", {"default": ""}), + "always_include_0": ("BOOLEAN", {"default": True},), + "autosize": ("ADEAUTOSIZE", {"padding": 50}), + }, + } + + RETURN_TYPES = ("CONTEXTREF_MODE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/ContextRef" + FUNCTION = "create_contextref_mode" + + def create_contextref_mode(self, switch_on_idxs: str, always_include_0: bool): + idxs = set(convert_str_to_indexes(indexes_str=switch_on_idxs, length=0, allow_range=False)) + if always_include_0 and 0 not in idxs: + idxs.add(0) + mode = ContextRefMode.init_indexes(indexes=idxs) + return (mode,) + + class ContextRef_TuneAttnAdain: @classmethod def INPUT_TYPES(s): diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 5d91048..6c289b2 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -835,13 +835,14 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list contextref_active = False contextref_injector = None contextref_mode = None - first_context = False + contextref_idxs_set = None + first_context = True # need to make sure that contextref stuff gets cleaned up, no matter what try: if ADGS.params.context_options.extras.should_run_context_ref(): contextref_active = True - first_context = True contextref_mode = ADGS.params.context_options.extras.context_ref.mode + contextref_idxs_set = contextref_mode.indexes.copy() # use injector to ensure only 1 cond or uncond will be batched at a time contextref_injector = ContextRefInjector() contextref_injector.inject() @@ -879,13 +880,28 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: refcn.contextref_cond_idx = 0 if first_context: - first_context = False model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.WRITE else: model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ if contextref_mode.mode == ContextRefMode.SLIDING: # if sliding, check if time to READ and WRITE if curr_window_idx % (contextref_mode.sliding_width-1) == 0: model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ_WRITE + # override with indexes mode, if set + if contextref_mode.mode == ContextRefMode.INDEXES: + contains_idx = False + for i in ctx_idxs: + if i in contextref_idxs_set: + contains_idx = True + # single trigger decides if each index should only trigger READ_WRITE once per step + if not contextref_mode.single_trigger: + break + contextref_idxs_set.remove(i) + if contains_idx: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ_WRITE + if first_context: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.WRITE + else: + model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ else: model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.OFF #logger.info(f"window: {curr_window_idx} - {model_options['transformer_options'][CONTEXTREF_MACHINE_STATE]}") @@ -914,11 +930,15 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list for i in range(len(sub_conds_out)): conds_final[i][full_idxs] += sub_conds_out[i] * weights_tensor counts_final[i][full_idxs] += weights_tensor + # handle NaiveReuse if naivereuse_active: cached_naive_ctx_idxs = ctx_idxs for i in range(len(sub_conds)): cached_naive_conds[i][full_idxs] = conds_final[i][full_idxs] / counts_final[i][full_idxs] naivereuse_active = False + # toggle first_context off, if needed + if first_context: + first_context = False finally: # clean contextref stuff with provided ACN function, if applicable if contextref_active: diff --git a/animatediff/utils_scheduling.py b/animatediff/utils_scheduling.py new file mode 100644 index 0000000..92ac5ab --- /dev/null +++ b/animatediff/utils_scheduling.py @@ -0,0 +1,90 @@ +from typing import Union + +from torch import Tensor + + +def validate_index(index: int, length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False) -> int: + # if part of range, do nothing + if is_range: + return index + # otherwise, validate index + # validate not out of range - only when latent_count is passed in + if length > 0 and index > length-1 and not allow_missing: + raise IndexError(f"Index '{index}' out of range for {length} item(s).") + # if negative, validate not out of range + if index < 0: + if not allow_negative: + raise IndexError(f"Negative indeces not allowed, but was '{index}'.") + conv_index = length+index + if conv_index < 0 and not allow_missing: + raise IndexError(f"Index '{index}', converted to '{conv_index}' out of range for {length} item(s).") + index = conv_index + return index + + +def convert_to_index_int(raw_index: str, length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False) -> int: + try: + return validate_index(int(raw_index), length=length, is_range=is_range, allow_negative=allow_negative, allow_missing=allow_missing) + except ValueError as e: + raise ValueError(f"Index '{raw_index}' must be an integer.", e) + + +def convert_str_to_indexes(indexes_str: str, length: int=0, allow_range=True, allow_missing=False) -> list[int]: + if not indexes_str: + return [] + int_indexes = list(range(0, length)) + allow_negative = length > 0 + chosen_indexes = [] + # parse string - allow positive ints, negative ints, and ranges separated by ':' + groups = indexes_str.split(",") + groups = [g.strip() for g in groups] + for g in groups: + # parse range of indeces (e.g. 2:16) + if ':' in g: + if not allow_range: + raise Exception("Ranges (:) not allowed for this input.") + index_range = g.split(":", 2) + index_range = [r.strip() for r in index_range] + + start_index = index_range[0] + if len(start_index) > 0: + start_index = convert_to_index_int(start_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing) + else: + start_index = 0 + end_index = index_range[1] + if len(end_index) > 0: + end_index = convert_to_index_int(end_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing) + else: + end_index = length + # support step as well, to allow things like reversing, every-other, etc. + step = 1 + if len(index_range) > 2: + step = index_range[2] + if len(step) > 0: + step = convert_to_index_int(step, length=length, is_range=True, allow_negative=True, allow_missing=True) + else: + step = 1 + # if latents were passed in, base indeces on known latent count + if len(int_indexes) > 0: + chosen_indexes.extend(int_indexes[start_index:end_index][::step]) + # otherwise, assume indeces are valid + else: + chosen_indexes.extend(list(range(start_index, end_index, step))) + # parse individual indeces + else: + chosen_indexes.append(convert_to_index_int(g, length=length, allow_negative=allow_negative, allow_missing=allow_missing)) + return chosen_indexes + + +def select_indexes(input_obj: Union[Tensor, list], idxs: list): + if type(input_obj) == Tensor: + return input_obj[idxs] + else: + return [input_obj[i] for i in idxs] + + +def select_indexes_from_str(input_obj: Union[Tensor, list], indexes: str, allow_range=True, err_if_missing=True, err_if_empty=True): + real_idxs = convert_str_to_indexes(indexes, len(input_obj), allow_range=allow_range, allow_missing=not err_if_missing) + if err_if_empty and len(real_idxs) == 0: + raise Exception(f"Nothing was selected based on indexes found in '{indexes}'.") + return select_indexes(input_obj, real_idxs) From 9dee45bc22ab673b303ff0479ffb381c065f6073 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 4 Aug 2024 21:28:52 -0500 Subject: [PATCH 11/22] Remove pointless batched_conds calculations in sliding_calc_conds_batch; this is a wrapper around cond batching, so there never is any batching in this layer --- animatediff/sampling.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 6c289b2..cd56a3d 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -813,8 +813,6 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list # get context windows ADGS.params.context_options.step = ADGS.current_step context_windows = get_context_windows(ADGS.params.full_length, ADGS.params.context_options) - # figure out how input is split - batched_conds = x_in.size(0)//ADGS.params.full_length if ADGS.motion_models is not None: ADGS.motion_models.set_view_options(ADGS.params.context_options.view_options) @@ -865,15 +863,10 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list # update exposed params model_options["transformer_options"]["ad_params"]["sub_idxs"] = ctx_idxs model_options["transformer_options"]["ad_params"]["context_length"] = len(ctx_idxs) - # account for all portions of input frames - full_idxs = [] - for n in range(batched_conds): - for ind in ctx_idxs: - full_idxs.append((ADGS.params.full_length*n)+ind) # get subsections of x, timestep, conds - sub_x = x_in[full_idxs] - sub_timestep = timestep[full_idxs] - sub_conds = [get_resized_cond(cond, full_idxs, len(ctx_idxs)) for cond in conds] + sub_x = x_in[ctx_idxs] + sub_timestep = timestep[ctx_idxs] + sub_conds = [get_resized_cond(cond, ctx_idxs, len(ctx_idxs)) for cond in conds] if contextref_active: # set cond counter to 0 (each cond encountered will increment it by 1) @@ -915,26 +908,24 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list bias = 1 - abs(idx - (ctx_idxs[0] + ctx_idxs[-1]) / 2) / ((ctx_idxs[-1] - ctx_idxs[0] + 1e-2) / 2) bias = max(1e-2, bias) # take weighted average relative to total bias of current idx - # and account for batched_conds for i in range(len(sub_conds_out)): - for n in range(batched_conds): - bias_total = biases_final[i][(full_length*n)+idx] - prev_weight = (bias_total / (bias_total + bias)) - new_weight = (bias / (bias_total + bias)) - conds_final[i][(full_length*n)+idx] = conds_final[i][(full_length*n)+idx] * prev_weight + sub_conds_out[i][(full_length*n)+pos] * new_weight - biases_final[i][(full_length*n)+idx] = bias_total + bias + bias_total = biases_final[i][idx] + prev_weight = (bias_total / (bias_total + bias)) + new_weight = (bias / (bias_total + bias)) + conds_final[i][idx] = conds_final[i][idx] * prev_weight + sub_conds_out[i][pos] * new_weight + biases_final[i][idx] = bias_total + bias else: # add conds and counts based on weights of fuse method - weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method, sigma=timestep) * batched_conds + weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method, sigma=timestep) weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) for i in range(len(sub_conds_out)): - conds_final[i][full_idxs] += sub_conds_out[i] * weights_tensor - counts_final[i][full_idxs] += weights_tensor + conds_final[i][ctx_idxs] += sub_conds_out[i] * weights_tensor + counts_final[i][ctx_idxs] += weights_tensor # handle NaiveReuse if naivereuse_active: cached_naive_ctx_idxs = ctx_idxs for i in range(len(sub_conds)): - cached_naive_conds[i][full_idxs] = conds_final[i][full_idxs] / counts_final[i][full_idxs] + cached_naive_conds[i][ctx_idxs] = conds_final[i][ctx_idxs] / counts_final[i][ctx_idxs] naivereuse_active = False # toggle first_context off, if needed if first_context: From f051ca9944f6d56fd5390bc9d9ec7d7cd0988d17 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 4 Aug 2024 22:00:10 -0500 Subject: [PATCH 12/22] Workflows can be cancelled between context window execution, instead of only per-step --- animatediff/sampling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index cd56a3d..950da57 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -855,6 +855,8 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list naivereuse_active = True # perform calc_conds_batch per context window for ctx_idxs in context_windows: + # allow processing to end between context window executions for faster Cancel + comfy.model_management.throw_exception_if_processing_interrupted() curr_window_idx += 1 ADGS.params.sub_idxs = ctx_idxs if ADGS.motion_models is not None: From 6167eee9f4e66dc2adbd1b514a7d3905290c8d55 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 5 Aug 2024 04:32:22 -0500 Subject: [PATCH 13/22] Added basic NaiveReuse Keyframe nodes --- animatediff/context_extras.py | 139 +++++++++++++++++++++++++++++++--- animatediff/nodes.py | 11 ++- animatediff/nodes_context.py | 101 ++++++++++++++++++++++-- animatediff/utils_motion.py | 30 ++++++-- 4 files changed, 252 insertions(+), 29 deletions(-) diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py index af22e6b..59438b5 100644 --- a/animatediff/context_extras.py +++ b/animatediff/context_extras.py @@ -1,8 +1,12 @@ +from typing import Union +import math +import torch from torch import Tensor from comfy.model_base import BaseModel -from .utils_motion import prepare_mask_batch, extend_to_batch_size, get_combined_multival +from .utils_motion import (prepare_mask_batch, extend_to_batch_size, get_combined_multival, resize_multival, + get_sorted_list_via_attr) class ContextExtra: @@ -93,34 +97,147 @@ def should_run(self): ################################ -# NaiveReuse +# NaiveReuse +class NaiveReuseKeyframe: + def __init__(self, mult_multival: Union[float, Tensor], start_percent=0.0, guarantee_steps=1): + self.mult_multival = mult_multival + # scheduling + self.start_percent = float(start_percent) + self.start_t = 999999999.9 + self.guarantee_steps = guarantee_steps + + def clone(self): + c = NaiveReuseKeyframe(mult_multival=self.mult_multival, + start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) + c.start_t = self.start_t + return c + +class NaiveReuseKeyframeGroup: + def __init__(self): + self.keyframes: list[NaiveReuseKeyframe] = [] + self._current_keyframe: NaiveReuseKeyframe = None + self._current_used_steps: int = 0 + self._current_index: int = 0 + self._previous_t = -1 + + def reset(self): + self._current_keyframe = None + self._current_used_steps = 0 + self._current_index = 0 + self._set_first_as_current() + + def add(self, keyframe: NaiveReuseKeyframe): + # add to end of list, then sort + self.keyframes.append(keyframe) + self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") + self._set_first_as_current() + + def _set_first_as_current(self): + if len(self.keyframes) > 0: + self._current_keyframe = self.keyframes[0] + else: + self._current_keyframe = None + + def has_index(self, index: int) -> int: + return index >=0 and index < len(self.keyframes) + + def is_empty(self) -> bool: + return len(self.keyframes) == 0 + + def clone(self): + cloned = NaiveReuseKeyframeGroup() + for keyframe in self.keyframes: + cloned.keyframes.append(keyframe) + cloned._set_first_as_current() + return cloned + + def initialize_timesteps(self, model: BaseModel): + for keyframe in self.keyframes: + keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) + + def prepare_current_keyframe(self, t: Tensor): + curr_t: float = t[0] + # if curr_t same as before, do nothing as step already accounted for + if curr_t == self._previous_t: + return + prev_index = self._current_index + # if met guaranteed steps, look for next keyframe in case need to switch + if self._current_used_steps >= self._current_keyframe.guarantee_steps: + # if has next index, loop through and see if need t oswitch + if self.has_index(self._current_index+1): + for i in range(self._current_index+1, len(self.keyframes)): + eval_c = self.keyframes[i] + # check if start_t is greater or equal to curr_t + # NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling + if eval_c.start_t >= curr_t: + self._current_index = i + self._current_keyframe = eval_c + self._current_used_steps = 0 + # if guarantee_steps greater than zero, stop searching for other keyframes + if self._current_keyframe.guarantee_steps > 0: + break + # if eval_c is outside the percent range, stop looking further + else: break + # update steps current context is used + self._current_used_steps += 1 + # update previous_t + self._previous_t = curr_t + + # properties shadow those of NaiveReuseKeyframe + @property + def mult_multival(self): + if self._current_keyframe != None: + return self._current_keyframe.mult_multival + return None + class NaiveReuse(ContextExtra): - def __init__(self, start_percent: float, end_percent: float, weighted_mean: float, multival_opt: Tensor=None): + def __init__(self, start_percent: float, end_percent: float, weighted_mean: float, multival_opt: Union[float, Tensor]=None, naivereuse_kf: NaiveReuseKeyframeGroup=None): super().__init__(start_percent=start_percent, end_percent=end_percent) self.weighted_mean = weighted_mean self.orig_multival = multival_opt self.mask: Tensor = None + self.keyframe = naivereuse_kf if naivereuse_kf else NaiveReuseKeyframeGroup() + self._prev_keyframe = None def cleanup(self): super().cleanup() del self.mask self.mask = None + self._prev_keyframe = None + self.keyframe.reset() + + def initialize_timesteps(self, model: BaseModel): + super().initialize_timesteps(model) + self.keyframe.initialize_timesteps(model) + + def prepare_current(self, t: Tensor): + super().prepare_current(t) + self.keyframe.prepare_current_keyframe(t) def get_effective_weighted_mean(self, x: Tensor, idxs: list[int]): - if self.orig_multival is None: + if self.orig_multival is None and self.keyframe.mult_multival is None: return self.weighted_mean - # otherwise, is Tensor and should be extended to match dims and size of x; - # see if needs to be recalculated - if type(self.orig_multival) != Tensor: - return self.weighted_mean * self.orig_multival - elif self.mask is None or self.mask.shape[0] != x.shape[0] or self.mask.shape[-1] != x.shape[-1] or self.mask.shape[-2] != x.shape[-2]: + # check if keyframe changed + keyframe_changed = False + if self.keyframe._current_keyframe != self._prev_keyframe: + keyframe_changed = True + self._prev_keyframe = self.keyframe._current_keyframe + + if type(self.orig_multival) != Tensor and type(self.keyframe.mult_multival) != Tensor: + return self.weighted_mean * get_combined_multival(self.orig_multival, self.keyframe.mult_multival) + + if self.mask is None or keyframe_changed or self.mask.shape[0] != x.shape[0] or self.mask.shape[-1] != x.shape[-1] or self.mask.shape[-2] != x.shape[-2]: del self.mask - self.mask = prepare_mask_batch(self.orig_multival, x.shape) - self.mask = extend_to_batch_size(self.mask, x.shape[0]) + real_mult_multival = resize_multival(self.keyframe.mult_multival, batch_size=x.shape[0], height=x.shape[-1], width=x.shape[-2]) + self.mask = resize_multival(self.orig_multival, batch_size=x.shape[0], height=x.shape[-1], width=x.shape[-2]) + self.mask = get_combined_multival(self.mask, real_mult_multival) return self.weighted_mean * self.mask[idxs].to(dtype=x.dtype, device=x.device) def should_run(self): to_return = super().should_run() + # if keyframe has 0.0 val, should not run + if self.keyframe.mult_multival is not None and type(self.keyframe.mult_multival) != Tensor and math.isclose(self.keyframe.mult_multival, 0.0): + return False # if weighted_mean is 0.0, then reuse will take no effect anyway return to_return and self.weighted_mean > 0.0 #-------------------------------- diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 0b73761..e88b7d1 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -31,7 +31,8 @@ VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom, SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef, ContextRef_ModeFirst, ContextRef_ModeSliding, ContextRef_ModeIndexes, - ContextRef_TuneAttn, ContextRef_TuneAttnAdain) + ContextRef_TuneAttn, ContextRef_TuneAttnAdain, + NaiveReuse_KeyframeNode, NaiveReuse_KeyframeMultivalNode) from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode, WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode, WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode) @@ -76,12 +77,14 @@ # Context Extras "ADE_ContextExtras_Set": SetContextExtrasOnContextOptions, "ADE_ContextExtras_ContextRef": ContextExtras_ContextRef, - "ADE_ContextExtras_NaiveReuse": ContextExtras_NaiveReuse, "ADE_ContextExtras_ContextRef_ModeFirst": ContextRef_ModeFirst, "ADE_ContextExtras_ContextRef_ModeSliding": ContextRef_ModeSliding, "ADE_ContextExtras_ContextRef_ModeIndexes": ContextRef_ModeIndexes, "ADE_ContextExtras_ContextRef_TuneAttn": ContextRef_TuneAttn, "ADE_ContextExtras_ContextRef_TuneAttnAdain": ContextRef_TuneAttnAdain, + "ADE_ContextExtras_NaiveReuse": ContextExtras_NaiveReuse, + "ADE_ContextExtras_NaiveReuse_Keyframe": NaiveReuse_KeyframeNode, + "ADE_ContextExtras_NaiveReuse_KeyframeMultival": NaiveReuse_KeyframeMultivalNode, #------------------------------------------------------------------------------ ############################################################################### # Iteration Opts @@ -219,12 +222,14 @@ # Context Extras "ADE_ContextExtras_Set": "Set Context Extras πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef": "Context Extrasβ—†ContextRef πŸŽ­πŸ…πŸ…“", - "ADE_ContextExtras_NaiveReuse": "Context Extrasβ—†NaiveReuse πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_ModeFirst": "ContextRef Modeβ—†First πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_ModeSliding": "ContextRef Modeβ—†Sliding πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_ModeIndexes": "ContextRef Modeβ—†Indexes πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_TuneAttn": "ContextRef Tuneβ—†Attn πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_TuneAttnAdain": "ContextRef Tuneβ—†Attn+Adain πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_NaiveReuse": "Context Extrasβ—†NaiveReuse πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_NaiveReuse_Keyframe": "NaiveReuse Keyframe πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_NaiveReuse_KeyframeMultival": "NaiveReuse Keyframe [Multival] πŸŽ­πŸ…πŸ…“", #------------------------------------------------------------------------------ ############################################################################### # Iteration Opts diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index 481e5c0..983c8db 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -7,7 +7,8 @@ from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules, generate_context_visualization) -from .context_extras import ContextExtrasGroup, ContextRef, ContextRefParams, ContextRefMode, NaiveReuse +from .context_extras import (ContextExtrasGroup, ContextRef, ContextRefParams, ContextRefMode, NaiveReuse, + NaiveReuseKeyframe, NaiveReuseKeyframeGroup) from .utils_model import BIGMAX, MAX_RESOLUTION from .utils_scheduling import convert_str_to_indexes @@ -480,6 +481,7 @@ def INPUT_TYPES(s): "optional": { "prev_extras": ("CONTEXT_EXTRAS",), "strength_multival": ("MULTIVAL",), + "naivereuse_kf": ("NAIVEREUSE_KEYFRAME",), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), "weighted_mean": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.001}), @@ -492,16 +494,73 @@ def INPUT_TYPES(s): FUNCTION = "create_context_extra" def create_context_extra(self, start_percent=0.0, end_percent=0.1, weighted_mean=0.95, strength_multival: Union[float, Tensor]=None, - prev_extras: ContextExtrasGroup=None): + naivereuse_kf: NaiveReuseKeyframeGroup=None, prev_extras: ContextExtrasGroup=None): if prev_extras is None: prev_extras = prev_extras = ContextExtrasGroup() prev_extras = prev_extras.clone() # create extra - naive_reuse = NaiveReuse(start_percent=start_percent, end_percent=end_percent, weighted_mean=weighted_mean, multival_opt=strength_multival) + naive_reuse = NaiveReuse(start_percent=start_percent, end_percent=end_percent, weighted_mean=weighted_mean, multival_opt=strength_multival, + naivereuse_kf=naivereuse_kf) prev_extras.add(naive_reuse) return (prev_extras,) +class NaiveReuse_KeyframeMultivalNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mult_multival": ("MULTIVAL",), + }, + "optional": { + "prev_kf": ("NAIVEREUSE_KEYFRAME",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + "autosize": ("ADEAUTOSIZE", {"padding": 80}), + } + } + + RETURN_TYPES = ("NAIVEREUSE_KEYFRAME",) + RETURN_NAMES = ("NAIVEREUSE_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" + FUNCTION = "create_keyframe" + + def create_keyframe(self, prev_kf=None, mult_multival=1.0, + start_percent=0.0, guarantee_steps=1): + if prev_kf is None: + prev_kf = NaiveReuseKeyframeGroup() + prev_kf = prev_kf.clone() + kf = NaiveReuseKeyframe(mult_multival=mult_multival, start_percent=start_percent, guarantee_steps=guarantee_steps) + prev_kf.add(kf) + return (prev_kf,) + + +class NaiveReuse_KeyframeNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_kf": ("NAIVEREUSE_KEYFRAME",), + "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + "autosize": ("ADEAUTOSIZE", {"padding": 10}), + } + } + + RETURN_TYPES = ("NAIVEREUSE_KEYFRAME",) + RETURN_NAMES = ("NAIVEREUSE_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" + FUNCTION = "create_keyframe" + + def create_keyframe(self, prev_kf=None, mult=1.0, + start_percent=0.0, guarantee_steps=1): + return NaiveReuse_KeyframeMultivalNode.create_keyframe(self, prev_kf=prev_kf, mult_multival=float(mult), + start_percent=start_percent, guarantee_steps=guarantee_steps) + + class ContextExtras_ContextRef: @classmethod def INPUT_TYPES(s): @@ -541,6 +600,32 @@ def create_context_extra(self, start_percent=0.0, end_percent=0.1, strength_mult return (prev_extras,) +class ContextRef_KeyframeNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_keyframe": ("CONTEXTREF_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "mode_replace": ("CONTEXTREF_MODE",), + "tune_replace": ("CONTEXTREF_TUNE",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + } + } + + RETURN_TYPES = ("CONTEXTREF_KEYFRAME",) + RETURN_NAMES = ("CONTEXTREF_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_keyframe" + + def create_keyframe(self, prev_keyframe=None, mult_multival=1.0, mode_replace=None, tune_replace=None, + start_percent=1.0, guarantee_steps=1): + pass + + class ContextRef_ModeFirst: @classmethod def INPUT_TYPES(s): @@ -553,7 +638,7 @@ def INPUT_TYPES(s): } RETURN_TYPES = ("CONTEXTREF_MODE",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/ContextRef" + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" FUNCTION = "create_contextref_mode" def create_contextref_mode(self): @@ -574,7 +659,7 @@ def INPUT_TYPES(s): } RETURN_TYPES = ("CONTEXTREF_MODE",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/ContextRef" + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" FUNCTION = "create_contextref_mode" def create_contextref_mode(self, sliding_width): @@ -596,7 +681,7 @@ def INPUT_TYPES(s): } RETURN_TYPES = ("CONTEXTREF_MODE",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/ContextRef" + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" FUNCTION = "create_contextref_mode" def create_contextref_mode(self, switch_on_idxs: str, always_include_0: bool): @@ -625,7 +710,7 @@ def INPUT_TYPES(s): } RETURN_TYPES = ("CONTEXTREF_TUNE",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/ContextRef" + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" FUNCTION = "create_contextref_tune" def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0, @@ -651,7 +736,7 @@ def INPUT_TYPES(s): } RETURN_TYPES = ("CONTEXTREF_TUNE",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/ContextRef" + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" FUNCTION = "create_contextref_tune" def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0): diff --git a/animatediff/utils_motion.py b/animatediff/utils_motion.py index 0cf5b1a..8425f02 100644 --- a/animatediff/utils_motion.py +++ b/animatediff/utils_motion.py @@ -275,7 +275,7 @@ def create_multival_combo(float_val: Union[float, list[float]], mask_optional: T return mask_optional -def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor]) -> Union[float, Tensor]: +def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor], force_leader_A=False) -> Union[float, Tensor]: # if one is None, use the other if multivalA == None: return multivalB @@ -284,14 +284,18 @@ def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[floa # both have a value - combine them based on type # if both are Tensors, make dims match before multiplying if type(multivalA) == Tensor and type(multivalB) == Tensor: - areaA = multivalA.shape[1]*multivalA.shape[2] - areaB = multivalB.shape[1]*multivalB.shape[2] - # match height/width to mask with larger area - leader,follower = (multivalA,multivalB) if areaA >= areaB else (multivalB,multivalA) - batch_size = multivalA.shape[0] if multivalA.shape[0] >= multivalB.shape[0] else multivalB.shape[0] + if force_leader_A: + leader,follower = (multivalA,multivalB) + batch_size = multivalA.shape[0] + else: + areaA = multivalA.shape[1]*multivalA.shape[2] + areaB = multivalB.shape[1]*multivalB.shape[2] + # match height/width to mask with larger area + leader,follower = (multivalA,multivalB) if areaA >= areaB else (multivalB,multivalA) + batch_size = multivalA.shape[0] if multivalA.shape[0] >= multivalB.shape[0] else multivalB.shape[0] # make follower same dimensions as leader follower = torch.unsqueeze(follower, 1) - follower = comfy.utils.common_upscale(follower, leader.shape[2], leader.shape[1], "bilinear", "center") + follower = comfy.utils.common_upscale(follower, leader.shape[-1], leader.shape[-2], "bilinear", "center") follower = torch.squeeze(follower, 1) # make sure batch size will match leader = extend_to_batch_size(leader, batch_size) @@ -301,6 +305,18 @@ def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[floa return multivalA * multivalB +def resize_multival(multival: Union[float, Tensor], batch_size: int, height: int, width: int): + if multival == None: + return 1.0 + if type(multival) != Tensor: + return multival + multival = torch.unsqueeze(multival, 1) + multival = comfy.utils.common_upscale(multival, height, width, "bilinear", "center") + multival = torch.squeeze(multival, 1) + multival = extend_to_batch_size(multival, batch_size) + return multival + + def get_combined_input(inputA: Union[InputPIA, None], inputB: Union[InputPIA, None], x: Tensor): if inputA is None: inputA = InputPIA_Multival(1.0) From d4b328b0b01f39db6b8ef367c0857dcbedfa6b72 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 5 Aug 2024 14:33:14 -0500 Subject: [PATCH 14/22] Fix NaiveReuse throwing error if no KFs attached --- animatediff/context_extras.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py index 59438b5..a323a52 100644 --- a/animatediff/context_extras.py +++ b/animatediff/context_extras.py @@ -156,6 +156,8 @@ def initialize_timesteps(self, model: BaseModel): keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) def prepare_current_keyframe(self, t: Tensor): + if self.is_empty(): + return curr_t: float = t[0] # if curr_t same as before, do nothing as step already accounted for if curr_t == self._previous_t: From 0f6aa7c2512b66ec13006a98590d3c1667276aad Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 6 Aug 2024 05:34:16 -0500 Subject: [PATCH 15/22] Added ContextRef Keyframes, made strength_multival for ContextRef functional, removed redundant version of NaiveReuse Keyframe node, and instead mult_multival is on the normal node --- animatediff/context_extras.py | 179 ++++++++++++++++++++++++++++++++-- animatediff/nodes.py | 10 +- animatediff/nodes_context.py | 74 ++++++-------- animatediff/sampling.py | 20 ++-- 4 files changed, 220 insertions(+), 63 deletions(-) diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py index a323a52..3af7df6 100644 --- a/animatediff/context_extras.py +++ b/animatediff/context_extras.py @@ -9,6 +9,9 @@ get_sorted_list_via_attr) +CONTEXTREF_VERSION = 1 + + class ContextExtra: def __init__(self, start_percent: float, end_percent: float): # scheduling @@ -36,7 +39,7 @@ def cleanup(self): ################################ # ContextRef -class ContextRefParams: +class ContextRefTune: def __init__(self, attn_style_fidelity=0.0, attn_ref_weight=0.0, attn_strength=0.0, adain_style_fidelity=0.0, adain_ref_weight=0.0, adain_strength=0.0): @@ -85,11 +88,162 @@ def init_indexes(cls, indexes: set[int]): return ContextRefMode(cls.INDEXES, indexes=indexes) +class ContextRefKeyframe: + def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=None, tune_replace: ContextRefTune=None, mode_replace: ContextRefMode=None, + start_percent=0.0, guarantee_steps=1, inherit_missing=True): + self.mult = mult + self.orig_mult_multival = mult_multival + self.orig_tune_replace = tune_replace + self.orig_mode_replace = mode_replace + self.mult_multival = self.orig_mult_multival + self.tune_replace = self.orig_tune_replace + self.mode_replace = self.orig_mode_replace + # scheduling + self.start_percent = float(start_percent) + self.guarantee_steps = guarantee_steps + self.inherit_missing = inherit_missing + + def clone(self): + c = ContextRefKeyframe(mult=self.mult, mult_multival=self.orig_mult_multival, tune_replace=self.orig_tune_replace, mode_replace=self.orig_mode_replace, + start_percent=self.start_percent, guarantee_steps=self.guarantee_steps, inherit_missing=self.inherit_missing) + c.mult_multival = self.mult_multival + c.tune_replace = self.tune_replace + c.mode_replace = self.mode_replace + return c + + +class ContextRefKeyframeGroup: + def __init__(self): + self.keyframes: list[ContextRefKeyframe] = [] + self._current_keyframe: NaiveReuseKeyframe = None + self._current_used_steps: int = 0 + self._current_index: int = 0 + self._previous_t = -1 + + def reset(self): + self._current_keyframe = None + self._current_used_steps = 0 + self._current_index = 0 + self._set_first_as_current() + + def add(self, keyframe: ContextRefKeyframe): + # add to end of list, then sort + self.keyframes.append(keyframe) + self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") + self._set_first_as_current() + self._prepare_all_keyframe_vals() + + def _set_first_as_current(self): + if len(self.keyframes) > 0: + self._current_keyframe = self.keyframes[0] + else: + self._current_keyframe = None + + def _prepare_all_keyframe_vals(self): + if self.is_empty(): + return + multival = None + tune = None + mode = None + for kf in self.keyframes: + # if shouldn't inherit, clear cache + if not kf.inherit_missing: + multival = None + tune = None + mode = None + # assign cached values, if origs were None + # Mult ################# + if kf.orig_mult_multival is None: + kf.mult_multival = multival + else: + kf.mult_multival = kf.orig_mult_multival + # Tune ################# + if kf.orig_tune_replace is None: + kf.tune_replace = tune + else: + kf.tune_replace = kf.orig_tune_replace + # Mode ################# + if kf.orig_mode_replace is None: + kf.mode_replace = mode + else: + kf.mode_replace = kf.orig_mode_replace + # save new caches, in case next keyframe inherits missing + if kf.mult_multival is not None: + multival = kf.mult_multival + if kf.tune_replace is not None: + tune = kf.tune_replace + if kf.mode_replace is not None: + mode = kf.mode_replace + + def has_index(self, index: int) -> int: + return index >=0 and index < len(self.keyframes) + + def is_empty(self) -> bool: + return len(self.keyframes) == 0 + + def clone(self): + cloned = ContextRefKeyframeGroup() + for keyframe in self.keyframes: + cloned.keyframes.append(keyframe.clone()) + cloned._set_first_as_current() + cloned._prepare_all_keyframe_vals() + return cloned + + def create_list_of_dicts(self): + # for each keyframe, create a dict representing values relevant to TimestepKeyframe creation in ACN + c = [] + for kf in self.keyframes: + d = {} + # scheduling + d["start_percent"] = kf.start_percent + d["guarantee_steps"] = kf.guarantee_steps + d["inherit_missing"] = kf.inherit_missing + # values + if type(kf.mult_multival) == Tensor: + d["strength"] = kf.mult + d["mask"] = kf.mult_multival + else: + if kf.mult_multival is None: + d["strength"] = kf.mult + else: + d["strength"] = kf.mult * kf.mult_multival + d["mask"] = None + d["tune"] = kf.tune_replace + d["mode"] = kf.mode_replace + # add to list + c.append(d) + return c + + class ContextRef(ContextExtra): - def __init__(self, start_percent: float, end_percent: float, params: ContextRefParams, mode: ContextRefMode): + def __init__(self, start_percent: float, end_percent: float, + strength_multival: Union[float, Tensor], tune: ContextRefTune, mode: ContextRefMode, + keyframe: ContextRefKeyframeGroup=None): super().__init__(start_percent=start_percent, end_percent=end_percent) - self.params = params + self.tune = tune self.mode = mode + self.keyframe = keyframe if keyframe else ContextRefKeyframeGroup() + self.version = CONTEXTREF_VERSION + # stuff for ACN usage + self.strength = 1.0 + self.mask = None + self._strength_multival = strength_multival + self.strength_multival = strength_multival + + @property + def strength_multival(self): + return self.strength_multival + @strength_multival.setter + def strength_multival(self, value): + if value is None: + value = 1.0 + if type(value) == Tensor: + self.strength = 1.0 + self.mask = value + else: + self.strength = value + self.mask = None + self._strength_multival = value def should_run(self): return super().should_run() @@ -99,7 +253,8 @@ def should_run(self): ################################ # NaiveReuse class NaiveReuseKeyframe: - def __init__(self, mult_multival: Union[float, Tensor], start_percent=0.0, guarantee_steps=1): + def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=1.0, start_percent=0.0, guarantee_steps=1): + self.mult = mult self.mult_multival = mult_multival # scheduling self.start_percent = float(start_percent) @@ -107,7 +262,7 @@ def __init__(self, mult_multival: Union[float, Tensor], start_percent=0.0, guara self.guarantee_steps = guarantee_steps def clone(self): - c = NaiveReuseKeyframe(mult_multival=self.mult_multival, + c = NaiveReuseKeyframe(mult=self.mult, mult_multival=self.mult_multival, start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) c.start_t = self.start_t return c @@ -186,6 +341,12 @@ def prepare_current_keyframe(self, t: Tensor): self._previous_t = curr_t # properties shadow those of NaiveReuseKeyframe + @property + def mult(self): + if self._current_keyframe != None: + return self._current_keyframe.mult + return 1.0 + @property def mult_multival(self): if self._current_keyframe != None: @@ -218,7 +379,7 @@ def prepare_current(self, t: Tensor): def get_effective_weighted_mean(self, x: Tensor, idxs: list[int]): if self.orig_multival is None and self.keyframe.mult_multival is None: - return self.weighted_mean + return self.weighted_mean * self.keyframe.mult # check if keyframe changed keyframe_changed = False if self.keyframe._current_keyframe != self._prev_keyframe: @@ -226,14 +387,14 @@ def get_effective_weighted_mean(self, x: Tensor, idxs: list[int]): self._prev_keyframe = self.keyframe._current_keyframe if type(self.orig_multival) != Tensor and type(self.keyframe.mult_multival) != Tensor: - return self.weighted_mean * get_combined_multival(self.orig_multival, self.keyframe.mult_multival) + return self.weighted_mean * self.keyframe.mult * get_combined_multival(self.orig_multival, self.keyframe.mult_multival) if self.mask is None or keyframe_changed or self.mask.shape[0] != x.shape[0] or self.mask.shape[-1] != x.shape[-1] or self.mask.shape[-2] != x.shape[-2]: del self.mask real_mult_multival = resize_multival(self.keyframe.mult_multival, batch_size=x.shape[0], height=x.shape[-1], width=x.shape[-2]) self.mask = resize_multival(self.orig_multival, batch_size=x.shape[0], height=x.shape[-1], width=x.shape[-2]) self.mask = get_combined_multival(self.mask, real_mult_multival) - return self.weighted_mean * self.mask[idxs].to(dtype=x.dtype, device=x.device) + return self.weighted_mean * self.keyframe.mult * self.mask[idxs].to(dtype=x.dtype, device=x.device) def should_run(self): to_return = super().should_run() @@ -241,7 +402,7 @@ def should_run(self): if self.keyframe.mult_multival is not None and type(self.keyframe.mult_multival) != Tensor and math.isclose(self.keyframe.mult_multival, 0.0): return False # if weighted_mean is 0.0, then reuse will take no effect anyway - return to_return and self.weighted_mean > 0.0 + return to_return and self.weighted_mean > 0.0 and self.keyframe.mult > 0.0 #-------------------------------- diff --git a/animatediff/nodes.py b/animatediff/nodes.py index e88b7d1..a2beae2 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -31,8 +31,8 @@ VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom, SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef, ContextRef_ModeFirst, ContextRef_ModeSliding, ContextRef_ModeIndexes, - ContextRef_TuneAttn, ContextRef_TuneAttnAdain, - NaiveReuse_KeyframeNode, NaiveReuse_KeyframeMultivalNode) + ContextRef_TuneAttn, ContextRef_TuneAttnAdain, ContextRef_KeyframeMultivalNode, + NaiveReuse_KeyframeMultivalNode) from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode, WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode, WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode) @@ -82,9 +82,9 @@ "ADE_ContextExtras_ContextRef_ModeIndexes": ContextRef_ModeIndexes, "ADE_ContextExtras_ContextRef_TuneAttn": ContextRef_TuneAttn, "ADE_ContextExtras_ContextRef_TuneAttnAdain": ContextRef_TuneAttnAdain, + "ADE_ContextExtras_ContextRef_Keyframe": ContextRef_KeyframeMultivalNode, "ADE_ContextExtras_NaiveReuse": ContextExtras_NaiveReuse, - "ADE_ContextExtras_NaiveReuse_Keyframe": NaiveReuse_KeyframeNode, - "ADE_ContextExtras_NaiveReuse_KeyframeMultival": NaiveReuse_KeyframeMultivalNode, + "ADE_ContextExtras_NaiveReuse_Keyframe": NaiveReuse_KeyframeMultivalNode, #------------------------------------------------------------------------------ ############################################################################### # Iteration Opts @@ -227,9 +227,9 @@ "ADE_ContextExtras_ContextRef_ModeIndexes": "ContextRef Modeβ—†Indexes πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_TuneAttn": "ContextRef Tuneβ—†Attn πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_TuneAttnAdain": "ContextRef Tuneβ—†Attn+Adain πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_Keyframe": "ContextRef Keyframe πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_NaiveReuse": "Context Extrasβ—†NaiveReuse πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_NaiveReuse_Keyframe": "NaiveReuse Keyframe πŸŽ­πŸ…πŸ…“", - "ADE_ContextExtras_NaiveReuse_KeyframeMultival": "NaiveReuse Keyframe [Multival] πŸŽ­πŸ…πŸ…“", #------------------------------------------------------------------------------ ############################################################################### # Iteration Opts diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index 983c8db..bcd933b 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -7,9 +7,11 @@ from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules, generate_context_visualization) -from .context_extras import (ContextExtrasGroup, ContextRef, ContextRefParams, ContextRefMode, NaiveReuse, - NaiveReuseKeyframe, NaiveReuseKeyframeGroup) +from .context_extras import (ContextExtrasGroup, + ContextRef, ContextRefTune, ContextRefMode, ContextRefKeyframeGroup, ContextRefKeyframe, + NaiveReuse, NaiveReuseKeyframe, NaiveReuseKeyframeGroup) from .utils_model import BIGMAX, MAX_RESOLUTION +from .utils_motion import get_combined_multival from .utils_scheduling import convert_str_to_indexes @@ -510,13 +512,14 @@ class NaiveReuse_KeyframeMultivalNode: def INPUT_TYPES(s): return { "required": { - "mult_multival": ("MULTIVAL",), }, "optional": { "prev_kf": ("NAIVEREUSE_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), - "autosize": ("ADEAUTOSIZE", {"padding": 80}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -525,42 +528,16 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" FUNCTION = "create_keyframe" - def create_keyframe(self, prev_kf=None, mult_multival=1.0, + def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=1.0, start_percent=0.0, guarantee_steps=1): if prev_kf is None: prev_kf = NaiveReuseKeyframeGroup() prev_kf = prev_kf.clone() - kf = NaiveReuseKeyframe(mult_multival=mult_multival, start_percent=start_percent, guarantee_steps=guarantee_steps) + kf = NaiveReuseKeyframe(mult=mult, mult_multival=mult_multival, start_percent=start_percent, guarantee_steps=guarantee_steps) prev_kf.add(kf) return (prev_kf,) -class NaiveReuse_KeyframeNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "prev_kf": ("NAIVEREUSE_KEYFRAME",), - "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), - "autosize": ("ADEAUTOSIZE", {"padding": 10}), - } - } - - RETURN_TYPES = ("NAIVEREUSE_KEYFRAME",) - RETURN_NAMES = ("NAIVEREUSE_KF",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" - FUNCTION = "create_keyframe" - - def create_keyframe(self, prev_kf=None, mult=1.0, - start_percent=0.0, guarantee_steps=1): - return NaiveReuse_KeyframeMultivalNode.create_keyframe(self, prev_kf=prev_kf, mult_multival=float(mult), - start_percent=start_percent, guarantee_steps=guarantee_steps) - - class ContextExtras_ContextRef: @classmethod def INPUT_TYPES(s): @@ -572,6 +549,7 @@ def INPUT_TYPES(s): "strength_multival": ("MULTIVAL",), "contextref_mode": ("CONTEXTREF_MODE",), "contextref_tune": ("CONTEXTREF_TUNE",), + "contextref_kf": ("CONTEXTREF_KEYFRAME",), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001}), "autosize": ("ADEAUTOSIZE", {"padding": 0}), @@ -583,36 +561,40 @@ def INPUT_TYPES(s): FUNCTION = "create_context_extra" def create_context_extra(self, start_percent=0.0, end_percent=0.1, strength_multival: Union[float, Tensor]=None, - contextref_mode: ContextRefMode=None, - contextref_tune: ContextRefParams=None, - prev_extras: ContextExtrasGroup=None): + contextref_mode: ContextRefMode=None, contextref_tune: ContextRefTune=None, + contextref_kf: ContextRefKeyframeGroup=None, prev_extras: ContextExtrasGroup=None): if prev_extras is None: prev_extras = prev_extras = ContextExtrasGroup() prev_extras = prev_extras.clone() # create extra # TODO: make customizable, and allow mask input if contextref_tune is None: - contextref_tune = ContextRefParams(attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0) + contextref_tune = ContextRefTune(attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0) if contextref_mode is None: contextref_mode = ContextRefMode.init_first() - context_ref = ContextRef(start_percent=start_percent, end_percent=end_percent, params=contextref_tune, mode=contextref_mode) + context_ref = ContextRef(start_percent=start_percent, end_percent=end_percent, + strength_multival=strength_multival, tune=contextref_tune, mode=contextref_mode, + keyframe=contextref_kf) prev_extras.add(context_ref) return (prev_extras,) -class ContextRef_KeyframeNode: +class ContextRef_KeyframeMultivalNode: @classmethod def INPUT_TYPES(s): return { "required": { }, "optional": { - "prev_keyframe": ("CONTEXTREF_KEYFRAME",), + "prev_kf": ("CONTEXTREF_KEYFRAME",), "mult_multival": ("MULTIVAL",), "mode_replace": ("CONTEXTREF_MODE",), "tune_replace": ("CONTEXTREF_TUNE",), + "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + "inherit_missing": ("BOOLEAN", {"default": True}, ), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -621,9 +603,15 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" FUNCTION = "create_keyframe" - def create_keyframe(self, prev_keyframe=None, mult_multival=1.0, mode_replace=None, tune_replace=None, - start_percent=1.0, guarantee_steps=1): - pass + def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=1.0, mode_replace=None, tune_replace=None, + start_percent=1.0, guarantee_steps=1, inherit_missing=True): + if prev_kf is None: + prev_kf = ContextRefKeyframeGroup() + prev_kf = prev_kf.clone() + kf = ContextRefKeyframe(mult=mult, mult_multival=mult_multival, tune_replace=tune_replace, mode_replace=mode_replace, + start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing) + prev_kf.add(kf) + return (prev_kf,) class ContextRef_ModeFirst: @@ -715,7 +703,7 @@ def INPUT_TYPES(s): def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0, adain_style_fidelity=1.0, adain_ref_weight=1.0, adain_strength=1.0): - params = ContextRefParams(attn_style_fidelity=attn_style_fidelity, adain_style_fidelity=adain_style_fidelity, + params = ContextRefTune(attn_style_fidelity=attn_style_fidelity, adain_style_fidelity=adain_style_fidelity, attn_ref_weight=attn_ref_weight, adain_ref_weight=adain_ref_weight, attn_strength=attn_strength, adain_strength=adain_strength) return (params,) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 950da57..b4e2615 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -838,12 +838,20 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list # need to make sure that contextref stuff gets cleaned up, no matter what try: if ADGS.params.context_options.extras.should_run_context_ref(): - contextref_active = True - contextref_mode = ADGS.params.context_options.extras.context_ref.mode - contextref_idxs_set = contextref_mode.indexes.copy() - # use injector to ensure only 1 cond or uncond will be batched at a time - contextref_injector = ContextRefInjector() - contextref_injector.inject() + # check if ContextRef ReferenceAdvanced ACN objs should_run + actually_should_run = True + for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: + refcn.prepare_current_timestep(timestep) + if not refcn.should_run(): + actually_should_run = False + break + if actually_should_run: + contextref_active = True + contextref_mode = ADGS.params.context_options.extras.context_ref.mode + contextref_idxs_set = contextref_mode.indexes.copy() + # use injector to ensure only 1 cond or uncond will be batched at a time + contextref_injector = ContextRefInjector() + contextref_injector.inject() curr_window_idx = -1 naivereuse_active = False From 221b4a0190afcce1ca826eb60e7aab026c9eb486 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 6 Aug 2024 18:38:27 -0500 Subject: [PATCH 16/22] Moved Context Extras nodes to its own file --- animatediff/nodes.py | 4 +- animatediff/nodes_context.py | 291 --------------------------- animatediff/nodes_context_extras.py | 300 ++++++++++++++++++++++++++++ 3 files changed, 302 insertions(+), 293 deletions(-) create mode 100644 animatediff/nodes_context_extras.py diff --git a/animatediff/nodes.py b/animatediff/nodes.py index a2beae2..71ecdcb 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -28,8 +28,8 @@ from .nodes_sigma_schedule import (SigmaScheduleNode, RawSigmaScheduleNode, WeightedAverageSigmaScheduleNode, InterpolatedWeightedAverageSigmaScheduleNode, SplitAndCombineSigmaScheduleNode, SigmaScheduleToSigmasNode) from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode, StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, - VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom, - SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef, + VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom) +from .nodes_context_extras import (SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef, ContextRef_ModeFirst, ContextRef_ModeSliding, ContextRef_ModeIndexes, ContextRef_TuneAttn, ContextRef_TuneAttnAdain, ContextRef_KeyframeMultivalNode, NaiveReuse_KeyframeMultivalNode) diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index bcd933b..c14a8b8 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -1,4 +1,3 @@ -import torch from torch import Tensor from typing import Union @@ -7,12 +6,7 @@ from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules, generate_context_visualization) -from .context_extras import (ContextExtrasGroup, - ContextRef, ContextRefTune, ContextRefMode, ContextRefKeyframeGroup, ContextRefKeyframe, - NaiveReuse, NaiveReuseKeyframe, NaiveReuseKeyframeGroup) from .utils_model import BIGMAX, MAX_RESOLUTION -from .utils_motion import get_combined_multival -from .utils_scheduling import convert_str_to_indexes LENGTH_MAX = 128 # keep an eye on these max values; @@ -446,288 +440,3 @@ def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sigm images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sigmas=sigmas) return (images,) - - -######################### -# Context Extras -class SetContextExtrasOnContextOptions: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "context_opts": ("CONTEXT_OPTIONS",), - "context_extras": ("CONTEXT_EXTRAS",), - }, - "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 0}), - } - } - - RETURN_TYPES = ("CONTEXT_OPTIONS",) - RETURN_NAMES = ("CONTEXT_OPTS",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" - FUNCTION = "set_context_extras" - - def set_context_extras(self, context_opts: ContextOptionsGroup, context_extras: ContextExtrasGroup): - context_opts = context_opts.clone() - context_opts.extras = context_extras.clone() - return (context_opts,) - - -class ContextExtras_NaiveReuse: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "prev_extras": ("CONTEXT_EXTRAS",), - "strength_multival": ("MULTIVAL",), - "naivereuse_kf": ("NAIVEREUSE_KEYFRAME",), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), - "weighted_mean": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.001}), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), - } - } - - RETURN_TYPES = ("CONTEXT_EXTRAS",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" - FUNCTION = "create_context_extra" - - def create_context_extra(self, start_percent=0.0, end_percent=0.1, weighted_mean=0.95, strength_multival: Union[float, Tensor]=None, - naivereuse_kf: NaiveReuseKeyframeGroup=None, prev_extras: ContextExtrasGroup=None): - if prev_extras is None: - prev_extras = prev_extras = ContextExtrasGroup() - prev_extras = prev_extras.clone() - # create extra - naive_reuse = NaiveReuse(start_percent=start_percent, end_percent=end_percent, weighted_mean=weighted_mean, multival_opt=strength_multival, - naivereuse_kf=naivereuse_kf) - prev_extras.add(naive_reuse) - return (prev_extras,) - - -class NaiveReuse_KeyframeMultivalNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "prev_kf": ("NAIVEREUSE_KEYFRAME",), - "mult_multival": ("MULTIVAL",), - "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), - } - } - - RETURN_TYPES = ("NAIVEREUSE_KEYFRAME",) - RETURN_NAMES = ("NAIVEREUSE_KF",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" - FUNCTION = "create_keyframe" - - def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=1.0, - start_percent=0.0, guarantee_steps=1): - if prev_kf is None: - prev_kf = NaiveReuseKeyframeGroup() - prev_kf = prev_kf.clone() - kf = NaiveReuseKeyframe(mult=mult, mult_multival=mult_multival, start_percent=start_percent, guarantee_steps=guarantee_steps) - prev_kf.add(kf) - return (prev_kf,) - - -class ContextExtras_ContextRef: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "prev_extras": ("CONTEXT_EXTRAS",), - "strength_multival": ("MULTIVAL",), - "contextref_mode": ("CONTEXTREF_MODE",), - "contextref_tune": ("CONTEXTREF_TUNE",), - "contextref_kf": ("CONTEXTREF_KEYFRAME",), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001}), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), - } - } - - RETURN_TYPES = ("CONTEXT_EXTRAS",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" - FUNCTION = "create_context_extra" - - def create_context_extra(self, start_percent=0.0, end_percent=0.1, strength_multival: Union[float, Tensor]=None, - contextref_mode: ContextRefMode=None, contextref_tune: ContextRefTune=None, - contextref_kf: ContextRefKeyframeGroup=None, prev_extras: ContextExtrasGroup=None): - if prev_extras is None: - prev_extras = prev_extras = ContextExtrasGroup() - prev_extras = prev_extras.clone() - # create extra - # TODO: make customizable, and allow mask input - if contextref_tune is None: - contextref_tune = ContextRefTune(attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0) - if contextref_mode is None: - contextref_mode = ContextRefMode.init_first() - context_ref = ContextRef(start_percent=start_percent, end_percent=end_percent, - strength_multival=strength_multival, tune=contextref_tune, mode=contextref_mode, - keyframe=contextref_kf) - prev_extras.add(context_ref) - return (prev_extras,) - - -class ContextRef_KeyframeMultivalNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "prev_kf": ("CONTEXTREF_KEYFRAME",), - "mult_multival": ("MULTIVAL",), - "mode_replace": ("CONTEXTREF_MODE",), - "tune_replace": ("CONTEXTREF_TUNE",), - "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), - "inherit_missing": ("BOOLEAN", {"default": True}, ), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), - } - } - - RETURN_TYPES = ("CONTEXTREF_KEYFRAME",) - RETURN_NAMES = ("CONTEXTREF_KF",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" - FUNCTION = "create_keyframe" - - def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=1.0, mode_replace=None, tune_replace=None, - start_percent=1.0, guarantee_steps=1, inherit_missing=True): - if prev_kf is None: - prev_kf = ContextRefKeyframeGroup() - prev_kf = prev_kf.clone() - kf = ContextRefKeyframe(mult=mult, mult_multival=mult_multival, tune_replace=tune_replace, mode_replace=mode_replace, - start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing) - prev_kf.add(kf) - return (prev_kf,) - - -class ContextRef_ModeFirst: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 25}), - }, - } - - RETURN_TYPES = ("CONTEXTREF_MODE",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" - FUNCTION = "create_contextref_mode" - - def create_contextref_mode(self): - mode = ContextRefMode.init_first() - return (mode,) - - -class ContextRef_ModeSliding: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "sliding_width": ("INT", {"default": 2, "min": 2, "max": BIGMAX, "step": 1}), - "autosize": ("ADEAUTOSIZE", {"padding": 42}), - } - } - - RETURN_TYPES = ("CONTEXTREF_MODE",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" - FUNCTION = "create_contextref_mode" - - def create_contextref_mode(self, sliding_width): - mode = ContextRefMode.init_sliding(sliding_width=sliding_width) - return (mode,) - - -class ContextRef_ModeIndexes: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "switch_on_idxs": ("STRING", {"default": ""}), - "always_include_0": ("BOOLEAN", {"default": True},), - "autosize": ("ADEAUTOSIZE", {"padding": 50}), - }, - } - - RETURN_TYPES = ("CONTEXTREF_MODE",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" - FUNCTION = "create_contextref_mode" - - def create_contextref_mode(self, switch_on_idxs: str, always_include_0: bool): - idxs = set(convert_str_to_indexes(indexes_str=switch_on_idxs, length=0, allow_range=False)) - if always_include_0 and 0 not in idxs: - idxs.add(0) - mode = ContextRefMode.init_indexes(indexes=idxs) - return (mode,) - - -class ContextRef_TuneAttnAdain: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "attn_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "adain_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "adain_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "adain_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "autosize": ("ADEAUTOSIZE", {"padding": 65}), - } - } - - RETURN_TYPES = ("CONTEXTREF_TUNE",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" - FUNCTION = "create_contextref_tune" - - def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0, - adain_style_fidelity=1.0, adain_ref_weight=1.0, adain_strength=1.0): - params = ContextRefTune(attn_style_fidelity=attn_style_fidelity, adain_style_fidelity=adain_style_fidelity, - attn_ref_weight=attn_ref_weight, adain_ref_weight=adain_ref_weight, - attn_strength=attn_strength, adain_strength=adain_strength) - return (params,) - - -class ContextRef_TuneAttn: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "attn_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "autosize": ("ADEAUTOSIZE", {"padding": 15}), - } - } - - RETURN_TYPES = ("CONTEXTREF_TUNE",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" - FUNCTION = "create_contextref_tune" - - def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0): - return ContextRef_TuneAttnAdain.create_contextref_tune(self, - attn_style_fidelity=attn_style_fidelity, attn_ref_weight=attn_ref_weight, attn_strength=attn_strength, - adain_ref_weight=0.0, adain_style_fidelity=0.0, adain_strength=0.0) diff --git a/animatediff/nodes_context_extras.py b/animatediff/nodes_context_extras.py new file mode 100644 index 0000000..6675632 --- /dev/null +++ b/animatediff/nodes_context_extras.py @@ -0,0 +1,300 @@ +from torch import Tensor +from typing import Union + +from .context import (ContextOptionsGroup) +from .context_extras import (ContextExtrasGroup, + ContextRef, ContextRefTune, ContextRefMode, ContextRefKeyframeGroup, ContextRefKeyframe, + NaiveReuse, NaiveReuseKeyframe, NaiveReuseKeyframeGroup) +from .utils_model import BIGMAX +from .utils_scheduling import convert_str_to_indexes + + +class SetContextExtrasOnContextOptions: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "context_opts": ("CONTEXT_OPTIONS",), + "context_extras": ("CONTEXT_EXTRAS",), + }, + "optional": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("CONTEXT_OPTIONS",) + RETURN_NAMES = ("CONTEXT_OPTS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" + FUNCTION = "set_context_extras" + + def set_context_extras(self, context_opts: ContextOptionsGroup, context_extras: ContextExtrasGroup): + context_opts = context_opts.clone() + context_opts.extras = context_extras.clone() + return (context_opts,) + + +######################################### +# NaiveReuse +class ContextExtras_NaiveReuse: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_extras": ("CONTEXT_EXTRAS",), + "strength_multival": ("MULTIVAL",), + "naivereuse_kf": ("NAIVEREUSE_KEYFRAME",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), + "weighted_mean": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.001}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("CONTEXT_EXTRAS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" + FUNCTION = "create_context_extra" + + def create_context_extra(self, start_percent=0.0, end_percent=0.1, weighted_mean=0.95, strength_multival: Union[float, Tensor]=None, + naivereuse_kf: NaiveReuseKeyframeGroup=None, prev_extras: ContextExtrasGroup=None): + if prev_extras is None: + prev_extras = prev_extras = ContextExtrasGroup() + prev_extras = prev_extras.clone() + # create extra + naive_reuse = NaiveReuse(start_percent=start_percent, end_percent=end_percent, weighted_mean=weighted_mean, multival_opt=strength_multival, + naivereuse_kf=naivereuse_kf) + prev_extras.add(naive_reuse) + return (prev_extras,) + + +class NaiveReuse_KeyframeMultivalNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_kf": ("NAIVEREUSE_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("NAIVEREUSE_KEYFRAME",) + RETURN_NAMES = ("NAIVEREUSE_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" + FUNCTION = "create_keyframe" + + def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=1.0, + start_percent=0.0, guarantee_steps=1): + if prev_kf is None: + prev_kf = NaiveReuseKeyframeGroup() + prev_kf = prev_kf.clone() + kf = NaiveReuseKeyframe(mult=mult, mult_multival=mult_multival, start_percent=start_percent, guarantee_steps=guarantee_steps) + prev_kf.add(kf) + return (prev_kf,) +#---------------------------------------- +######################################### + + +######################################### +# ContextRef +class ContextExtras_ContextRef: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_extras": ("CONTEXT_EXTRAS",), + "strength_multival": ("MULTIVAL",), + "contextref_mode": ("CONTEXTREF_MODE",), + "contextref_tune": ("CONTEXTREF_TUNE",), + "contextref_kf": ("CONTEXTREF_KEYFRAME",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("CONTEXT_EXTRAS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" + FUNCTION = "create_context_extra" + + def create_context_extra(self, start_percent=0.0, end_percent=0.1, strength_multival: Union[float, Tensor]=None, + contextref_mode: ContextRefMode=None, contextref_tune: ContextRefTune=None, + contextref_kf: ContextRefKeyframeGroup=None, prev_extras: ContextExtrasGroup=None): + if prev_extras is None: + prev_extras = prev_extras = ContextExtrasGroup() + prev_extras = prev_extras.clone() + # create extra + # TODO: make customizable, and allow mask input + if contextref_tune is None: + contextref_tune = ContextRefTune(attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0) + if contextref_mode is None: + contextref_mode = ContextRefMode.init_first() + context_ref = ContextRef(start_percent=start_percent, end_percent=end_percent, + strength_multival=strength_multival, tune=contextref_tune, mode=contextref_mode, + keyframe=contextref_kf) + prev_extras.add(context_ref) + return (prev_extras,) + + +class ContextRef_KeyframeMultivalNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "prev_kf": ("CONTEXTREF_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "mode_replace": ("CONTEXTREF_MODE",), + "tune_replace": ("CONTEXTREF_TUNE",), + "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + "inherit_missing": ("BOOLEAN", {"default": True}, ), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("CONTEXTREF_KEYFRAME",) + RETURN_NAMES = ("CONTEXTREF_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_keyframe" + + def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=1.0, mode_replace=None, tune_replace=None, + start_percent=1.0, guarantee_steps=1, inherit_missing=True): + if prev_kf is None: + prev_kf = ContextRefKeyframeGroup() + prev_kf = prev_kf.clone() + kf = ContextRefKeyframe(mult=mult, mult_multival=mult_multival, tune_replace=tune_replace, mode_replace=mode_replace, + start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing) + prev_kf.add(kf) + return (prev_kf,) + + +class ContextRef_ModeFirst: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "autosize": ("ADEAUTOSIZE", {"padding": 25}), + }, + } + + RETURN_TYPES = ("CONTEXTREF_MODE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_contextref_mode" + + def create_contextref_mode(self): + mode = ContextRefMode.init_first() + return (mode,) + + +class ContextRef_ModeSliding: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "sliding_width": ("INT", {"default": 2, "min": 2, "max": BIGMAX, "step": 1}), + "autosize": ("ADEAUTOSIZE", {"padding": 42}), + } + } + + RETURN_TYPES = ("CONTEXTREF_MODE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_contextref_mode" + + def create_contextref_mode(self, sliding_width): + mode = ContextRefMode.init_sliding(sliding_width=sliding_width) + return (mode,) + + +class ContextRef_ModeIndexes: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "switch_on_idxs": ("STRING", {"default": ""}), + "always_include_0": ("BOOLEAN", {"default": True},), + "autosize": ("ADEAUTOSIZE", {"padding": 50}), + }, + } + + RETURN_TYPES = ("CONTEXTREF_MODE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_contextref_mode" + + def create_contextref_mode(self, switch_on_idxs: str, always_include_0: bool): + idxs = set(convert_str_to_indexes(indexes_str=switch_on_idxs, length=0, allow_range=False)) + if always_include_0 and 0 not in idxs: + idxs.add(0) + mode = ContextRefMode.init_indexes(indexes=idxs) + return (mode,) + + +class ContextRef_TuneAttnAdain: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "attn_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "adain_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "adain_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "adain_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "autosize": ("ADEAUTOSIZE", {"padding": 65}), + } + } + + RETURN_TYPES = ("CONTEXTREF_TUNE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_contextref_tune" + + def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0, + adain_style_fidelity=1.0, adain_ref_weight=1.0, adain_strength=1.0): + params = ContextRefTune(attn_style_fidelity=attn_style_fidelity, adain_style_fidelity=adain_style_fidelity, + attn_ref_weight=attn_ref_weight, adain_ref_weight=adain_ref_weight, + attn_strength=attn_strength, adain_strength=adain_strength) + return (params,) + + +class ContextRef_TuneAttn: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "attn_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "autosize": ("ADEAUTOSIZE", {"padding": 15}), + } + } + + RETURN_TYPES = ("CONTEXTREF_TUNE",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_contextref_tune" + + def create_contextref_tune(self, attn_style_fidelity=1.0, attn_ref_weight=1.0, attn_strength=1.0): + return ContextRef_TuneAttnAdain.create_contextref_tune(self, + attn_style_fidelity=attn_style_fidelity, attn_ref_weight=attn_ref_weight, attn_strength=attn_strength, + adain_ref_weight=0.0, adain_style_fidelity=0.0, adain_strength=0.0) +#---------------------------------------- +######################################### From 0ecd0ac82c3f9a02d3b0b0c253fb7583f8e04a97 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 6 Aug 2024 20:05:12 -0500 Subject: [PATCH 17/22] Added ContextRef Keyframe Interp. node, renamed Interpolation to Interp. in node names to decrease visual node size --- animatediff/nodes.py | 10 +++-- animatediff/nodes_conditioning.py | 1 + animatediff/nodes_context_extras.py | 58 ++++++++++++++++++++++++++++- animatediff/nodes_sample.py | 1 + animatediff/nodes_sigma_schedule.py | 3 ++ 5 files changed, 67 insertions(+), 6 deletions(-) diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 71ecdcb..e8a0a57 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -31,7 +31,7 @@ VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom) from .nodes_context_extras import (SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef, ContextRef_ModeFirst, ContextRef_ModeSliding, ContextRef_ModeIndexes, - ContextRef_TuneAttn, ContextRef_TuneAttnAdain, ContextRef_KeyframeMultivalNode, + ContextRef_TuneAttn, ContextRef_TuneAttnAdain, ContextRef_KeyframeMultivalNode, ContextRef_KeyframeInterpolationNode, NaiveReuse_KeyframeMultivalNode) from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode, WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode, @@ -83,6 +83,7 @@ "ADE_ContextExtras_ContextRef_TuneAttn": ContextRef_TuneAttn, "ADE_ContextExtras_ContextRef_TuneAttnAdain": ContextRef_TuneAttnAdain, "ADE_ContextExtras_ContextRef_Keyframe": ContextRef_KeyframeMultivalNode, + "ADE_ContextExtras_ContextRef_KeyframeInterpolation": ContextRef_KeyframeInterpolationNode, "ADE_ContextExtras_NaiveReuse": ContextExtras_NaiveReuse, "ADE_ContextExtras_NaiveReuse_Keyframe": NaiveReuse_KeyframeMultivalNode, #------------------------------------------------------------------------------ @@ -228,6 +229,7 @@ "ADE_ContextExtras_ContextRef_TuneAttn": "ContextRef Tuneβ—†Attn πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_TuneAttnAdain": "ContextRef Tuneβ—†Attn+Adain πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_Keyframe": "ContextRef Keyframe πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_KeyframeInterpolation": "ContextRef Keyframe Interp. πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_NaiveReuse": "Context Extrasβ—†NaiveReuse πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_NaiveReuse_Keyframe": "NaiveReuse Keyframe πŸŽ­πŸ…πŸ…“", #------------------------------------------------------------------------------ @@ -246,7 +248,7 @@ "ADE_SetLoraHookKeyframe": "Set LoRA Hook Keyframes πŸŽ­πŸ…πŸ…“", "ADE_AttachLoraHookToCLIP": "Set CLIP LoRA Hook πŸŽ­πŸ…πŸ…“", "ADE_LoraHookKeyframe": "LoRA Hook Keyframe πŸŽ­πŸ…πŸ…“", - "ADE_LoraHookKeyframeInterpolation": "LoRA Hook Keyframes Interpolation πŸŽ­πŸ…πŸ…“", + "ADE_LoraHookKeyframeInterpolation": "LoRA Hook Keyframes Interp. πŸŽ­πŸ…πŸ…“", "ADE_LoraHookKeyframeFromStrengthList": "LoRA Hook Keyframes From List πŸŽ­πŸ…πŸ…“", "ADE_AttachLoraHookToConditioning": "Set Model LoRA Hook πŸŽ­πŸ…πŸ…“", "ADE_PairedConditioningSetMask": "Set Props on Conds πŸŽ­πŸ…πŸ…“", @@ -278,7 +280,7 @@ "ADE_CustomCFG": "Custom CFG [Multival] πŸŽ­πŸ…πŸ…“", "ADE_CustomCFGKeyframeSimple": "Custom CFG Keyframe πŸŽ­πŸ…πŸ…“", "ADE_CustomCFGKeyframe": "Custom CFG Keyframe [Multival] πŸŽ­πŸ…πŸ…“", - "ADE_CustomCFGKeyframeInterpolation": "Custom CFG Keyframes Interpolation πŸŽ­πŸ…πŸ…“", + "ADE_CustomCFGKeyframeInterpolation": "Custom CFG Keyframes Interp. πŸŽ­πŸ…πŸ…“", "ADE_CustomCFGKeyframeFromList": "Custom CFG Keyframes From List πŸŽ­πŸ…πŸ…“", "ADE_CFGExtrasPAGSimple": "CFG Extrasβ—†PAG πŸŽ­πŸ…πŸ…“", "ADE_CFGExtrasPAG": "CFG Extrasβ—†PAG [Multival] πŸŽ­πŸ…πŸ…“", @@ -287,7 +289,7 @@ "ADE_SigmaSchedule": "Create Sigma Schedule πŸŽ­πŸ…πŸ…“", "ADE_RawSigmaSchedule": "Create Raw Sigma Schedule πŸŽ­πŸ…πŸ…“", "ADE_SigmaScheduleWeightedAverage": "Sigma Schedule Weighted Mean πŸŽ­πŸ…πŸ…“", - "ADE_SigmaScheduleWeightedAverageInterp": "Sigma Schedule Interpolated Mean πŸŽ­πŸ…πŸ…“", + "ADE_SigmaScheduleWeightedAverageInterp": "Sigma Schedule Interp. Mean πŸŽ­πŸ…πŸ…“", "ADE_SigmaScheduleSplitAndCombine": "Sigma Schedule Split Combine πŸŽ­πŸ…πŸ…“", "ADE_SigmaScheduleToSigmas": "Sigma Schedule To Sigmas πŸŽ­πŸ…πŸ…“", "ADE_NoisedImageInjection": "Image Injection πŸŽ­πŸ…πŸ…“", diff --git a/animatediff/nodes_conditioning.py b/animatediff/nodes_conditioning.py index 9c694b2..57f1233 100644 --- a/animatediff/nodes_conditioning.py +++ b/animatediff/nodes_conditioning.py @@ -337,6 +337,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_hook_kf": ("LORA_HOOK_KEYFRAMES",), + "autosize": ("ADEAUTOSIZE", {"padding": 70}), } } diff --git a/animatediff/nodes_context_extras.py b/animatediff/nodes_context_extras.py index 6675632..e9860df 100644 --- a/animatediff/nodes_context_extras.py +++ b/animatediff/nodes_context_extras.py @@ -5,8 +5,9 @@ from .context_extras import (ContextExtrasGroup, ContextRef, ContextRefTune, ContextRefMode, ContextRefKeyframeGroup, ContextRefKeyframe, NaiveReuse, NaiveReuseKeyframe, NaiveReuseKeyframeGroup) -from .utils_model import BIGMAX +from .utils_model import BIGMAX, InterpolationMethod from .utils_scheduling import convert_str_to_indexes +from .logger import logger class SetContextExtrasOnContextOptions: @@ -168,7 +169,8 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" FUNCTION = "create_keyframe" - def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=1.0, mode_replace=None, tune_replace=None, + def create_keyframe(self, prev_kf: ContextRefKeyframeGroup=None, + mult=1.0, mult_multival=1.0, mode_replace=None, tune_replace=None, start_percent=1.0, guarantee_steps=1, inherit_missing=True): if prev_kf is None: prev_kf = ContextRefKeyframeGroup() @@ -179,6 +181,58 @@ def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=1.0, mode_replac return (prev_kf,) +class ContextRef_KeyframeInterpolationNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "mult_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "mult_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "interpolation": (InterpolationMethod._LIST, ), + "intervals": ("INT", {"default": 50, "min": 2, "max": 100, "step": 1}), + "inherit_missing": ("BOOLEAN", {"default": True}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_kf": ("CONTEXTREF_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "mode_replace": ("CONTEXTREF_MODE",), + "tune_replace": ("CONTEXTREF_TUNE",), + "autosize": ("ADEAUTOSIZE", {"padding": 50}), + } + } + + RETURN_TYPES = ("CONTEXTREF_KEYFRAME",) + RETURN_NAMES = ("CONTEXTREF_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_keyframe" + + def create_keyframe(self, + start_percent: float, end_percent: float, + mult_start: float, mult_end: float, interpolation: str, intervals: int, + inherit_missing=True, prev_kf: ContextRefKeyframeGroup=None, + mult_multival=1.0, mode_replace=None, tune_replace=None, print_keyframes=False): + if prev_kf is None: + prev_kf = ContextRefKeyframeGroup() + prev_kf = prev_kf.clone() + percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=intervals, method=InterpolationMethod.LINEAR) + mults = InterpolationMethod.get_weights(num_from=mult_start, num_to=mult_end, length=intervals, method=interpolation) + + is_first = True + for percent, mult in zip(percents, mults): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_kf.add(ContextRefKeyframe(mult=mult, mult_multival=mult_multival, tune_replace=tune_replace, mode_replace=mode_replace, + start_percent=percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing)) + if print_keyframes: + logger.info(f"ContextRefKeyframe - start_percent:{percent} = {mult}") + return (prev_kf,) + + class ContextRef_ModeFirst: @classmethod def INPUT_TYPES(s): diff --git a/animatediff/nodes_sample.py b/animatediff/nodes_sample.py index d988a62..70f4f1c 100644 --- a/animatediff/nodes_sample.py +++ b/animatediff/nodes_sample.py @@ -337,6 +337,7 @@ def INPUT_TYPES(s): "optional": { "prev_custom_cfg": ("CUSTOM_CFG",), "cfg_extras": ("CFG_EXTRAS",), + "autosize": ("ADEAUTOSIZE", {"padding": 70}), } } diff --git a/animatediff/nodes_sigma_schedule.py b/animatediff/nodes_sigma_schedule.py index 20580a2..828e2ca 100644 --- a/animatediff/nodes_sigma_schedule.py +++ b/animatediff/nodes_sigma_schedule.py @@ -101,6 +101,9 @@ def INPUT_TYPES(s): "weight_A_Start": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}), "weight_A_End": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}), "interpolation": (InterpolationMethod._LIST,), + }, + "optional": { + "autosize": ("ADEAUTOSIZE", {"padding": 70}), } } From fd7f0ffa24db2ba7782a72dbd440212b0174ad19 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 6 Aug 2024 21:57:25 -0500 Subject: [PATCH 18/22] Added ContextRef Keyframe From List node, added NaiveReuse Keyframe Interp. and From List nodes, added inherit_missing to NaiveReuse Keyframes --- animatediff/context_extras.py | 31 ++++- animatediff/nodes.py | 13 ++- animatediff/nodes_context_extras.py | 168 +++++++++++++++++++++++++++- animatediff/utils_motion.py | 8 +- 4 files changed, 204 insertions(+), 16 deletions(-) diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py index 3af7df6..2f19856 100644 --- a/animatediff/context_extras.py +++ b/animatediff/context_extras.py @@ -106,9 +106,6 @@ def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=None, tune_repl def clone(self): c = ContextRefKeyframe(mult=self.mult, mult_multival=self.orig_mult_multival, tune_replace=self.orig_tune_replace, mode_replace=self.orig_mode_replace, start_percent=self.start_percent, guarantee_steps=self.guarantee_steps, inherit_missing=self.inherit_missing) - c.mult_multival = self.mult_multival - c.tune_replace = self.tune_replace - c.mode_replace = self.mode_replace return c @@ -253,13 +250,15 @@ def should_run(self): ################################ # NaiveReuse class NaiveReuseKeyframe: - def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=1.0, start_percent=0.0, guarantee_steps=1): + def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=None, start_percent=0.0, guarantee_steps=1, inherit_missing=True): self.mult = mult + self.orig_mult_multival = mult_multival self.mult_multival = mult_multival # scheduling self.start_percent = float(start_percent) self.start_t = 999999999.9 self.guarantee_steps = guarantee_steps + self.inherit_missing = inherit_missing def clone(self): c = NaiveReuseKeyframe(mult=self.mult, mult_multival=self.mult_multival, @@ -267,6 +266,7 @@ def clone(self): c.start_t = self.start_t return c + class NaiveReuseKeyframeGroup: def __init__(self): self.keyframes: list[NaiveReuseKeyframe] = [] @@ -286,13 +286,32 @@ def add(self, keyframe: NaiveReuseKeyframe): self.keyframes.append(keyframe) self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") self._set_first_as_current() + self._prepare_all_keyframe_vals() def _set_first_as_current(self): if len(self.keyframes) > 0: self._current_keyframe = self.keyframes[0] else: self._current_keyframe = None - + + def _prepare_all_keyframe_vals(self): + if self.is_empty(): + return + multival = None + for kf in self.keyframes: + # if shouldn't inherit, clear cache + if not kf.inherit_missing: + multival = None + # assign cached values, if origs were None + # Mult ################# + if kf.orig_mult_multival is None: + kf.mult_multival = multival + else: + kf.mult_multival = kf.orig_mult_multival + # save new caches, in case next keyframe inherits missing + if kf.mult_multival is not None: + multival = kf.mult_multival + def has_index(self, index: int) -> int: return index >=0 and index < len(self.keyframes) @@ -304,6 +323,7 @@ def clone(self): for keyframe in self.keyframes: cloned.keyframes.append(keyframe) cloned._set_first_as_current() + cloned._prepare_all_keyframe_vals() return cloned def initialize_timesteps(self, model: BaseModel): @@ -353,6 +373,7 @@ def mult_multival(self): return self._current_keyframe.mult_multival return None + class NaiveReuse(ContextExtra): def __init__(self, start_percent: float, end_percent: float, weighted_mean: float, multival_opt: Union[float, Tensor]=None, naivereuse_kf: NaiveReuseKeyframeGroup=None): super().__init__(start_percent=start_percent, end_percent=end_percent) diff --git a/animatediff/nodes.py b/animatediff/nodes.py index e8a0a57..69a5496 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -31,8 +31,9 @@ VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom) from .nodes_context_extras import (SetContextExtrasOnContextOptions, ContextExtras_NaiveReuse, ContextExtras_ContextRef, ContextRef_ModeFirst, ContextRef_ModeSliding, ContextRef_ModeIndexes, - ContextRef_TuneAttn, ContextRef_TuneAttnAdain, ContextRef_KeyframeMultivalNode, ContextRef_KeyframeInterpolationNode, - NaiveReuse_KeyframeMultivalNode) + ContextRef_TuneAttn, ContextRef_TuneAttnAdain, + ContextRef_KeyframeMultivalNode, ContextRef_KeyframeInterpolationNode, ContextRef_KeyframeFromListNode, + NaiveReuse_KeyframeMultivalNode, NaiveReuse_KeyframeInterpolationNode, NaiveReuse_KeyframeFromListNode) from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode, WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode, WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode) @@ -84,8 +85,11 @@ "ADE_ContextExtras_ContextRef_TuneAttnAdain": ContextRef_TuneAttnAdain, "ADE_ContextExtras_ContextRef_Keyframe": ContextRef_KeyframeMultivalNode, "ADE_ContextExtras_ContextRef_KeyframeInterpolation": ContextRef_KeyframeInterpolationNode, + "ADE_ContextExtras_ContextRef_KeyframeFromList": ContextRef_KeyframeFromListNode, "ADE_ContextExtras_NaiveReuse": ContextExtras_NaiveReuse, "ADE_ContextExtras_NaiveReuse_Keyframe": NaiveReuse_KeyframeMultivalNode, + "ADE_ContextExtras_NaiveReuse_KeyframeInterpolation": NaiveReuse_KeyframeInterpolationNode, + "ADE_ContextExtras_NaiveReuse_KeyframeFromList": NaiveReuse_KeyframeFromListNode, #------------------------------------------------------------------------------ ############################################################################### # Iteration Opts @@ -229,9 +233,12 @@ "ADE_ContextExtras_ContextRef_TuneAttn": "ContextRef Tuneβ—†Attn πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_TuneAttnAdain": "ContextRef Tuneβ—†Attn+Adain πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_ContextRef_Keyframe": "ContextRef Keyframe πŸŽ­πŸ…πŸ…“", - "ADE_ContextExtras_ContextRef_KeyframeInterpolation": "ContextRef Keyframe Interp. πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_KeyframeInterpolation": "ContextRef Keyframes Interp. πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_ContextRef_KeyframeFromList": "ContextRef Keyframes From List πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_NaiveReuse": "Context Extrasβ—†NaiveReuse πŸŽ­πŸ…πŸ…“", "ADE_ContextExtras_NaiveReuse_Keyframe": "NaiveReuse Keyframe πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_NaiveReuse_KeyframeInterpolation": "NaiveReuse Keyframes Interp. πŸŽ­πŸ…πŸ…“", + "ADE_ContextExtras_NaiveReuse_KeyframeFromList": "NaiveReuse Keyframes From List πŸŽ­πŸ…πŸ…“", #------------------------------------------------------------------------------ ############################################################################### # Iteration Opts diff --git a/animatediff/nodes_context_extras.py b/animatediff/nodes_context_extras.py index e9860df..8a28e8d 100644 --- a/animatediff/nodes_context_extras.py +++ b/animatediff/nodes_context_extras.py @@ -1,5 +1,6 @@ from torch import Tensor from typing import Union +from collections.abc import Iterable from .context import (ContextOptionsGroup) from .context_extras import (ContextExtrasGroup, @@ -81,6 +82,7 @@ def INPUT_TYPES(s): "mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + "inherit_missing": ("BOOLEAN", {"default": True}, ), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -90,14 +92,117 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" FUNCTION = "create_keyframe" - def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=1.0, - start_percent=0.0, guarantee_steps=1): + def create_keyframe(self, prev_kf=None, mult=1.0, mult_multival=None, + start_percent=0.0, guarantee_steps=1, inherit_missing=True): if prev_kf is None: prev_kf = NaiveReuseKeyframeGroup() prev_kf = prev_kf.clone() - kf = NaiveReuseKeyframe(mult=mult, mult_multival=mult_multival, start_percent=start_percent, guarantee_steps=guarantee_steps) + kf = NaiveReuseKeyframe(mult=mult, mult_multival=mult_multival, + start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing) prev_kf.add(kf) return (prev_kf,) + + +class NaiveReuse_KeyframeInterpolationNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "mult_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "mult_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "interpolation": (InterpolationMethod._LIST, ), + "intervals": ("INT", {"default": 50, "min": 2, "max": 100, "step": 1}), + "inherit_missing": ("BOOLEAN", {"default": True}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_kf": ("NAIVEREUSE_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "autosize": ("ADEAUTOSIZE", {"padding": 50}), + } + } + + RETURN_TYPES = ("NAIVEREUSE_KEYFRAME",) + RETURN_NAMES = ("NAIVEREUSE_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" + FUNCTION = "create_keyframe" + + def create_keyframe(self, + start_percent: float, end_percent: float, + mult_start: float, mult_end: float, interpolation: str, intervals: int, + inherit_missing=True, prev_kf: NaiveReuseKeyframeGroup=None, + mult_multival=None, print_keyframes=False): + if prev_kf is None: + prev_kf = NaiveReuseKeyframeGroup() + prev_kf = prev_kf.clone() + prev_kf = prev_kf.clone() + percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=intervals, method=InterpolationMethod.LINEAR) + mults = InterpolationMethod.get_weights(num_from=mult_start, num_to=mult_end, length=intervals, method=interpolation) + + is_first = True + for percent, mult in zip(percents, mults): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_kf.add(NaiveReuseKeyframe(mult=mult, mult_multival=mult_multival, + start_percent=percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing)) + if print_keyframes: + logger.info(f"NaiveReuseKeyframe - start_percent:{percent} = {mult}") + return (prev_kf,) + + +class NaiveReuse_KeyframeFromListNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mults_float": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "inherit_missing": ("BOOLEAN", {"default": True}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_kf": ("NAIVEREUSE_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("NAIVEREUSE_KEYFRAME",) + RETURN_NAMES = ("NAIVEREUSE_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/naivereuse" + FUNCTION = "create_keyframe" + + def create_keyframe(self, mults_float: Union[float, list[float]], + start_percent: float, end_percent: float, + inherit_missing=True, prev_kf: NaiveReuseKeyframeGroup=None, + mult_multival=None, print_keyframes=False): + if prev_kf is None: + prev_kf = NaiveReuseKeyframeGroup() + prev_kf = prev_kf.clone() + if type(mults_float) in (float, int): + mults_float = [float(mults_float)] + elif isinstance(mults_float, Iterable): + pass + else: + raise Exception(f"strengths_float must be either an interable input or a float, but was {type(mults_float).__repr__}.") + percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(mults_float), method=InterpolationMethod.LINEAR) + + is_first = True + for percent, mult in zip(percents, mults_float): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_kf.add(NaiveReuseKeyframe(mult=mult, mult_multival=mult_multival, + start_percent=percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing)) + if print_keyframes: + logger.info(f"NaiveReuseKeyframe - start_percent:{percent} = {mult}") + return (prev_kf,) #---------------------------------------- ######################################### @@ -170,7 +275,7 @@ def INPUT_TYPES(s): FUNCTION = "create_keyframe" def create_keyframe(self, prev_kf: ContextRefKeyframeGroup=None, - mult=1.0, mult_multival=1.0, mode_replace=None, tune_replace=None, + mult=1.0, mult_multival=None, mode_replace=None, tune_replace=None, start_percent=1.0, guarantee_steps=1, inherit_missing=True): if prev_kf is None: prev_kf = ContextRefKeyframeGroup() @@ -213,7 +318,7 @@ def create_keyframe(self, start_percent: float, end_percent: float, mult_start: float, mult_end: float, interpolation: str, intervals: int, inherit_missing=True, prev_kf: ContextRefKeyframeGroup=None, - mult_multival=1.0, mode_replace=None, tune_replace=None, print_keyframes=False): + mult_multival=None, mode_replace=None, tune_replace=None, print_keyframes=False): if prev_kf is None: prev_kf = ContextRefKeyframeGroup() prev_kf = prev_kf.clone() @@ -233,6 +338,59 @@ def create_keyframe(self, return (prev_kf,) +class ContextRef_KeyframeFromListNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mults_float": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "inherit_missing": ("BOOLEAN", {"default": True}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_kf": ("CONTEXTREF_KEYFRAME",), + "mult_multival": ("MULTIVAL",), + "mode_replace": ("CONTEXTREF_MODE",), + "tune_replace": ("CONTEXTREF_TUNE",), + "autosize": ("ADEAUTOSIZE", {"padding": 50}), + } + } + + RETURN_TYPES = ("CONTEXTREF_KEYFRAME",) + RETURN_NAMES = ("CONTEXTREF_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras/contextref" + FUNCTION = "create_keyframe" + + def create_keyframe(self, mults_float: Union[float, list[float]], + start_percent: float, end_percent: float, + inherit_missing=True, prev_kf: ContextRefKeyframeGroup=None, + mult_multival=None, mode_replace=None, tune_replace=None, print_keyframes=False): + if prev_kf is None: + prev_kf = ContextRefKeyframeGroup() + prev_kf = prev_kf.clone() + if type(mults_float) in (float, int): + mults_float = [float(mults_float)] + elif isinstance(mults_float, Iterable): + pass + else: + raise Exception(f"strengths_float must be either an interable input or a float, but was {type(mults_float).__repr__}.") + percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(mults_float), method=InterpolationMethod.LINEAR) + + is_first = True + for percent, mult in zip(percents, mults_float): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_kf.add(ContextRefKeyframe(mult=mult, mult_multival=mult_multival, tune_replace=tune_replace, mode_replace=mode_replace, + start_percent=percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing)) + if print_keyframes: + logger.info(f"ContextRefKeyframe - start_percent:{percent} = {mult}") + return (prev_kf,) + + class ContextRef_ModeFirst: @classmethod def INPUT_TYPES(s): diff --git a/animatediff/utils_motion.py b/animatediff/utils_motion.py index 8425f02..501a7f6 100644 --- a/animatediff/utils_motion.py +++ b/animatediff/utils_motion.py @@ -276,10 +276,12 @@ def create_multival_combo(float_val: Union[float, list[float]], mask_optional: T def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor], force_leader_A=False) -> Union[float, Tensor]: + if multivalA is None and multivalB is None: + return 1.0 # if one is None, use the other - if multivalA == None: + if multivalA is None: return multivalB - elif multivalB == None: + elif multivalB is None: return multivalA # both have a value - combine them based on type # if both are Tensors, make dims match before multiplying @@ -306,7 +308,7 @@ def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[floa def resize_multival(multival: Union[float, Tensor], batch_size: int, height: int, width: int): - if multival == None: + if multival is None: return 1.0 if type(multival) != Tensor: return multival From ee3409f30e300e3c20d540ecb1a26e6538c06ece Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 7 Aug 2024 02:36:22 -0500 Subject: [PATCH 19/22] Added support for ContextRef mode_replace and tune_replace --- animatediff/nodes_context.py | 7 ++++--- animatediff/sampling.py | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index c14a8b8..fd42683 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -6,6 +6,7 @@ from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules, generate_context_visualization) +from .model_injection import ModelPatcherAndInjector from .utils_model import BIGMAX, MAX_RESOLUTION @@ -378,7 +379,7 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, + def visualize(self, model: ModelPatcherAndInjector, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, visual_width: 1280, latents_length=32, steps=20, start_step=0, end_step=20): images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sampler_name=sampler_name, scheduler=scheduler, @@ -408,7 +409,7 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, + def visualize(self, model: ModelPatcherAndInjector, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, visual_width: 1280, latents_length=32, steps=20, denoise=1.0): images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sampler_name=sampler_name, scheduler=scheduler, @@ -435,7 +436,7 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sigmas, + def visualize(self, model: ModelPatcherAndInjector, context_opts: ContextOptionsGroup, sigmas, visual_width: 1280, latents_length=32): images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sigmas=sigmas) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 38878b7..5b40bdc 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -849,10 +849,11 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list refcn.prepare_current_timestep(timestep) if not refcn.should_run(): actually_should_run = False - break if actually_should_run: contextref_active = True - contextref_mode = ADGS.params.context_options.extras.context_ref.mode + for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: + # get mode_override if present, mode otherwise + contextref_mode = refcn.get_contextref_mode_replace() or ADGS.params.context_options.extras.context_ref.mode contextref_idxs_set = contextref_mode.indexes.copy() # use injector to ensure only 1 cond or uncond will be batched at a time contextref_injector = ContextRefInjector() From ecbd4ff362f5b76f68ba88877f8c50de4b558c5e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 8 Aug 2024 02:49:25 -0500 Subject: [PATCH 20/22] Added check for Advanced-ControlNet cooperation and matching errors to throw --- animatediff/sampling.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 5b40bdc..6a73aea 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -843,6 +843,13 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list # need to make sure that contextref stuff gets cleaned up, no matter what try: if ADGS.params.context_options.extras.should_run_context_ref(): + # check that ACN provided ContextRef as requested + temp_refcn_list = model_options["transformer_options"].get(CONTEXTREF_CONTROL_LIST_ALL, None) + if temp_refcn_list is None: + raise Exception("Advanced-ControlNet nodes are either missing or too outdated to support ContextRef. Update/install ComfyUI-Advanced-ControlNet to use ContextRef.") + if len(temp_refcn_list) == 0: + raise Exception("Unexpected ContextRef issue; Advanced-ControlNet did not provide any ContextRef objs for AnimateDiff-Evolved.") + del temp_refcn_list # check if ContextRef ReferenceAdvanced ACN objs should_run actually_should_run = True for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: From fade43bb76abd5828a0d8ba21743131341f76f9e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 8 Aug 2024 03:03:29 -0500 Subject: [PATCH 21/22] version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c268580..2684188 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.0.12" +version = "1.1.0" license = { file = "LICENSE" } dependencies = [] From 1de8a7c3afa515e8aef619c48f76a817371c7c6d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 8 Aug 2024 03:06:44 -0500 Subject: [PATCH 22/22] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 9b59cf4..c5d4911 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ NOTE: you can also use custom locations for models/motion loras by making use of - Per-frame GLIGEN coordinates control - Currently requires GLIGENTextBoxApplyBatch from KJNodes to do so, but I will add native nodes to do this soon. - Image Injection mid-sampling +- ContextRef and NaiveReuse (novel cross-context consistency techniques) ## Upcoming Features - Example workflows for **every feature** in AnimateDiff-Evolved repo, and hopefully a long Youtube video showing all features (Goal: before Elden Ring DLC releases. Working on it right now.)