From 8e4c6a15d2b294b1e7ea113febe8f9b32bf0f9c1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 7 Dec 2024 20:15:53 -0600 Subject: [PATCH 01/12] Progress on HelloMeme support, may move this into Advanced-ControlNet later --- animatediff/adapter_hellomeme.py | 307 +++++++++++++++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 animatediff/adapter_hellomeme.py diff --git a/animatediff/adapter_hellomeme.py b/animatediff/adapter_hellomeme.py new file mode 100644 index 0000000..ef53020 --- /dev/null +++ b/animatediff/adapter_hellomeme.py @@ -0,0 +1,307 @@ +# main code adapted from HelloMeme: https://github.com/HelloVision/HelloMeme +from typing import Optional, Union, Callable + +import copy +import math +import torch +from torch import Tensor, nn + +from einops import rearrange + +import comfy.ops +from comfy.ldm.modules.diffusionmodules import openaimodel +from comfy.ldm.modules.attention import CrossAttention, FeedForward +from comfy.model_patcher import ModelPatcher + + +def zero_module(module: nn.Module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +def create_HM_forward_timestep_embed_patch(): + return (SKReferenceAttention, hm_forward_timestep_embed_patch_ade) + + +def hm_forward_timestep_embed_patch_ade(layer, x, emb, context, transformer_options, *args, **kwargs): + return layer(x, transformer_options=transformer_options) + + +class HMReferenceAdapter(nn.Module): + def __init__(self, + block_out_channels: tuple[int] = (320, 640, 1280, 1280), + num_attention_heads: Optional[Union[int, tuple[int]]] = 8, + ops=comfy.ops.disable_weight_init + ): + super().__init__() + + self.block_out_channels = block_out_channels + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(block_out_channels) + self.num_attention_heads = num_attention_heads + + self.reference_modules_down = nn.ModuleList([]) + self.reference_modules_mid = None + self.reference_modules_up = nn.ModuleList([]) + + for i in range(len(block_out_channels)): + output_channel = block_out_channels[i] + + self.reference_modules_down.append( + SKReferenceAttention( + in_channels=output_channel, + num_attention_heads=num_attention_heads[i] + ) + ) + + self.reference_modules_mid = SKReferenceAttention( + in_channels=block_out_channels[-1], + num_attention_heads=num_attention_heads[-1] + ) + + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i in range(len(block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + if i > 0: + self.reference_modules_up.append( + SKReferenceAttention( + in_channels=prev_output_channel, + num_attention_heads=reversed_num_attention_heads[i] + ) + ) + + def inject(self, model: ModelPatcher): + unet: openaimodel.UNetModel = model.model.diffusion_model + del unet + + def eject(self, model: ModelPatcher): + unet: openaimodel.UNetModel = model.model.diffusion_model + del unet + + +class SKReferenceAttention(nn.Module): + def __init__(self, + in_channels: int, + num_attention_heads: int=1, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + num_positional_embeddings: int = 64*2, + ops = comfy.ops.disable_weight_init, + ): + super().__init__() + self.pos_embed = SinusoidalPositionalEmbedding(in_channels, max_seq_length=num_positional_embeddings) + self.attn1 = CrossAttention( + query_dim=in_channels, + heads=num_attention_heads, + dim_head=in_channels // num_attention_heads, + dropout=0.0, + ) + self.attn2 = CrossAttention( + query_dim=in_channels, + heads=num_attention_heads, + dim_head=in_channels // num_attention_heads, + dropout=0.0, + ) + self.norm = ops.LayerNorm(in_channels, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.proj = zero_module(ops.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)) + + # def forward(self, hidden_states: Tensor, ref_states: Tensor, num_frames: int): + def forward(self, hidden_states: Tensor, transformer_options: dict[str]): + h, w = hidden_states.shape[-2:] + + ref_states: Tensor = transformer_options["ade_ref_states"] + ad_params: dict[str] = transformer_options["ad_params"] + num_frames = ad_params.get("context_length", ad_params["full_length"]) + + if ref_states.shape[0] != hidden_states.shape[0]: + ref_states = ref_states.repeat_interleave(num_frames, dim=0) + cat_states = torch.cat([hidden_states, ref_states], dim=-1) + + cat_states = rearrange(cat_states.contiguous(), "b c h w -> (b h) w c") + res1 = self.attn1(self.norm(self.pos_embed(cat_states))) + res1 = rearrange(res1[:, :w, :], "(b h) w c -> b c h w", h=h) + + cat_states2 = torch.cat([res1, ref_states], dim=-2) + cat_states2 = rearrange(cat_states2.contiguous(), "b c h w -> (b w) h c") + res2 = self.attn2(self.norm(self.pos_embed(cat_states2))) + + res2 = rearrange(res2[:, :h, :], "(b w) h c -> b c h w", w=w) + + return hidden_states + self.proj(res2) + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1, ops=comfy.ops.disable_weight_init): + """3x3 convolution with padding""" + return ops.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ops=comfy.ops.disable_weight_init, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride, ops=ops) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, ops=ops) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class SKCrossAttention(nn.Module): + def __init__(self, + channel_in, + channel_out, + heads: int=8, + cross_attention_dim: int=320, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + num_positional_embeddings: int = 64, + num_positional_embeddings_hidden: int = 64, + ops=comfy.ops.disable_weight_init + ): + super().__init__() + self.conv = BasicBlock( + inplanes=channel_in, + planes=channel_out, + stride=2, + downsample=nn.Sequential( + ops.Conv2d(channel_in, channel_out, kernel_size=1, stride=2, bias=False), + nn.InstanceNorm2d(channel_out), + nn.SiLU(), + ), + norm_layer=nn.InstanceNorm2d + ) + + self.pos_embed = SinusoidalPositionalEmbedding(channel_out, max_seq_length=num_positional_embeddings) + self.pos_embed_hidden = SinusoidalPositionalEmbedding(cross_attention_dim, max_seq_length=num_positional_embeddings_hidden) + + self.norm1 = ops.LayerNorm(channel_out, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.attn1 = CrossAttention( + query_dim=channel_out, + heads=heads, + dim_head=channel_out // heads, + dropout=0.0, + context_dim=cross_attention_dim, + ) + + self.norm2 = nn.LayerNorm(channel_out, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.attn2 = CrossAttention( + query_dim=channel_out, + heads=heads, + dim_head=channel_out // heads, + dropout=0.0, + context_dim=cross_attention_dim, + ) + + self.ff = FeedForward( + channel_out, + mult=2, + dropout=0.0, + glu=True, + operations=ops, + ) + + self.proj = zero_module(ops.Conv2d(channel_out, channel_out, kernel_size=3, padding=1)) + + def forward(self, input: Tensor, hidden_states: Tensor): + x: Tensor = self.conv(input) + h, w = x.shape[-2:] + x = rearrange(x, "b c h w -> (b h) w c") + x = self.attn1(self.norm1(self.pos_embed(x)), self.pos_embed_hidden(hidden_states.repeat_interleave(h, dim=0).contiguous())) + x = rearrange(x, "(b h) w c -> (b w) h c", h=h) + x = self.ff(self.attn2(self.norm2(self.pos_embed(x)), self.pos_embed_hidden(hidden_states.repeat_interleave(w, dim=0).contiguous()))) + x = rearrange(x, "(b w) h c -> b c h w", w=w) + x = self.proj(x) + return x + + +# from diffusers +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x: Tensor): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x + + +class InsertReferenceAdapter(object): + def __init__(self): + self.reference_modules_down = None + self.reference_modules_mid = None + self.reference_modules_up = None + + def insert_reference_adapter(self, adapter: HMReferenceAdapter): + self.reference_modules_down = copy.deepcopy(adapter.reference_modules_down) + self.reference_modules_mid = copy.deepcopy(adapter.reference_modules_mid) + self.reference_modules_up = copy.deepcopy(adapter.reference_modules_up) From 4f3516860ef8b516be7ed031596651a4d4031628 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 10 Dec 2024 18:24:01 -0600 Subject: [PATCH 02/12] Testing unet layout for HM RefNet support --- animatediff/adapter_hellomeme.py | 61 ++++++++++++++++++++++++++++---- animatediff/nodes.py | 5 +++ animatediff/nodes_hellomeme.py | 40 +++++++++++++++++++++ 3 files changed, 99 insertions(+), 7 deletions(-) create mode 100644 animatediff/nodes_hellomeme.py diff --git a/animatediff/adapter_hellomeme.py b/animatediff/adapter_hellomeme.py index ef53020..aa17340 100644 --- a/animatediff/adapter_hellomeme.py +++ b/animatediff/adapter_hellomeme.py @@ -11,7 +11,7 @@ import comfy.ops from comfy.ldm.modules.diffusionmodules import openaimodel from comfy.ldm.modules.attention import CrossAttention, FeedForward -from comfy.model_patcher import ModelPatcher +from comfy.model_patcher import ModelPatcher, PatcherInjection def zero_module(module: nn.Module): @@ -21,13 +21,13 @@ def zero_module(module: nn.Module): def create_HM_forward_timestep_embed_patch(): - return (SKReferenceAttention, hm_forward_timestep_embed_patch_ade) + return (SKReferenceAttention, _hm_forward_timestep_embed_patch_ade) - -def hm_forward_timestep_embed_patch_ade(layer, x, emb, context, transformer_options, *args, **kwargs): +def _hm_forward_timestep_embed_patch_ade(layer, x, emb, context, transformer_options, *args, **kwargs): return layer(x, transformer_options=transformer_options) + class HMReferenceAdapter(nn.Module): def __init__(self, block_out_channels: tuple[int] = (320, 640, 1280, 1280), @@ -51,13 +51,15 @@ def __init__(self, self.reference_modules_down.append( SKReferenceAttention( in_channels=output_channel, - num_attention_heads=num_attention_heads[i] + num_attention_heads=num_attention_heads[i], + ops=ops ) ) self.reference_modules_mid = SKReferenceAttention( in_channels=block_out_channels[-1], - num_attention_heads=num_attention_heads[-1] + num_attention_heads=num_attention_heads[-1], + ops=ops ) reversed_block_out_channels = list(reversed(block_out_channels)) @@ -72,17 +74,62 @@ def __init__(self, self.reference_modules_up.append( SKReferenceAttention( in_channels=prev_output_channel, - num_attention_heads=reversed_num_attention_heads[i] + num_attention_heads=reversed_num_attention_heads[i], + ops=ops ) ) def inject(self, model: ModelPatcher): unet: openaimodel.UNetModel = model.model.diffusion_model + # inject input (down) blocks + if self.reference_modules_down is not None: + self._inject_down(unet.input_blocks) + # inject mid block + if self.reference_modules_mid is not None: + self._inject_mid([unet.middle_block]) + # inject output (up) blocks + if self.reference_modules_up is not None: + self._inject_up(unet.output_blocks) del unet + def _inject_down(self, unet_blocks: nn.ModuleList): + b = 20 + + def _inject_up(self, unet_blocks: nn.ModuleList): + b = 20 + + def _inject_mid(self, unet_blocks: nn.ModuleList): + # add middle block at the end + injection_count = 0 + unet_idx = 0 + injection_goal = 1 + def eject(self, model: ModelPatcher): unet: openaimodel.UNetModel = model.model.diffusion_model + # eject input (down) blocks + if hasattr(unet, "input_blocks"): + self._eject(unet.input_blocks) + # eject mid block (encapsulate in list to make compatible) + if hasattr(unet, "middle_block"): + self._eject([unet.middle_block]) + # eject output (up) blocks + if hasattr(unet, "output_blocks"): + self._eject(unet.output_blocks) del unet + + def _eject(self, unet_blocks: nn.ModuleList): + # eject all SKReferenceAttention objects from all blocks + for block in unet_blocks: + idx_to_pop = [] + for idx, component in enumerate(block): + if type(component) == SKReferenceAttention: + idx_to_pop.append(idx) + # pop in reverse order, as to not disturb what the indeces refer to + for idx in sorted(idx_to_pop, reverse=True): + block.pop(idx) + + def create_injector(self): + return PatcherInjection(inject=self.inject, eject=self.eject) class SKReferenceAttention(nn.Module): diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 44480bc..60f73f1 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -10,6 +10,7 @@ CameraCtrlReplaceCameraParameters, CameraCtrlSetOriginalAspectRatio) from .nodes_pia import (ApplyAnimateDiffPIAModel, LoadAnimateDiffAndInjectPIANode, InputPIA_MultivalNode, InputPIA_PaperPresetsNode, PIA_ADKeyframeNode) from .nodes_fancyvideo import (ApplyAnimateDiffFancyVideo,) +from .nodes_hellomeme import (TestHMRefNetInjection,) from .nodes_multival import MultivalDynamicNode, MultivalScaledMaskNode, MultivalDynamicFloatInputNode, MultivalDynamicFloatsNode, MultivalConvertToMaskNode from .nodes_conditioning import (CreateLoraHookKeyframeInterpolationDEPR, MaskableLoraLoaderDEPR, MaskableLoraLoaderModelOnlyDEPR, MaskableSDModelLoaderDEPR, MaskableSDModelLoaderModelOnlyDEPR, @@ -212,6 +213,8 @@ "ADE_InjectPIAIntoAnimateDiffModel": LoadAnimateDiffAndInjectPIANode, # FancyVideo #ApplyAnimateDiffFancyVideo.NodeID: ApplyAnimateDiffFancyVideo, + # HelloMeme + TestHMRefNetInjection.NodeID: TestHMRefNetInjection, # Deprecated Nodes "AnimateDiffLoaderV1": AnimateDiffLoaderDEPR, "ADE_AnimateDiffLoaderV1Advanced": AnimateDiffLoaderAdvancedDEPR, @@ -383,6 +386,8 @@ "ADE_InjectPIAIntoAnimateDiffModel": "πŸ§ͺInject PIA into AnimateDiff Model πŸŽ­πŸ…πŸ…“β‘‘", # FancyVideo ApplyAnimateDiffFancyVideo.NodeID: ApplyAnimateDiffFancyVideo.NodeName, + # HelloMeme + TestHMRefNetInjection.NodeID: TestHMRefNetInjection.NodeName, # Deprecated Nodes "AnimateDiffLoaderV1": "🚫AnimateDiff Loader [DEPRECATED] πŸŽ­πŸ…πŸ…“", "ADE_AnimateDiffLoaderV1Advanced": "🚫AnimateDiff Loader (Advanced) [DEPRECATED] πŸŽ­πŸ…πŸ…“", diff --git a/animatediff/nodes_hellomeme.py b/animatediff/nodes_hellomeme.py new file mode 100644 index 0000000..dcc3cc1 --- /dev/null +++ b/animatediff/nodes_hellomeme.py @@ -0,0 +1,40 @@ +from typing import Union +import torch + +from comfy.model_patcher import ModelPatcher +import comfy.model_management + + +from .adapter_hellomeme import HMReferenceAdapter, create_HM_forward_timestep_embed_patch + + +class TestHMRefNetInjection: + NodeID = "ADE_TestHMRefNetInjection" + NodeName = "Test HMRefNetInjection" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + } + } + + RETURN_TYPES = ("MODEL",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/HelloMeme" + FUNCTION = "inject_hmref" + + def inject_hmref(self, model: ModelPatcher): + model = model.clone() + + hmref = HMReferenceAdapter() + hmref.to(comfy.model_management.unet_dtype()) + hmref.to(comfy.model_management.unet_offload_device()) + mp_hmref = ModelPatcher(model=hmref, + load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device()) + model.set_additional_models("ADE_HMREF", [mp_hmref]) + model.set_model_forward_timestep_embed_patch(create_HM_forward_timestep_embed_patch()) + model.set_injections("ADE_HMREF", [hmref.create_injector()]) + + return (model,) From ebb5496001a3b557866526bd32033ea65ce58851 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 26 Dec 2024 07:41:53 -0600 Subject: [PATCH 03/12] Technically initial support for HelloMeme RefNet; it needs more work, but after some more testing with the diffusers repo, it does not seem like something that can be extended well. Will continue to explore --- animatediff/adapter_hellomeme.py | 219 ++++++++++++++++++++++++++++--- animatediff/model_injection.py | 19 +-- animatediff/nodes.py | 2 +- animatediff/nodes_hellomeme.py | 31 +++-- animatediff/sampling.py | 33 +++++ 5 files changed, 266 insertions(+), 38 deletions(-) diff --git a/animatediff/adapter_hellomeme.py b/animatediff/adapter_hellomeme.py index aa17340..194d8ac 100644 --- a/animatediff/adapter_hellomeme.py +++ b/animatediff/adapter_hellomeme.py @@ -1,5 +1,6 @@ # main code adapted from HelloMeme: https://github.com/HelloVision/HelloMeme -from typing import Optional, Union, Callable +from __future__ import annotations +from typing import Optional, Union, Callable, TYPE_CHECKING import copy import math @@ -9,9 +10,28 @@ from einops import rearrange import comfy.ops +import comfy.model_management +import comfy.patcher_extension +from comfy.patcher_extension import WrappersMP +import comfy.utils from comfy.ldm.modules.diffusionmodules import openaimodel from comfy.ldm.modules.attention import CrossAttention, FeedForward from comfy.model_patcher import ModelPatcher, PatcherInjection +if TYPE_CHECKING: + from comfy.sd import VAE + from comfy.model_base import BaseModel + +from .utils_model import get_motion_model_path, vae_encode_raw_batched +from .utils_motion import extend_to_batch_size +from .logger import logger + + +class HMRefConst: + HMREF = "ADE_HMREF" + REF_STATES = "ade_ref_states" + REF_MODE = "ade_ref_mode" + WRITE = "write" + READ = "read" def zero_module(module: nn.Module): @@ -28,10 +48,132 @@ def _hm_forward_timestep_embed_patch_ade(layer, x, emb, context, transformer_opt +class HMModelPatcher(ModelPatcher): + '''Class used only for type hints.''' + def __init__(self): + self.model: HMReferenceAdapter + + +def create_HMModelPatcher(model: HMReferenceAdapter, load_device, offload_device) -> HMModelPatcher: + patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) + return patcher + + +def load_hmreferenceadapter(model_name: str): + model_path = get_motion_model_path(model_name) + logger.info(f"Loading HMReferenceAdapter {model_name}") + state_dict = comfy.utils.load_torch_file(model_path, safe_load=True) + state_dict = prepare_hmref_state_dict(state_dict=state_dict, name=model_name) + # initialize HMReferenceAdapter + if comfy.model_management.unet_manual_cast(comfy.model_management.unet_dtype(), comfy.model_management.get_torch_device()) is None: + ops = comfy.ops.disable_weight_init + else: + ops = comfy.ops.manual_cast + hmref = HMReferenceAdapter(ops=ops) + hmref.to(comfy.model_management.unet_dtype()) + hmref.to(comfy.model_management.unet_offload_device()) + load_result = hmref.load_state_dict(state_dict, strict=True) + hmref_model = create_HMModelPatcher(model=hmref, load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device()) + return hmref_model + + +def prepare_hmref_state_dict(state_dict: dict[str, Tensor], name: str): + for key in list(state_dict.keys()): + # the last down module is not used at all; don't bother loading it + if key.startswith("reference_modules_down.3"): + state_dict.pop(key) + return state_dict + + +def create_hmref_attachment(model: ModelPatcher, attachment: HMRefAttachment): + model.set_attachments(HMRefConst.HMREF, attachment) + + +def get_hmref_attachment(model: ModelPatcher) -> Union[HMRefAttachment, None]: + return model.get_attachment(HMRefConst.HMREF) + + +def create_hmref_apply_model_wrapper(model_options: dict): + comfy.patcher_extension.add_wrapper_with_key(WrappersMP.APPLY_MODEL, + HMRefConst.HMREF, + _hmref_apply_model_wrapper, + model_options, is_model_options=True) + + +def _hmref_apply_model_wrapper(executor, *args, **kwargs): + # args (from BaseModel._apply_model): + # 0: x + # 1: t + # 2: c_concat + # 3: c_crossattn + # 4: control + # 5: transformer_options + transformer_options: dict[str] = args[5] + try: + transformer_options[HMRefConst.REF_STATES] = HMRefStates() + transformer_options[HMRefConst.REF_MODE] = HMRefConst.WRITE + # run in WRITE mode to get REF_STATES filled up + executor(*args, **kwargs) + # run in READ mode now + transformer_options[HMRefConst.REF_MODE] = HMRefConst.READ + return executor(*args, **kwargs) + finally: + # clean up transformer_options + del transformer_options[HMRefConst.REF_STATES] + del transformer_options[HMRefConst.REF_MODE] + +class HMRefAttachment: + def __init__(self, + image: Tensor, + vae: VAE): + self.image = image + self.vae = vae + # cached values + self.cached_shape = None + self.ref_latent: Tensor = None + + def on_model_patcher_clone(self, *args, **kwargs): + n = HMRefAttachment(image=self.image, vae=self.vae) + return n + + def prepare_ref_latent(self, model: BaseModel, x: Tensor): + # if already prepared, return it on expected device + if self.ref_latent is not None: + return self.ref_latent.to(device=x.device, dtype=x.dtype) + # get currently used models so they can be properly reloaded after perfoming VAE Encoding + cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + try: + b, c, h, w = x.shape + # transform range [0, 1] into [-1, 1] + #usable_ref = self.image.clone() + usable_ref = 2.0 * self.image - 1.0 + # resize image to match latent size + usable_ref = usable_ref.movedim(-1, 1) + usable_ref = comfy.utils.common_upscale(samples=usable_ref, width=w*self.vae.downscale_ratio, height=h*self.vae.downscale_ratio, + upscale_method="bilinear", crop="center") + usable_ref = usable_ref.movedim(1, -1) + # VAE encode images + logger.info("VAE Encoding HMREF input images...") + usable_ref = model.process_latent_in(vae_encode_raw_batched(vae=self.vae, pixels=usable_ref, show_pbar=False)) + logger.info("VAE Encoding HMREF input images complete.") + # make usable_ref expected length + usable_ref = extend_to_batch_size(usable_ref, b) + self.ref_latent = usable_ref.to(device=x.device, dtype=x.dtype) + return self.ref_latent + finally: + comfy.model_management.load_models_gpu(cached_loaded_models) + + def cleanup(self, *args, **kwargs): + del self.ref_latent + self.ref_latent = None + + class HMReferenceAdapter(nn.Module): def __init__(self, block_out_channels: tuple[int] = (320, 640, 1280, 1280), num_attention_heads: Optional[Union[int, tuple[int]]] = 8, + ignore_last_down: bool = True, ops=comfy.ops.disable_weight_init ): super().__init__() @@ -45,12 +187,15 @@ def __init__(self, self.reference_modules_mid = None self.reference_modules_up = nn.ModuleList([]) - for i in range(len(block_out_channels)): + # ignore last block (only CrossAttn blocks matter), unless otherwise specified + channels_to_parse = block_out_channels if not ignore_last_down else block_out_channels[:-1] + for i in range(len(channels_to_parse)): output_channel = block_out_channels[i] self.reference_modules_down.append( SKReferenceAttention( in_channels=output_channel, + index=i, num_attention_heads=num_attention_heads[i], ops=ops ) @@ -58,6 +203,7 @@ def __init__(self, self.reference_modules_mid = SKReferenceAttention( in_channels=block_out_channels[-1], + index=0, num_attention_heads=num_attention_heads[-1], ops=ops ) @@ -67,13 +213,16 @@ def __init__(self, output_channel = reversed_block_out_channels[0] for i in range(len(block_out_channels)): + # need prev_output_channel due to Upsample locations prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] + # ignore first block (only CrossAttn blocks matter) if i > 0: self.reference_modules_up.append( SKReferenceAttention( in_channels=prev_output_channel, + index=i, num_attention_heads=reversed_num_attention_heads[i], ops=ops ) @@ -83,26 +232,39 @@ def inject(self, model: ModelPatcher): unet: openaimodel.UNetModel = model.model.diffusion_model # inject input (down) blocks if self.reference_modules_down is not None: - self._inject_down(unet.input_blocks) + self._inject_up_down(unet.input_blocks, self.reference_modules_down, "Downsample") # inject mid block if self.reference_modules_mid is not None: self._inject_mid([unet.middle_block]) + #print(unet.middle_block) # inject output (up) blocks if self.reference_modules_up is not None: - self._inject_up(unet.output_blocks) + self._inject_up_down(unet.output_blocks, self.reference_modules_up, "Upsample") del unet - def _inject_down(self, unet_blocks: nn.ModuleList): - b = 20 - - def _inject_up(self, unet_blocks: nn.ModuleList): - b = 20 + def _inject_up_down(self, unet_blocks: nn.ModuleList, ref_blocks: nn.ModuleList, sample_module_name: str): + injection_count = 0 + unet_idx = 0 + injection_goal = len(ref_blocks) + # only stop injecting when modules exhausted + while injection_count < injection_goal: + if unet_idx >= len(unet_blocks): + break + # look for sample_module_name block in current unet_idx + sample_idx = -1 + for idx, component in enumerate(unet_blocks[unet_idx]): + if type(component).__name__ == sample_module_name: + sample_idx = idx + # if found, place ref_block right after it + if sample_idx >= 0: + unet_blocks[unet_idx].insert(sample_idx+1, ref_blocks[injection_count]) + injection_count += 1 + # increment unet_idx + unet_idx += 1 def _inject_mid(self, unet_blocks: nn.ModuleList): # add middle block at the end - injection_count = 0 - unet_idx = 0 - injection_goal = 1 + unet_blocks[0].insert(len(unet_blocks[0]), self.reference_modules_mid) def eject(self, model: ModelPatcher): unet: openaimodel.UNetModel = model.model.diffusion_model @@ -112,6 +274,7 @@ def eject(self, model: ModelPatcher): # eject mid block (encapsulate in list to make compatible) if hasattr(unet, "middle_block"): self._eject([unet.middle_block]) + #print(unet.middle_block) # eject output (up) blocks if hasattr(unet, "output_blocks"): self._eject(unet.output_blocks) @@ -132,9 +295,15 @@ def create_injector(self): return PatcherInjection(inject=self.inject, eject=self.eject) +class HMRefStates: + def __init__(self): + self.states = {} + + class SKReferenceAttention(nn.Module): def __init__(self, in_channels: int, + index: int, num_attention_heads: int=1, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, @@ -142,6 +311,7 @@ def __init__(self, ops = comfy.ops.disable_weight_init, ): super().__init__() + self.index = index self.pos_embed = SinusoidalPositionalEmbedding(in_channels, max_seq_length=num_positional_embeddings) self.attn1 = CrossAttention( query_dim=in_channels, @@ -158,23 +328,34 @@ def __init__(self, self.norm = ops.LayerNorm(in_channels, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.proj = zero_module(ops.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)) + def get_ref_state_id(self, transformer_options: dict[str]): + block = transformer_options["block"] + return f"{block[0]}_{self.index}" + # def forward(self, hidden_states: Tensor, ref_states: Tensor, num_frames: int): def forward(self, hidden_states: Tensor, transformer_options: dict[str]): + ref_mode = transformer_options.get(HMRefConst.REF_MODE, HMRefConst.WRITE) + if ref_mode == HMRefConst.WRITE: + ref_states: HMRefStates = transformer_options.setdefault(HMRefConst.REF_STATES, HMRefStates()) + ref_states.states[self.get_ref_state_id(transformer_options)] = hidden_states.clone() + return hidden_states + h, w = hidden_states.shape[-2:] - ref_states: Tensor = transformer_options["ade_ref_states"] - ad_params: dict[str] = transformer_options["ad_params"] - num_frames = ad_params.get("context_length", ad_params["full_length"]) + states: Tensor = transformer_options[HMRefConst.REF_STATES].states[self.get_ref_state_id(transformer_options)] + num_frames = hidden_states.shape[0] // len(transformer_options["cond_or_uncond"]) + #ad_params: dict[str] = transformer_options["ad_params"] + #num_frames = ad_params.get("context_length", ad_params["full_length"]) - if ref_states.shape[0] != hidden_states.shape[0]: - ref_states = ref_states.repeat_interleave(num_frames, dim=0) - cat_states = torch.cat([hidden_states, ref_states], dim=-1) + if states.shape[0] != hidden_states.shape[0]: + states = states.repeat_interleave(num_frames, dim=0) + cat_states = torch.cat([hidden_states, states], dim=-1) cat_states = rearrange(cat_states.contiguous(), "b c h w -> (b h) w c") res1 = self.attn1(self.norm(self.pos_embed(cat_states))) res1 = rearrange(res1[:, :w, :], "(b h) w c -> b c h w", h=h) - cat_states2 = torch.cat([res1, ref_states], dim=-2) + cat_states2 = torch.cat([res1, states], dim=-2) cat_states2 = rearrange(cat_states2.contiguous(), "b c h w -> (b w) h c") res2 = self.attn2(self.norm(self.pos_embed(cat_states2))) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 5987d40..6a22845 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -1,3 +1,4 @@ +from __future__ import annotations import copy from typing import Union, Callable from collections import namedtuple @@ -56,7 +57,7 @@ def __init__(self, model: ModelPatcher): self.model = model def set_all_properties(self, outer_sampler_wrapper: Callable, calc_cond_batch_wrapper: Callable, - params: 'InjectionParams', sample_settings: SampleSettings=None, motion_models: 'MotionModelGroup'=None): + params: InjectionParams, sample_settings: SampleSettings=None, motion_models: MotionModelGroup=None): self.set_outer_sample_wrapper(outer_sampler_wrapper) self.set_calc_cond_batch_wrapper(calc_cond_batch_wrapper) self.set_sample_settings(sample_settings = sample_settings if sample_settings is not None else SampleSettings()) @@ -122,16 +123,18 @@ def get_name_string(self, show_version=False): def get_sample_settings(self) -> SampleSettings: - return self.model.get_attachment(self.SAMPLE_SETTINGS) + sample_settings = self.model.get_attachment(self.SAMPLE_SETTINGS) + return sample_settings if sample_settings is not None else SampleSettings() def set_sample_settings(self, sample_settings: SampleSettings): self.model.set_attachments(self.SAMPLE_SETTINGS, sample_settings) - def get_params(self) -> 'InjectionParams': - return self.model.get_attachment(self.PARAMS) + def get_params(self) -> InjectionParams: + params = self.model.get_attachment(self.PARAMS) + return params if params is not None else InjectionParams() - def set_params(self, params: 'InjectionParams'): + def set_params(self, params: InjectionParams): self.model.set_attachments(self.PARAMS, params) if params.context_options.context_length is not None: self.set_ACN_outer_sample_wrapper(throw_exception=False) @@ -236,7 +239,7 @@ def _mm_clean_callback(self: MotionModelPatcher, *args, **kwargs): attachment.cleanup(self) -def get_mm_attachment(patcher: MotionModelPatcher) -> 'MotionModelAttachment': +def get_mm_attachment(patcher: MotionModelPatcher) -> MotionModelAttachment: return patcher.get_attachment(ModelPatcherHelper.ADE) @@ -623,7 +626,7 @@ def __getitem__(self, index) -> MotionModelPatcher: def is_empty(self) -> bool: return len(self.models) == 0 - def clone(self) -> 'MotionModelGroup': + def clone(self) -> MotionModelGroup: cloned = MotionModelGroup() for mm in self.models: cloned.add(mm) @@ -1080,7 +1083,7 @@ def set_motion_model_settings(self, motion_model_settings: AnimateDiffSettings): def reset_context(self): self.context_options = ContextOptionsGroup.default() - def clone(self) -> 'InjectionParams': + def clone(self) -> InjectionParams: new_params = InjectionParams( self.unlimited_area_hack, self.apply_mm_groupnorm_hack, apply_v2_properly=self.apply_v2_properly, ) diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 60f73f1..2167ee2 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -214,7 +214,7 @@ # FancyVideo #ApplyAnimateDiffFancyVideo.NodeID: ApplyAnimateDiffFancyVideo, # HelloMeme - TestHMRefNetInjection.NodeID: TestHMRefNetInjection, + #TestHMRefNetInjection.NodeID: TestHMRefNetInjection, # Deprecated Nodes "AnimateDiffLoaderV1": AnimateDiffLoaderDEPR, "ADE_AnimateDiffLoaderV1Advanced": AnimateDiffLoaderAdvancedDEPR, diff --git a/animatediff/nodes_hellomeme.py b/animatediff/nodes_hellomeme.py index dcc3cc1..0d002f5 100644 --- a/animatediff/nodes_hellomeme.py +++ b/animatediff/nodes_hellomeme.py @@ -1,11 +1,18 @@ from typing import Union import torch +from torch import Tensor +from comfy.sd import VAE from comfy.model_patcher import ModelPatcher import comfy.model_management -from .adapter_hellomeme import HMReferenceAdapter, create_HM_forward_timestep_embed_patch +from .adapter_hellomeme import (HMRefConst, HMModelPatcher, HMRefAttachment, load_hmreferenceadapter, + create_hmref_attachment, + create_HM_forward_timestep_embed_patch) +from .model_injection import ModelPatcherHelper +from .sampling import outer_sample_wrapper +from .utils_model import get_available_motion_models class TestHMRefNetInjection: @@ -17,6 +24,9 @@ def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), + "image": ("IMAGE",), + "vae": ("VAE",), + "hmref": (get_available_motion_models(),), } } @@ -24,17 +34,18 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/HelloMeme" FUNCTION = "inject_hmref" - def inject_hmref(self, model: ModelPatcher): + def inject_hmref(self, model: ModelPatcher, image: Tensor, vae: VAE, + hmref: str): model = model.clone() - hmref = HMReferenceAdapter() - hmref.to(comfy.model_management.unet_dtype()) - hmref.to(comfy.model_management.unet_offload_device()) - mp_hmref = ModelPatcher(model=hmref, - load_device=comfy.model_management.get_torch_device(), - offload_device=comfy.model_management.unet_offload_device()) - model.set_additional_models("ADE_HMREF", [mp_hmref]) + mp_hmref: HMModelPatcher = load_hmreferenceadapter(hmref) + model.set_additional_models(HMRefConst.HMREF, [mp_hmref]) model.set_model_forward_timestep_embed_patch(create_HM_forward_timestep_embed_patch()) - model.set_injections("ADE_HMREF", [hmref.create_injector()]) + model.set_injections(HMRefConst.HMREF, [mp_hmref.model.create_injector()]) + create_hmref_attachment(model, HMRefAttachment(image=image, vae=vae)) + + helper = ModelPatcherHelper(model) + helper.set_outer_sample_wrapper(outer_sample_wrapper) + del helper return (model,) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 157ad92..f0fdf84 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -26,6 +26,7 @@ from .utils_motion import composite_extend, prepare_mask_batch, extend_to_batch_size from .model_injection import InjectionParams, ModelPatcherHelper, MotionModelGroup, get_mm_attachment from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion +from .adapter_hellomeme import HMRefConst, HMRefStates, get_hmref_attachment, create_hmref_apply_model_wrapper from .logger import logger @@ -367,6 +368,7 @@ def outer_sample_wrapper(executor: WrapperExecutor, *args, **kwargs): cached_latents = None cached_noise = None function_injections = FunctionInjectionHolder() + hmref_attachment = None try: guider: comfy.samplers.CFGGuider = executor.class_obj @@ -435,6 +437,34 @@ def ad_callback(step, x0, x, total_steps): iter_kwargs = {} # NOTE: original KSampler stuff is not doable here, so skipping... + + # NOTE: this will never be used as I have hidden HelloMeme RefNet nodes from being loaded + # if have HMRef, then do what's needed + hmref_attachment = get_hmref_attachment(helper.model) + if hmref_attachment is not None: + ref_latents = None + #sigmas: Tensor = args[3] + try: + hmref_attachment.prepare_ref_latent(helper.model.model, x=latents) + # NOTE: trying out using the hmrefnet more like other types of refnets + create_hmref_apply_model_wrapper(guider.model_options) + # ref_latents = hmref_attachment.prepare_ref_latent(helper.model.model, x=latents) + # ref_sigmas = torch.tensor([helper.model.model.model_sampling.sigma_min, torch.tensor(0.0)]).to(device=sigmas.device, dtype=sigmas.dtype) + # new_args = args.copy() + # new_args[3] = ref_sigmas + # make sure transformer_options has necessary HMREF stuff + # guider.model_options["transformer_options"][HMRefConst.REF_STATES] = HMRefStates() + # guider.model_options["transformer_options"][HMRefConst.REF_MODE] = HMRefConst.WRITE + # ADGS.update_with_inject_params(params) + # ADGS.start_step = 0 + # ADGS.current_step = ADGS.start_step + # ADGS.last_step = 0 + # executor(*tuple(new_args), **kwargs) + # guider.model_options["transformer_options"][HMRefConst.REF_MODE] = HMRefConst.READ + finally: + del ref_latents + #del sigmas + for curr_i in range(iter_opts.iterations): # handle GLOBALSTATE vars and step tally # NOTE: only KSampler/KSampler (Advanced) would have steps; @@ -493,6 +523,9 @@ def ad_callback(step, x0, x, total_steps): del cached_latents del cached_noise del orig_model_options + if hmref_attachment is not None: + hmref_attachment.cleanup() + del hmref_attachment # reset global state ADGS.reset() # clean motion_models From 1431497ff8777975aee1faa8472f5eca74e8325a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 27 Dec 2024 17:52:30 -0600 Subject: [PATCH 04/12] Progress on MotionCtrl support, fixed PIA nodes autosize not being stored in hidden --- animatediff/adapter_motionctrl.py | 92 +++++++++++++++++++++++++++ animatediff/motion_module_ad.py | 14 ++++ animatediff/nodes_motionctrl.py | 102 ++++++++++++++++++++++++++++++ animatediff/nodes_pia.py | 8 ++- 4 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 animatediff/adapter_motionctrl.py create mode 100644 animatediff/nodes_motionctrl.py diff --git a/animatediff/adapter_motionctrl.py b/animatediff/adapter_motionctrl.py new file mode 100644 index 0000000..6599f0f --- /dev/null +++ b/animatediff/adapter_motionctrl.py @@ -0,0 +1,92 @@ +# main code adapted from https://github.com/TencentARC/MotionCtrl/tree/animatediff +from __future__ import annotations +from torch import nn, Tensor + +from comfy.model_patcher import ModelPatcher +import comfy.model_management +import comfy.ops +import comfy.utils + +from .adapter_cameractrl import ResnetBlockCameraCtrl +from .motion_module_ad import AnimateDiffModel +from .utils_model import get_motion_model_path + +# cmcm (Camera Control) +def injection_motionctrl_cmcm(motion_model: AnimateDiffModel, cmcm_name: str): + pass + + +# omcm (Object Control) +def load_motionctrl_omcm(omcm_name: str): + omcm_path = get_motion_model_path(omcm_name) + state_dict = comfy.utils.load_torch_file(omcm_path, safe_load=True) + for key in list(state_dict.keys()): + # remove 'module.' prefix + if key.startswith('module.'): + new_key = key.replace('module.', '') + state_dict[new_key] = state_dict[key] + state_dict.pop(key) + + if comfy.model_management.unet_manual_cast(comfy.model_management.unet_dtype(), comfy.model_management.get_torch_device()) is None: + ops = comfy.ops.disable_weight_init + else: + ops = comfy.ops.manual_cast + adapter = MotionCtrlAdapter(ops=ops) + adapter.load_state_dict(state_dict=state_dict, strict=True) + adapter.to( + device = comfy.model_management.unet_offload_device(), + dtype = comfy.model_management.unet_dtype() + ) + omcm_modelpatcher = _create_OMCMModelPatcher(model=adapter, + load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device()) + return omcm_modelpatcher + + +def _create_OMCMModelPatcher(model, load_device, offload_device) -> ObjectControlModelPatcher: + patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) + return patcher + + +class ObjectControlModelPatcher(ModelPatcher): + '''Class only used for type hints.''' + def __init__(self): + self.model: MotionCtrlAdapter + + +class MotionCtrlAdapter(nn.Module): + def __init__(self, + downscale_factor=8, + channels=[320, 640, 1280, 1280], + nums_rb=2, cin=128, # 2*8*8 + ksize=3, sk=True, + use_conv=False, + ops=comfy.ops.disable_weight_init): + super(MotionCtrlAdapter, self).__init__() + self.downscale_factor = downscale_factor + self.unshuffle = nn.PixelUnshuffle(downscale_factor) + self.channels = channels + self.nums_rb = nums_rb + self.body = [] + for i in range(len(channels)): + for j in range(nums_rb): + if (i != 0) and (j == 0): + self.body.append( + ResnetBlockCameraCtrl(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv, ops=ops)) + else: + self.body.append( + ResnetBlockCameraCtrl(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv, ops=ops)) + self.body = nn.ModuleList(self.body) + self.conv_in = ops.Conv2d(cin, channels[0], 3, 1, 1) + + def forward(self, x: Tensor): + x = self.unshuffle(x) + # extract features + features = [] + x = self.conv_in(x) + for i in range(len(self.channels)): + for j in range(self.nums_rb): + idx = i * self.nums_rb + j + x = self.body[idx](x) + features.append(x) + return features diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 21c265b..f14bae5 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -1313,6 +1313,8 @@ def __init__( self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"), operations=ops) self.ff_norm = ops.LayerNorm(dim) + # for MotionCtrl (CMCM) use + self.cc_projections: comfy.ops.disable_weight_init.Linear = None def set_scale_multiplier(self, idx: int, multiplier: Union[float, None]): self.attention_blocks[idx].set_scale_multiplier(multiplier) @@ -1346,6 +1348,7 @@ def forward( elif view_options.context_length == video_length and not view_options.use_on_equal_length: view_options = None if not view_options: + count = 0 for attention_block, norm, scale_mask in zip(self.attention_blocks, self.norms, scale_masks): norm_hidden_states = norm(hidden_states).to(hidden_states.dtype) hidden_states = ( @@ -1362,6 +1365,15 @@ def forward( transformer_options=transformer_options, ) + hidden_states ) + # do MotionCtrl-CMCM stuff if needed + if self.cc_projections is not None and count==0 and 'ADE_RT' in transformer_options: + RT: Tensor = transformer_options['ADE_RT'] + B, t, _ = RT.shape + RT = RT.reshape(B*t, 1, -1) + RT = RT.repeat(1, hidden_states.shape[1]) + hidden_states = torch.cat([hidden_states, RT], dim=-1) + hidden_states = self.cc_projections(hidden_states) + count += 1 else: # views idea gotten from diffusers AnimateDiff FreeNoise implementation: # https://github.com/arthur-qiu/FreeNoise-AnimateDiff/blob/main/animatediff/models/motion_module.py @@ -1381,6 +1393,7 @@ def forward( sub_hidden_states = rearrange(hidden_states[:, sub_idxs], "b f d c -> (b f) d c") if has_camera_feature: mm_kwargs["camera_feature"] = orig_camera_feature[:, sub_idxs, :] + count = 0 for attention_block, norm, scale_mask in zip(self.attention_blocks, self.norms, scale_masks): norm_hidden_states = norm(sub_hidden_states).to(sub_hidden_states.dtype) sub_hidden_states = ( @@ -1397,6 +1410,7 @@ def forward( transformer_options=transformer_options, ) + sub_hidden_states ) + count += 1 sub_hidden_states = rearrange(sub_hidden_states, "(b f) d c -> b f d c", f=len(sub_idxs)) weights = get_context_weights(len(sub_idxs), view_options.fuse_method) * batched_conds diff --git a/animatediff/nodes_motionctrl.py b/animatediff/nodes_motionctrl.py new file mode 100644 index 0000000..7c84753 --- /dev/null +++ b/animatediff/nodes_motionctrl.py @@ -0,0 +1,102 @@ +import torch +from torch import Tensor + +from .ad_settings import AnimateDiffSettings +from .adapter_motionctrl import injection_motionctrl_cmcm, load_motionctrl_omcm + +from .motion_module_ad import AllPerBlocks +from .model_injection import MotionModelPatcher, MotionModelGroup, load_motion_module_gen2 +from .motion_lora import MotionLoraList + +from .nodes_gen2 import ApplyAnimateDiffModelNode +from .utils_model import get_available_motion_models +from .utils_motion import ADKeyframeGroup + + +class LoadMotionCtrlCMCM: + NodeID = "ADE_LoadMotionCtrl_CMCMMOdel" + NodeName = "Load AnimateDiff+MotionCtrl Camera Model πŸŽ­πŸ…πŸ…“β‘‘" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (get_available_motion_models(),), + "motionctrl_cmcm": (get_available_motion_models(),), + }, + "optional": { + "ad_settings": ("AD_SETTINGS",), + } + } + + RETURN_TYPES = ("MOTION_MODEL_ADE",) + RETURN_NAMES = ("MOTION_MODEL",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/MotionCtrl" + FUNCTION = "load_motionctrl_cmcm" + + def load_motionctrl_cmcm(self, model_name: str, motionctrl_cmcm: str, ad_settings: AnimateDiffSettings=None): + motion_model = load_motion_module_gen2(model_name=model_name, motion_model_settings=ad_settings) + motion_model = injection_motionctrl_cmcm(motion_model, cmcm_name=motionctrl_cmcm) + return (motion_model,) + + +class LoadMotionCtrlOMCM: + NodeID = "ADE_LoadMotionCtrl_OMCMMOdel" + NodeName = "Load MotionCtrl Object Model πŸŽ­πŸ…πŸ…“β‘‘" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "motionctrl_omcm": (get_available_motion_models(),), + } + } + + RETURN_TYPES = ("OMCM_MOTIONCTRL",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/MotionCtrl" + FUNCTION = "load_motionctrl_omcm" + + def load_motionctrl_omcm(self, motionctrl_omcm: str): + omcm_modelpatcher = load_motionctrl_omcm(motionctrl_omcm) + return (omcm_modelpatcher,) + + +class ApplyAnimateDiffMotionCtrlModel: + NodeID = "ADE_ApplyAnimateDiffModelWithMotionCtrl" + NodeName = "Apply AnimateDiff+MotionCtrl Model πŸŽ­πŸ…πŸ…“β‘‘" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "motion_model": ("MOTION_MODEL_ADE",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }, + "optional": { + "omcm_motionctrl": ("OMCM_MOTIONCTRL",), + "motion_lora": ("MOTION_LORA",), + "scale_multival": ("MULTIVAL",), + "effect_multival": ("MULTIVAL",), + "ad_keyframes": ("AD_KEYFRAMES",), + "prev_m_models": ("M_MODELS",), + "per_block": ("PER_BLOCK",), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("M_MODELS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/MotionCtrl" + FUNCTION = "apply_motion_model" + + def apply_motion_model(self, motion_model: MotionModelPatcher, start_percent: float=0.0, end_percent: float=1.0, + motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, + scale_multival=None, effect_multival=None, per_block: AllPerBlocks=None, + prev_m_models: MotionModelGroup=None,): + (new_m_models,) = ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, start_percent=start_percent, end_percent=end_percent, + motion_lora=motion_lora, ad_keyframes=ad_keyframes, per_block=per_block, + scale_multival=scale_multival, effect_multival=effect_multival, prev_m_models=prev_m_models) + # most recent added model will always be first in list + curr_model = new_m_models.models[0] + # check if model has CMCM; if so, make sure something is provided for it + # check if OMCM is provided; if so, make sure something is provided for it + return (new_m_models,) diff --git a/animatediff/nodes_pia.py b/animatediff/nodes_pia.py index a7b9716..a61553a 100644 --- a/animatediff/nodes_pia.py +++ b/animatediff/nodes_pia.py @@ -126,6 +126,8 @@ def INPUT_TYPES(s): "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), "per_block": ("PER_BLOCK",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -202,6 +204,8 @@ def INPUT_TYPES(s): "pia_input": ("PIA_INPUT",), "inherit_missing": ("BOOLEAN", {"default": True}, ), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -254,8 +258,10 @@ def INPUT_TYPES(s): "optional": { "mult_multival": ("MULTIVAL",), "print_values": ("BOOLEAN", {"default": False},), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), #"effect_multival": ("MULTIVAL",), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } From faa5ec365a6285b193c0c56e8de1aa0f0c20f168 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 28 Dec 2024 12:57:19 -0600 Subject: [PATCH 05/12] More MotionCtrl progress, added per_block_replace to AnimateDiff Keyframe, small refactoring to make this possible, fully deprecated AnimateDiff Loader [Legacy] node --- animatediff/adapter_motionctrl.py | 51 +++++++++++--- animatediff/model_injection.py | 48 ++++++++++--- animatediff/motion_module_ad.py | 106 +++++++--------------------- animatediff/nodes.py | 17 +++-- animatediff/nodes_deprecated.py | 94 ++++++++++++++++++++++++- animatediff/nodes_gen1.py | 112 +++--------------------------- animatediff/nodes_gen2.py | 10 ++- animatediff/nodes_motionctrl.py | 11 +-- animatediff/nodes_per_block.py | 18 ++--- animatediff/utils_motion.py | 77 ++++++++++++++++++++ 10 files changed, 322 insertions(+), 222 deletions(-) diff --git a/animatediff/adapter_motionctrl.py b/animatediff/adapter_motionctrl.py index 6599f0f..fef7d1d 100644 --- a/animatediff/adapter_motionctrl.py +++ b/animatediff/adapter_motionctrl.py @@ -8,24 +8,50 @@ import comfy.utils from .adapter_cameractrl import ResnetBlockCameraCtrl +from .ad_settings import AnimateDiffSettings from .motion_module_ad import AnimateDiffModel +from .model_injection import apply_mm_settings from .utils_model import get_motion_model_path + # cmcm (Camera Control) -def injection_motionctrl_cmcm(motion_model: AnimateDiffModel, cmcm_name: str): - pass +def inject_motionctrl_cmcm(motion_model: AnimateDiffModel, cmcm_name: str, ad_settings: AnimateDiffSettings=None, + apply_non_ccs=True): + cmcm_path = get_motion_model_path(cmcm_name) + state_dict = comfy.utils.load_torch_file(cmcm_path, safe_load=True) + _remove_module_prefix(state_dict) + # if applicable, apply ad_settings to cmcm to match expected behavior + if ad_settings is not None: + state_dict = apply_mm_settings(model_dict=state_dict, mm_settings=ad_settings) + motion_model.init_motionctrl_cc_projections(state_dict=state_dict) + # seperate out PE keys so can be applied separately in case dims don't match + apply_dict = {} + for key in list(state_dict.keys()): + if "cc_projection" in key: + apply_dict[key] = state_dict[key] + state_dict.pop(key) + pe_dict = {} + for key in list(state_dict.keys()): + if "pos_encoder" in key: + pe_dict[key] = state_dict[key] + state_dict.pop(key) + if apply_non_ccs: + apply_dict.update(state_dict) + for key, value in pe_dict.items(): + comfy.utils.set_attr(motion_model, key, value) + _, unexpected = motion_model.load_state_dict(apply_dict, strict=False) + if len(unexpected) > 0: + raise Exception(f"MotionCtrl CMCM model had unexpected keys: {unexpected}") + # make sure model is still has proper dtype and offload device + motion_model.to(comfy.model_management.unet_dtype()) + motion_model.to(comfy.model_management.unet_offload_device()) # omcm (Object Control) def load_motionctrl_omcm(omcm_name: str): omcm_path = get_motion_model_path(omcm_name) state_dict = comfy.utils.load_torch_file(omcm_path, safe_load=True) - for key in list(state_dict.keys()): - # remove 'module.' prefix - if key.startswith('module.'): - new_key = key.replace('module.', '') - state_dict[new_key] = state_dict[key] - state_dict.pop(key) + _remove_module_prefix(state_dict) if comfy.model_management.unet_manual_cast(comfy.model_management.unet_dtype(), comfy.model_management.get_torch_device()) is None: ops = comfy.ops.disable_weight_init @@ -48,6 +74,15 @@ def _create_OMCMModelPatcher(model, load_device, offload_device) -> ObjectContro return patcher +def _remove_module_prefix(state_dict: dict[str, Tensor]): + for key in list(state_dict.keys()): + # remove 'module.' prefix + if key.startswith('module.'): + new_key = key.replace('module.', '') + state_dict[new_key] = state_dict[key] + state_dict.pop(key) + + class ObjectControlModelPatcher(ModelPatcher): '''Class only used for type hints.''' def __init__(self): diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 6a22845..954d7b5 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -22,10 +22,11 @@ from .ad_settings import AnimateDiffSettings, AdjustPE, AdjustWeight from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding from .context import ContextOptions, ContextOptions, ContextOptionsGroup -from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, AnimateDiffInfo, EncoderOnlyAnimateDiffModel, VersatileAttention, PerBlock, AllPerBlocks, +from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, AnimateDiffInfo, EncoderOnlyAnimateDiffModel, VersatileAttention, VanillaTemporalModule, has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len) from .logger import logger from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA, + PerBlock, AllPerBlocks, get_combined_per_block_list, get_combined_multival, get_combined_input, get_combined_input_effect_multival, ade_broadcast_image_to, extend_to_batch_size, prepare_mask_batch) from .conditioning import HookRef, LoraHook, LoraHookGroup, LoraHookMode @@ -266,6 +267,9 @@ def __init__(self): self.camera_features: list[Tensor] = None # temporary self.camera_features_shape: tuple = None self.cameractrl_multival: Union[float, Tensor] = None + ## temp + self.current_cameractrl_effect: Union[float, Tensor] = None + self.combined_cameractrl_effect: Union[float, Tensor] = None # PIA self.orig_pia_images: Tensor = None @@ -275,6 +279,10 @@ def __init__(self): self.prev_pia_latents_shape: tuple = None self.prev_current_pia_input: InputPIA = None self.pia_multival: Union[float, Tensor] = None + ## temp + self.current_pia_input: InputPIA = None + self.combined_pia_mask: Union[float, Tensor] = None + self.combined_pia_effect: Union[float, Tensor] = None # FancyVideo self.orig_fancy_images: Tensor = None @@ -290,14 +298,10 @@ def __init__(self): self.previous_t = -1 self.current_scale: Union[float, Tensor] = None self.current_effect: Union[float, Tensor] = None - self.current_cameractrl_effect: Union[float, Tensor] = None - self.current_pia_input: InputPIA = None + self.current_per_block_list: Union[list[PerBlock], None] = None self.combined_scale: Union[float, Tensor] = None self.combined_effect: Union[float, Tensor] = None - self.combined_per_block_list: Union[float, Tensor] = None - self.combined_cameractrl_effect: Union[float, Tensor] = None - self.combined_pia_mask: Union[float, Tensor] = None - self.combined_pia_effect: Union[float, Tensor] = None + self.combined_per_block_list: Union[list[PerBlock], None] = None self.was_within_range = False self.prev_sub_idxs = None self.prev_batched_number = None @@ -336,19 +340,28 @@ def prepare_current_keyframe(self, patcher: MotionModelPatcher, x: Tensor, t: Te self.current_index = i self.current_keyframe = eval_kf self.current_used_steps = 0 - # keep track of scale and effect multivals, accounting for inherit_missing + # NOTE: handle possible inputs from keyframe, taking into account inherit_missing + # scale if self.current_keyframe.has_scale(): self.current_scale = self.current_keyframe.scale_multival elif not self.current_keyframe.inherit_missing: self.current_scale = None + # effect if self.current_keyframe.has_effect(): self.current_effect = self.current_keyframe.effect_multival elif not self.current_keyframe.inherit_missing: self.current_effect = None + # per_block_list + if self.current_keyframe.has_per_block_replace(): + self.current_per_block_list = self.current_keyframe.per_block_list + elif not self.current_keyframe.inherit_missing: + self.current_per_block_list = None + # cameractrl_effect if self.current_keyframe.has_cameractrl_effect(): self.current_cameractrl_effect = self.current_keyframe.cameractrl_multival elif not self.current_keyframe.inherit_missing: self.current_cameractrl_effect = None + # pia_input if self.current_keyframe.has_pia_input(): self.current_pia_input = self.current_keyframe.pia_input elif not self.current_keyframe.inherit_missing: @@ -364,12 +377,13 @@ def prepare_current_keyframe(self, patcher: MotionModelPatcher, x: Tensor, t: Te # combine model's scale and effect with keyframe's scale and effect self.combined_scale = get_combined_multival(self.scale_multival, self.current_scale) self.combined_effect = get_combined_multival(self.effect_multival, self.current_effect) + self.combined_per_block_list = get_combined_per_block_list(self.per_block_list, self.current_per_block_list) self.combined_cameractrl_effect = get_combined_multival(self.cameractrl_multival, self.current_cameractrl_effect) self.combined_pia_mask = get_combined_input(self.pia_input, self.current_pia_input, x) self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input) # apply scale and effect - patcher.model.set_scale(self.combined_scale, self.per_block_list) - patcher.model.set_effect(self.combined_effect, self.per_block_list) # TODO: set combined_per_block_list + patcher.model.set_scale(self.combined_scale, self.combined_per_block_list) + patcher.model.set_effect(self.combined_effect, self.combined_per_block_list) patcher.model.set_cameractrl_effect(self.combined_cameractrl_effect) # apply effect - if not within range, set effect to 0, effectively turning model off if curr_t > self.timestep_range[0] or curr_t < self.timestep_range[1]: @@ -378,7 +392,7 @@ def prepare_current_keyframe(self, patcher: MotionModelPatcher, x: Tensor, t: Te else: # if was not in range last step, apply effect to toggle AD status if not self.was_within_range: - patcher.model.set_effect(self.combined_effect, self.per_block_list) + patcher.model.set_effect(self.combined_effect, self.combined_per_block_list) self.was_within_range = True # update steps current keyframe is used self.current_used_steps += 1 @@ -386,6 +400,7 @@ def prepare_current_keyframe(self, patcher: MotionModelPatcher, x: Tensor, t: Te self.previous_t = curr_t def prepare_alcmi2v_features(self, patcher: MotionModelPatcher, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str], latent_format): + '''Used for AnimateLCM-I2V''' # if no img_encoder, done if patcher.model.img_encoder is None: return @@ -412,6 +427,7 @@ def prepare_alcmi2v_features(self, patcher: MotionModelPatcher, x: Tensor, cond_ self.prev_batched_number = batched_number def prepare_camera_features(self, patcher: MotionModelPatcher, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str]): + '''Used for CameraCtrl''' # if no camera_encoder, done if patcher.model.camera_encoder is None: return @@ -445,6 +461,7 @@ def prepare_camera_features(self, patcher: MotionModelPatcher, x: Tensor, cond_o self.prev_batched_number = batched_number def get_pia_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: + '''Used for PIA''' # if have cached shape, check if matches - if so, return cached pia_latents if self.prev_pia_latents_shape is not None: if self.prev_pia_latents_shape[0] == x.shape[0] and self.prev_pia_latents_shape[2] == x.shape[2] and self.prev_pia_latents_shape[3] == x.shape[3]: @@ -505,6 +522,7 @@ def get_pia_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: comfy.model_management.load_models_gpu(cached_loaded_models) def get_fancy_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: + '''Used for FancyVideo''' # if have cached shape, check if matches - if so, return cached fancy_latents if self.prev_fancy_latents_shape is not None: if self.prev_fancy_latents_shape[0] == x.shape[0] and self.prev_fancy_latents_shape[-2] == x.shape[-2] and self.prev_fancy_latents_shape[-1] == x.shape[-1]: @@ -572,6 +590,7 @@ def cleanup(self, patcher: MotionModelPatcher): self.previous_t = -1 self.current_scale = None self.current_effect = None + self.current_per_block_list = None self.combined_scale = None self.combined_effect = None self.combined_per_block_list = None @@ -927,6 +946,13 @@ def validate_per_block_compatibility(motion_model: MotionModelPatcher, all_per_b raise Exception(f"Per-Block provided is meant for {all_per_blocks.sd_type}, but provided motion module is for {mm_info.sd_type}.") +def validate_per_block_compatibility_keyframes(motion_model: MotionModelPatcher, keyframes: ADKeyframeGroup): + if keyframes is None: + return + for keyframe in keyframes.keyframes: + validate_per_block_compatibility(motion_model, keyframe._per_block_replace) + + def interpolate_pe_to_length(model_dict: dict[str, Tensor], key: str, new_length: int): pe_shape = model_dict[key].shape temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index f14bae5..0fefe1d 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -1,7 +1,6 @@ import math from typing import Iterable, Tuple, Union, TYPE_CHECKING import re -from dataclasses import dataclass from collections.abc import Iterable as IterColl import torch @@ -15,7 +14,7 @@ from comfy.ldm.modules.diffusionmodules import openaimodel from comfy.ldm.modules.diffusionmodules.openaimodel import SpatialTransformer from comfy.controlnet import broadcast_image_to -from comfy.utils import repeat_to_batch_size +import comfy.utils import comfy.ops import comfy.model_management @@ -24,7 +23,9 @@ if TYPE_CHECKING: # avoids circular import from .adapter_cameractrl import CameraPoseEncoder from .adapter_fancyvideo import FancyVideoCondEmbedding, FancyVideoKeys, initialize_weights_to_zero -from .utils_motion import (CrossAttentionMM, MotionCompatibilityError, DummyNNModule, extend_to_batch_size, extend_list_to_batch_size, +from .utils_motion import (CrossAttentionMM, MotionCompatibilityError, DummyNNModule, + PerBlock, PerBlockId, + extend_to_batch_size, extend_list_to_batch_size, prepare_mask_batch, get_combined_multival) from .utils_model import BetaSchedules, ModelTypeSD from .logger import logger @@ -66,61 +67,6 @@ def get_string(self): return f"{self.mm_name}:{self.mm_version}:{self.mm_format}:{self.sd_type}" -####################### -# Facilitate Per-Block Effect and Scale Control -class PerAttn: - def __init__(self, attn_idx: Union[int, None], scale: Union[float, Tensor, None]): - self.attn_idx = attn_idx - self.scale = scale - - def matches(self, id: int): - if self.attn_idx is None: - return True - return self.attn_idx == id - - -class PerBlockId: - def __init__(self, block_type: str, block_idx: Union[int, None]=None, module_idx: Union[int, None]=None): - self.block_type = block_type - self.block_idx = block_idx - self.module_idx = module_idx - - def matches(self, other: 'PerBlockId') -> bool: - # block_type - if other.block_type != self.block_type: - return False - # block_idx - if other.block_idx is None: - return True - elif other.block_idx != self.block_idx: - return False - # module_idx - if other.module_idx is None: - return True - return other.module_idx == self.module_idx - - def __str__(self): - return f"PerBlockId({self.block_type},{self.block_idx},{self.module_idx})" - - -class PerBlock: - def __init__(self, id: PerBlockId, effect: Union[float, Tensor, None]=None, - scales: Union[list[Union[float, Tensor, None]], None]=None): - self.id = id - self.effect = effect - self.scales = scales - - def matches(self, id: PerBlockId): - return self.id.matches(id) - - -@dataclass -class AllPerBlocks: - per_block_list: list[PerBlock] - sd_type: Union[str, None] = None -#---------------------- -####################### - def is_hotshotxl(mm_state_dict: dict[str, Tensor]) -> bool: # use pos_encoder naming to determine if hotshotxl model for key in mm_state_dict.keys(): @@ -445,9 +391,7 @@ def set_camera_encoder(self, camera_encoder: 'CameraPoseEncoder'): self.camera_encoder = camera_encoder def init_conv_in(self, mm_state_dict: dict[str, Tensor]): - ''' - Used for PIA/FancyVideo - ''' + '''Used for PIA/FancyVideo''' del self.conv_in # hardcoded values, for now # dim=2, in_channels=9, model_channels=320, kernel=3, padding=1, @@ -459,9 +403,7 @@ def init_conv_in(self, mm_state_dict: dict[str, Tensor]): dtype=comfy.model_management.unet_dtype(), device=comfy.model_management.unet_offload_device()) def init_fps_embedding(self, mm_state_dict: dict[str, Tensor]): - ''' - Used for FancyVideo - ''' + '''Used for FancyVideo''' del self.fps_embedding in_channels = mm_state_dict["fps_embedding.linear.weight"].size(1) # expected to be 320 cond_embed_dim = mm_state_dict["fps_embedding.linear.weight"].size(0) # expected to be 1280 @@ -469,15 +411,23 @@ def init_fps_embedding(self, mm_state_dict: dict[str, Tensor]): self.fps_embedding.apply(initialize_weights_to_zero) def init_motion_embedding(self, mm_state_dict: dict[str, Tensor]): - ''' - Used for FancyVideo - ''' + '''Used for FancyVideo''' del self.motion_embedding in_channels = mm_state_dict["motion_embedding.linear.weight"].size(1) # expected to be 320 cond_embed_dim = mm_state_dict["motion_embedding.linear.weight"].size(0) # expected to be 1280 self.motion_embedding = FancyVideoCondEmbedding(in_channels=in_channels, cond_embed_dim=cond_embed_dim) self.motion_embedding.apply(initialize_weights_to_zero) + def init_motionctrl_cc_projections(self, state_dict: dict[str, Tensor]): + '''Used for MotionCtrl''' + for key, value in state_dict.items(): + if key.endswith('cc_projection.weight'): + in_features = value.shape[1] + out_features = value.shape[0] + ttb_key = key.split('.cc_projection')[0] + ttb: TemporalTransformerBlock = comfy.utils.get_attr(self, ttb_key) + ttb.init_cc_projection(in_features=in_features, out_features=out_features, ops=self.ops) + def get_fancyvideo_emb_patches(self, dtype, device, fps=25, motion_score=3.0): patches = [] if self.fps_embedding is not None: @@ -1150,7 +1100,7 @@ def get_scale_mask(self, idx: int, hidden_states: Tensor) -> Union[Tensor, None] # otherwise, calculate temp mask self.prev_hidden_states_batch = batch mask = prepare_mask_batch(self.raw_scale_masks[idx], shape=(self.full_length, 1, height, width)) - mask = repeat_to_batch_size(mask, self.full_length) + mask = extend_to_batch_size(mask, self.full_length) # if mask not the same amount length as full length, make it match if self.full_length != mask.shape[0]: mask = broadcast_image_to(mask, self.full_length, 1) @@ -1196,7 +1146,7 @@ def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, N # otherwise, calculate temp_cameractrl self.prev_cameractrl_hidden_states_batch = batch mask = prepare_mask_batch(self.raw_cameractrl_effect, shape=(self.full_length, 1, height, width)) - mask = repeat_to_batch_size(mask, self.full_length) + mask = extend_to_batch_size(mask, self.full_length) # if mask not the same amount length as full length, make it match if self.full_length != mask.shape[0]: mask = broadcast_image_to(mask, self.full_length, 1) @@ -1314,7 +1264,7 @@ def __init__( self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"), operations=ops) self.ff_norm = ops.LayerNorm(dim) # for MotionCtrl (CMCM) use - self.cc_projections: comfy.ops.disable_weight_init.Linear = None + self.cc_projection: comfy.ops.disable_weight_init.Linear = None def set_scale_multiplier(self, idx: int, multiplier: Union[float, None]): self.attention_blocks[idx].set_scale_multiplier(multiplier) @@ -1327,6 +1277,9 @@ def reset_temp_vars(self): for block in self.attention_blocks: block.reset_temp_vars() + def init_cc_projection(self, in_features: int, out_features: int, ops: comfy.ops.disable_weight_init): + self.cc_projection = ops.Linear(in_features=in_features, out_features=out_features) + def forward( self, hidden_states: Tensor, @@ -1366,13 +1319,13 @@ def forward( ) + hidden_states ) # do MotionCtrl-CMCM stuff if needed - if self.cc_projections is not None and count==0 and 'ADE_RT' in transformer_options: + if self.cc_projection is not None and count==0 and 'ADE_RT' in transformer_options: RT: Tensor = transformer_options['ADE_RT'] B, t, _ = RT.shape RT = RT.reshape(B*t, 1, -1) RT = RT.repeat(1, hidden_states.shape[1]) hidden_states = torch.cat([hidden_states, RT], dim=-1) - hidden_states = self.cc_projections(hidden_states) + hidden_states = self.cc_projection(hidden_states) count += 1 else: # views idea gotten from diffusers AnimateDiff FreeNoise implementation: @@ -1427,10 +1380,7 @@ def forward( del value_final del count_final - hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states - - output = hidden_states - return output + return self.ff(self.ff_norm(hidden_states)) + hidden_states class PositionalEncoding(nn.Module): @@ -1472,7 +1422,6 @@ def __init__( ): super().__init__(operations=ops, *args, **kwargs) assert attention_mode == "Temporal" - self.attention_mode = attention_mode self.is_cross_attention = kwargs["context_dim"] is not None @@ -1551,9 +1500,8 @@ def forward( transformer_options=transformer_options, ) - hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + return rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) - return hidden_states ############################################################################ ### EncoderOnly Version diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 2167ee2..64befdc 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -1,6 +1,6 @@ import comfy.sample as comfy_sample -from .nodes_gen1 import (AnimateDiffLoaderGen1, LegacyAnimateDiffLoaderWithContext) +from .nodes_gen1 import (AnimateDiffLoaderGen1,) from .nodes_gen2 import (UseEvolvedSamplingNode, ApplyAnimateDiffModelNode, ApplyAnimateDiffModelBasicNode, ADKeyframeNode, LoadAnimateDiffModelNode) from .nodes_animatelcmi2v import (ApplyAnimateLCMI2VModel, LoadAnimateLCMI2VModelNode, LoadAnimateDiffAndInjectI2VNode, UpscaleAndVaeEncode) @@ -8,6 +8,7 @@ LoadCameraPosesFromFile, LoadCameraPosesFromPath, CameraCtrlPoseBasic, CameraCtrlPoseCombo, CameraCtrlPoseAdvanced, CameraCtrlManualAppendPose, CameraCtrlReplaceCameraParameters, CameraCtrlSetOriginalAspectRatio) +from .nodes_motionctrl import (LoadMotionCtrlCMCM, LoadMotionCtrlOMCM, ApplyAnimateDiffMotionCtrlModel) from .nodes_pia import (ApplyAnimateDiffPIAModel, LoadAnimateDiffAndInjectPIANode, InputPIA_MultivalNode, InputPIA_PaperPresetsNode, PIA_ADKeyframeNode) from .nodes_fancyvideo import (ApplyAnimateDiffFancyVideo,) from .nodes_hellomeme import (TestHMRefNetInjection,) @@ -44,7 +45,7 @@ PerBlock_SD15_LowLevelNode, PerBlock_SD15_MidLevelNode, PerBlock_SD15_FromFloatsNode, PerBlock_SDXL_LowLevelNode, PerBlock_SDXL_MidLevelNode, PerBlock_SDXL_FromFloatsNode) from .nodes_extras import AnimateDiffUnload, EmptyLatentImageLarge, CheckpointLoaderSimpleWithNoiseSelect, PerturbedAttentionGuidanceMultival, RescaleCFGMultival -from .nodes_deprecated import (AnimateDiffLoaderDEPR, AnimateDiffLoaderAdvancedDEPR, AnimateDiffCombineDEPR, +from .nodes_deprecated import (AnimateDiffLoaderDEPR, AnimateDiffLoaderAdvancedDEPR, LegacyAnimateDiffLoaderWithContextDEPR, AnimateDiffCombineDEPR, AnimateDiffModelSettingsDEPR, AnimateDiffModelSettingsSimpleDEPR, AnimateDiffModelSettingsAdvancedDEPR, AnimateDiffModelSettingsAdvancedAttnStrengthsDEPR) from .nodes_lora import AnimateDiffLoraLoader @@ -182,7 +183,6 @@ "ADE_RescaleCFGMultival": RescaleCFGMultival, # Gen1 Nodes "ADE_AnimateDiffLoaderGen1": AnimateDiffLoaderGen1, - "ADE_AnimateDiffLoaderWithContext": LegacyAnimateDiffLoaderWithContext, # Gen2 Nodes "ADE_UseEvolvedSampling": UseEvolvedSamplingNode, "ADE_ApplyAnimateDiffModelSimple": ApplyAnimateDiffModelBasicNode, @@ -193,6 +193,10 @@ "ADE_LoadAnimateLCMI2VModel": LoadAnimateLCMI2VModelNode, "ADE_UpscaleAndVAEEncode": UpscaleAndVaeEncode, "ADE_InjectI2VIntoAnimateDiffModel": LoadAnimateDiffAndInjectI2VNode, + # MotionCtrl Nodes + LoadMotionCtrlCMCM.NodeID: LoadMotionCtrlCMCM, + LoadMotionCtrlOMCM.NodeID: LoadMotionCtrlOMCM, + ApplyAnimateDiffMotionCtrlModel.NodeID: ApplyAnimateDiffMotionCtrlModel, # CameraCtrl Nodes "ADE_ApplyAnimateDiffModelWithCameraCtrl": ApplyAnimateDiffWithCameraCtrl, "ADE_LoadAnimateDiffModelWithCameraCtrl": LoadAnimateDiffModelWithCameraCtrl, @@ -216,6 +220,7 @@ # HelloMeme #TestHMRefNetInjection.NodeID: TestHMRefNetInjection, # Deprecated Nodes + "ADE_AnimateDiffLoaderWithContext": LegacyAnimateDiffLoaderWithContextDEPR, "AnimateDiffLoaderV1": AnimateDiffLoaderDEPR, "ADE_AnimateDiffLoaderV1Advanced": AnimateDiffLoaderAdvancedDEPR, "ADE_AnimateDiffCombine": AnimateDiffCombineDEPR, @@ -355,7 +360,6 @@ "ADE_RescaleCFGMultival": "RescaleCFG [Multival] πŸŽ­πŸ…πŸ…“", # Gen1 Nodes "ADE_AnimateDiffLoaderGen1": "AnimateDiff Loader πŸŽ­πŸ…πŸ…“β‘ ", - "ADE_AnimateDiffLoaderWithContext": "AnimateDiff Loader [Legacy] πŸŽ­πŸ…πŸ…“β‘ ", # Gen2 Nodes "ADE_UseEvolvedSampling": "Use Evolved Sampling πŸŽ­πŸ…πŸ…“β‘‘", "ADE_ApplyAnimateDiffModelSimple": "Apply AnimateDiff Model πŸŽ­πŸ…πŸ…“β‘‘", @@ -366,6 +370,10 @@ "ADE_LoadAnimateLCMI2VModel": "Load AnimateLCM-I2V Model πŸŽ­πŸ…πŸ…“β‘‘", "ADE_UpscaleAndVAEEncode": "Scale Ref Image and VAE Encode πŸŽ­πŸ…πŸ…“β‘‘", "ADE_InjectI2VIntoAnimateDiffModel": "πŸ§ͺInject I2V into AnimateDiff Model πŸŽ­πŸ…πŸ…“β‘‘", + # MotionCtrl Nodes + LoadMotionCtrlCMCM.NodeID: LoadMotionCtrlCMCM.NodeName, + LoadMotionCtrlOMCM.NodeID: LoadMotionCtrlOMCM.NodeName, + ApplyAnimateDiffMotionCtrlModel.NodeID: ApplyAnimateDiffMotionCtrlModel.NodeName, # CameraCtrl Nodes "ADE_ApplyAnimateDiffModelWithCameraCtrl": "Apply AnimateDiff+CameraCtrl Model πŸŽ­πŸ…πŸ…“β‘‘", "ADE_LoadAnimateDiffModelWithCameraCtrl": "Load AnimateDiff+CameraCtrl Model πŸŽ­πŸ…πŸ…“β‘‘", @@ -389,6 +397,7 @@ # HelloMeme TestHMRefNetInjection.NodeID: TestHMRefNetInjection.NodeName, # Deprecated Nodes + "ADE_AnimateDiffLoaderWithContext": "AnimateDiff Loader [Legacy] πŸŽ­πŸ…πŸ…“β‘ ", "AnimateDiffLoaderV1": "🚫AnimateDiff Loader [DEPRECATED] πŸŽ­πŸ…πŸ…“", "ADE_AnimateDiffLoaderV1Advanced": "🚫AnimateDiff Loader (Advanced) [DEPRECATED] πŸŽ­πŸ…πŸ…“", "ADE_AnimateDiffCombine": "🚫AnimateDiff Combine [DEPRECATED, Use Video Combine (VHS) Instead!] πŸŽ­πŸ…πŸ…“", diff --git a/animatediff/nodes_deprecated.py b/animatediff/nodes_deprecated.py index 50fe5c9..f7e1d9b 100644 --- a/animatediff/nodes_deprecated.py +++ b/animatediff/nodes_deprecated.py @@ -16,8 +16,11 @@ from .context import ContextOptionsGroup, ContextOptions, ContextSchedules from .logger import logger from .utils_model import Folders, BetaSchedules, get_available_motion_models -from .model_injection import ModelPatcherHelper, InjectionParams, MotionModelGroup, load_motion_module_gen1 +from .utils_motion import ADKeyframeGroup +from .motion_lora import MotionLoraList +from .model_injection import (ModelPatcherHelper, InjectionParams, MotionModelGroup, get_mm_attachment, load_motion_module_gen1) from .sampling import outer_sample_wrapper, sliding_calc_cond_batch +from .sample_settings import SampleSettings class AnimateDiffLoaderDEPR: @@ -148,7 +151,96 @@ def load_mm_and_inject_params(self, del motion_model return (model, latents) + + +class LegacyAnimateDiffLoaderWithContextDEPR: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "model_name": (get_available_motion_models(),), + "beta_schedule": (BetaSchedules.ALIAS_LIST, {"default": BetaSchedules.AUTOSELECT}), + #"apply_mm_groupnorm_hack": ("BOOLEAN", {"default": True}), + }, + "optional": { + "context_options": ("CONTEXT_OPTIONS",), + "motion_lora": ("MOTION_LORA",), + "ad_settings": ("AD_SETTINGS",), + "sample_settings": ("SAMPLE_SETTINGS",), + "motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}), + "apply_v2_models_properly": ("BOOLEAN", {"default": True}), + "ad_keyframes": ("AD_KEYFRAMES",), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated; use AnimateDiff Loader instead."}), + } + } + DEPRECATED = True + RETURN_TYPES = ("MODEL",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘  Gen1 nodes β‘ " + FUNCTION = "load_mm_and_inject_params" + + def load_mm_and_inject_params(self, + model: ModelPatcher, + model_name: str, beta_schedule: str,# apply_mm_groupnorm_hack: bool, + context_options: ContextOptionsGroup=None, motion_lora: MotionLoraList=None, ad_settings: AnimateDiffSettings=None, motion_model_settings: AnimateDiffSettings=None, + sample_settings: SampleSettings=None, motion_scale: float=1.0, apply_v2_models_properly: bool=False, ad_keyframes: ADKeyframeGroup=None, + ): + if ad_settings is not None: + motion_model_settings = ad_settings + # load motion module + motion_model = load_motion_module_gen1(model_name, model, motion_lora=motion_lora, motion_model_settings=motion_model_settings) + # set injection params + params = InjectionParams( + unlimited_area_hack=False, + apply_v2_properly=apply_v2_models_properly, + ) + if context_options: + params.set_context(context_options) + # set motion_scale and motion_model_settings + if not motion_model_settings: + motion_model_settings = AnimateDiffSettings() + motion_model_settings.attn_scale = motion_scale + params.set_motion_model_settings(motion_model_settings) + + attachment = get_mm_attachment(motion_model) + if params.motion_model_settings.mask_attn_scale is not None: + attachment.scale_multival = params.motion_model_settings.mask_attn_scale * params.motion_model_settings.attn_scale + else: + attachment.scale_multival = params.motion_model_settings.attn_scale + + attachment.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() + + # need to use a ModelPatcher that supports injection of motion modules into unet + model = model.clone() + helper = ModelPatcherHelper(model) + helper.set_all_properties( + outer_sampler_wrapper=outer_sample_wrapper, + calc_cond_batch_wrapper=sliding_calc_cond_batch, + params=params, + sample_settings=sample_settings, + motion_models=MotionModelGroup(motion_model), + ) + + sample_settings = helper.get_sample_settings() + if sample_settings.custom_cfg is not None: + logger.info("[Sample Settings] custom_cfg is set; will override any KSampler cfg values or patches.") + + if sample_settings.sigma_schedule is not None: + logger.info("[Sample Settings] sigma_schedule is set; will override beta_schedule.") + model.add_object_patch("model_sampling", sample_settings.sigma_schedule.clone().model_sampling) + else: + # save model sampling from BetaSchedule as object patch + # if autoselect, get suggested beta_schedule from motion model + if beta_schedule == BetaSchedules.AUTOSELECT and helper.get_motion_models(): + beta_schedule = helper.get_motion_models()[0].model.get_best_beta_schedule(log=True) + new_model_sampling = BetaSchedules.to_model_sampling(beta_schedule, model) + if new_model_sampling is not None: + model.add_object_patch("model_sampling", new_model_sampling) + + del motion_model + return (model,) + class AnimateDiffCombineDEPR: ffmpeg_warning_already_shown = False diff --git a/animatediff/nodes_gen1.py b/animatediff/nodes_gen1.py index ff5388f..1f694bb 100644 --- a/animatediff/nodes_gen1.py +++ b/animatediff/nodes_gen1.py @@ -1,20 +1,15 @@ -from pathlib import Path -import torch - -import comfy.sample as comfy_sample from comfy.model_patcher import ModelPatcher -from .ad_settings import AdjustGroup, AnimateDiffSettings, AdjustPE, AdjustWeight -from .context import ContextOptions, ContextOptionsGroup, ContextSchedules +from .ad_settings import AnimateDiffSettings +from .context import ContextOptionsGroup from .logger import logger -from .utils_model import BetaSchedules, get_available_motion_loras, get_available_motion_models, get_motion_lora_path -from .utils_motion import ADKeyframeGroup, get_combined_multival -from .motion_lora import MotionLoraInfo, MotionLoraList -from .motion_module_ad import AllPerBlocks +from .utils_model import BetaSchedules, get_available_motion_models +from .utils_motion import ADKeyframeGroup, AllPerBlocks, get_combined_multival +from .motion_lora import MotionLoraList from .model_injection import (ModelPatcherHelper, InjectionParams, MotionModelGroup, get_mm_attachment, - load_motion_lora_as_patches, load_motion_module_gen1, load_motion_module_gen2, validate_model_compatibility_gen2, - validate_per_block_compatibility) -from .sample_settings import SampleSettings, SeedNoiseGeneration + load_motion_lora_as_patches, load_motion_module_gen2, validate_model_compatibility_gen2, + validate_per_block_compatibility, validate_per_block_compatibility_keyframes) +from .sample_settings import SampleSettings from .sampling import outer_sample_wrapper, sliding_calc_cond_batch @@ -66,7 +61,8 @@ def load_mm_and_inject_params(self, validate_per_block_compatibility(motion_model=motion_model, all_per_blocks=per_block) attachment.per_block_list = per_block.per_block_list attachment.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() - + validate_per_block_compatibility_keyframes(motion_model=motion_model, keyframes=attachment.keyframes) + # create injection params params = InjectionParams(unlimited_area_hack=False) # apply context options @@ -112,91 +108,3 @@ def load_mm_and_inject_params(self, del motion_model return (model,) - - -class LegacyAnimateDiffLoaderWithContext: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "model_name": (get_available_motion_models(),), - "beta_schedule": (BetaSchedules.ALIAS_LIST, {"default": BetaSchedules.AUTOSELECT}), - #"apply_mm_groupnorm_hack": ("BOOLEAN", {"default": True}), - }, - "optional": { - "context_options": ("CONTEXT_OPTIONS",), - "motion_lora": ("MOTION_LORA",), - "ad_settings": ("AD_SETTINGS",), - "sample_settings": ("SAMPLE_SETTINGS",), - "motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}), - "apply_v2_models_properly": ("BOOLEAN", {"default": True}), - "ad_keyframes": ("AD_KEYFRAMES",), - } - } - - DEPRECATED = True - RETURN_TYPES = ("MODEL",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘  Gen1 nodes β‘ " - FUNCTION = "load_mm_and_inject_params" - - def load_mm_and_inject_params(self, - model: ModelPatcher, - model_name: str, beta_schedule: str,# apply_mm_groupnorm_hack: bool, - context_options: ContextOptionsGroup=None, motion_lora: MotionLoraList=None, ad_settings: AnimateDiffSettings=None, motion_model_settings: AnimateDiffSettings=None, - sample_settings: SampleSettings=None, motion_scale: float=1.0, apply_v2_models_properly: bool=False, ad_keyframes: ADKeyframeGroup=None, - ): - if ad_settings is not None: - motion_model_settings = ad_settings - # load motion module - motion_model = load_motion_module_gen1(model_name, model, motion_lora=motion_lora, motion_model_settings=motion_model_settings) - # set injection params - params = InjectionParams( - unlimited_area_hack=False, - apply_v2_properly=apply_v2_models_properly, - ) - if context_options: - params.set_context(context_options) - # set motion_scale and motion_model_settings - if not motion_model_settings: - motion_model_settings = AnimateDiffSettings() - motion_model_settings.attn_scale = motion_scale - params.set_motion_model_settings(motion_model_settings) - - attachment = get_mm_attachment(motion_model) - if params.motion_model_settings.mask_attn_scale is not None: - attachment.scale_multival = params.motion_model_settings.mask_attn_scale * params.motion_model_settings.attn_scale - else: - attachment.scale_multival = params.motion_model_settings.attn_scale - - attachment.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() - - # need to use a ModelPatcher that supports injection of motion modules into unet - model = model.clone() - helper = ModelPatcherHelper(model) - helper.set_all_properties( - outer_sampler_wrapper=outer_sample_wrapper, - calc_cond_batch_wrapper=sliding_calc_cond_batch, - params=params, - sample_settings=sample_settings, - motion_models=MotionModelGroup(motion_model), - ) - - sample_settings = helper.get_sample_settings() - if sample_settings.custom_cfg is not None: - logger.info("[Sample Settings] custom_cfg is set; will override any KSampler cfg values or patches.") - - if sample_settings.sigma_schedule is not None: - logger.info("[Sample Settings] sigma_schedule is set; will override beta_schedule.") - model.add_object_patch("model_sampling", sample_settings.sigma_schedule.clone().model_sampling) - else: - # save model sampling from BetaSchedule as object patch - # if autoselect, get suggested beta_schedule from motion model - if beta_schedule == BetaSchedules.AUTOSELECT and helper.get_motion_models(): - beta_schedule = helper.get_motion_models()[0].model.get_best_beta_schedule(log=True) - new_model_sampling = BetaSchedules.to_model_sampling(beta_schedule, model) - if new_model_sampling is not None: - model.add_object_patch("model_sampling", new_model_sampling) - - del motion_model - return (model,) diff --git a/animatediff/nodes_gen2.py b/animatediff/nodes_gen2.py index 19ed454..0f4b0c6 100644 --- a/animatediff/nodes_gen2.py +++ b/animatediff/nodes_gen2.py @@ -7,12 +7,12 @@ from .context import ContextOptionsGroup from .logger import logger from .utils_model import BIGMAX, BetaSchedules, get_available_motion_models -from .utils_motion import ADKeyframeGroup, ADKeyframe, InputPIA +from .utils_motion import ADKeyframeGroup, ADKeyframe, InputPIA, AllPerBlocks from .motion_lora import MotionLoraList -from .motion_module_ad import AllPerBlocks from .model_injection import (ModelPatcherHelper, InjectionParams, MotionModelGroup, MotionModelPatcher, get_mm_attachment, create_fresh_motion_module, - load_motion_module_gen2, load_motion_lora_as_patches, validate_model_compatibility_gen2, validate_per_block_compatibility) + load_motion_module_gen2, load_motion_lora_as_patches, validate_model_compatibility_gen2, + validate_per_block_compatibility, validate_per_block_compatibility_keyframes) from .sample_settings import SampleSettings from .sampling import outer_sample_wrapper, sliding_calc_cond_batch @@ -135,6 +135,7 @@ def apply_motion_model(self, motion_model: MotionModelPatcher, start_percent: fl validate_per_block_compatibility(motion_model=motion_model, all_per_blocks=per_block) attachment.per_block_list = per_block.per_block_list attachment.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() + validate_per_block_compatibility_keyframes(motion_model=motion_model, keyframes=attachment.keyframes) attachment.timestep_percent_range = (start_percent, end_percent) # add to beginning, so that after injection, it will be the earliest of prev_m_models to be run prev_m_models.add_to_start(mm=motion_model) @@ -211,6 +212,7 @@ def INPUT_TYPES(s): "prev_ad_keyframes": ("AD_KEYFRAMES", ), "scale_multival": ("MULTIVAL",), "effect_multival": ("MULTIVAL",), + "per_block_replace": ("PER_BLOCK",), "inherit_missing": ("BOOLEAN", {"default": True}, ), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), }, @@ -227,6 +229,7 @@ def INPUT_TYPES(s): def load_keyframe(self, start_percent: float, prev_ad_keyframes=None, scale_multival: Union[float, torch.Tensor]=None, effect_multival: Union[float, torch.Tensor]=None, + per_block_replace: AllPerBlocks=None, cameractrl_multival: Union[float, torch.Tensor]=None, pia_input: InputPIA=None, inherit_missing: bool=True, guarantee_steps: int=1): if not prev_ad_keyframes: @@ -234,6 +237,7 @@ def load_keyframe(self, prev_ad_keyframes = prev_ad_keyframes.clone() keyframe = ADKeyframe(start_percent=start_percent, scale_multival=scale_multival, effect_multival=effect_multival, + per_block_replace=per_block_replace, cameractrl_multival=cameractrl_multival, pia_input=pia_input, inherit_missing=inherit_missing, guarantee_steps=guarantee_steps) prev_ad_keyframes.add(keyframe) diff --git a/animatediff/nodes_motionctrl.py b/animatediff/nodes_motionctrl.py index 7c84753..3d02e1f 100644 --- a/animatediff/nodes_motionctrl.py +++ b/animatediff/nodes_motionctrl.py @@ -2,15 +2,14 @@ from torch import Tensor from .ad_settings import AnimateDiffSettings -from .adapter_motionctrl import injection_motionctrl_cmcm, load_motionctrl_omcm +from .adapter_motionctrl import inject_motionctrl_cmcm, load_motionctrl_omcm -from .motion_module_ad import AllPerBlocks from .model_injection import MotionModelPatcher, MotionModelGroup, load_motion_module_gen2 from .motion_lora import MotionLoraList from .nodes_gen2 import ApplyAnimateDiffModelNode from .utils_model import get_available_motion_models -from .utils_motion import ADKeyframeGroup +from .utils_motion import ADKeyframeGroup, AllPerBlocks class LoadMotionCtrlCMCM: @@ -24,6 +23,7 @@ def INPUT_TYPES(s): "motionctrl_cmcm": (get_available_motion_models(),), }, "optional": { + "override_ad_weights": ("BOOLEAN", {"default": True}), "ad_settings": ("AD_SETTINGS",), } } @@ -33,9 +33,10 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/MotionCtrl" FUNCTION = "load_motionctrl_cmcm" - def load_motionctrl_cmcm(self, model_name: str, motionctrl_cmcm: str, ad_settings: AnimateDiffSettings=None): + def load_motionctrl_cmcm(self, model_name: str, motionctrl_cmcm: str, + override_ad_weights=True, ad_settings: AnimateDiffSettings=None,): motion_model = load_motion_module_gen2(model_name=model_name, motion_model_settings=ad_settings) - motion_model = injection_motionctrl_cmcm(motion_model, cmcm_name=motionctrl_cmcm) + inject_motionctrl_cmcm(motion_model.model, cmcm_name=motionctrl_cmcm, apply_non_ccs=override_ad_weights) return (motion_model,) diff --git a/animatediff/nodes_per_block.py b/animatediff/nodes_per_block.py index ed077e7..549bfd6 100644 --- a/animatediff/nodes_per_block.py +++ b/animatediff/nodes_per_block.py @@ -2,9 +2,9 @@ from torch import Tensor from .documentation import short_desc, register_description, coll, DocHelper -from .motion_module_ad import PerBlock, PerBlockId, BlockType, AllPerBlocks +from .motion_module_ad import BlockType from .utils_model import ModelTypeSD -from .utils_motion import extend_list_to_batch_size +from .utils_motion import AllPerBlocks, PerBlock, PerBlockId, extend_list_to_batch_size class ADBlockHolder: @@ -119,7 +119,7 @@ def create_per_block(self, if block is not None: blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) if len(blocks) == 0: - return (None,) + blocks = None return (AllPerBlocks(blocks),) @@ -175,7 +175,7 @@ def create_per_block(self, if block is not None: blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) if len(blocks) == 0: - return (None,) + blocks = None return (AllPerBlocks(blocks, ModelTypeSD.SD1_5),) @@ -267,7 +267,7 @@ def create_per_block(self, if block is not None: blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) if len(blocks) == 0: - return (None,) + blocks = None return (AllPerBlocks(blocks, ModelTypeSD.SD1_5),) @@ -305,7 +305,7 @@ def create_per_block(self, effect_21_floats: Union[list[float], None]=None, scale_21_floats: Union[list[float], None]=None): if effect_21_floats is None and scale_21_floats is None: - return (None,) + return (AllPerBlocks(None, ModelTypeSD.SD1_5),) # SD1.5 has 21 blocks block_total = 21 holders = [ADBlockHolder() for _ in range(block_total)] @@ -366,7 +366,7 @@ def create_per_block(self, if block is not None: blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) if len(blocks) == 0: - return (None,) + blocks = None return (AllPerBlocks(blocks, ModelTypeSD.SDXL),) @@ -443,7 +443,7 @@ def create_per_block(self, if block is not None: blocks.append(PerBlock(id=id, effect=block.effect, scales=block.scales)) if len(blocks) == 0: - return (None,) + blocks = None return (AllPerBlocks(blocks, ModelTypeSD.SDXL),) @@ -481,7 +481,7 @@ def create_per_block(self, effect_16_floats: Union[list[float], None]=None, scale_16_floats: Union[list[float], None]=None): if effect_16_floats is None and scale_16_floats is None: - return (None,) + return (AllPerBlocks(None, ModelTypeSD.SDXL),) # SDXL has 16 blocks block_total = 16 holders = [ADBlockHolder() for _ in range(block_total)] diff --git a/animatediff/utils_motion.py b/animatediff/utils_motion.py index e6f69a0..122244f 100644 --- a/animatediff/utils_motion.py +++ b/animatediff/utils_motion.py @@ -4,6 +4,7 @@ from torch import Tensor, nn from abc import ABC, abstractmethod from collections.abc import Iterable +from dataclasses import dataclass import comfy.model_management as model_management import comfy.ops @@ -335,11 +336,77 @@ def get_combined_input_effect_multival(inputA: Union[InputPIA, None], inputB: Un return get_combined_multival(inputA.effect_multival, inputB.effect_multival) +####################### +# Facilitate Per-Block Effect and Scale Control +class PerAttn: + def __init__(self, attn_idx: Union[int, None], scale: Union[float, Tensor, None]): + self.attn_idx = attn_idx + self.scale = scale + + def matches(self, id: int): + if self.attn_idx is None: + return True + return self.attn_idx == id + + +class PerBlockId: + def __init__(self, block_type: str, block_idx: Union[int, None]=None, module_idx: Union[int, None]=None): + self.block_type = block_type + self.block_idx = block_idx + self.module_idx = module_idx + + def matches(self, other: 'PerBlockId') -> bool: + # block_type + if other.block_type != self.block_type: + return False + # block_idx + if other.block_idx is None: + return True + elif other.block_idx != self.block_idx: + return False + # module_idx + if other.module_idx is None: + return True + return other.module_idx == self.module_idx + + def __str__(self): + return f"PerBlockId({self.block_type},{self.block_idx},{self.module_idx})" + + +class PerBlock: + def __init__(self, id: PerBlockId, effect: Union[float, Tensor, None]=None, + scales: Union[list[Union[float, Tensor, None]], None]=None): + self.id = id + self.effect = effect + self.scales = scales + + def matches(self, id: PerBlockId): + return self.id.matches(id) + + +@dataclass +class AllPerBlocks: + per_block_list: list[PerBlock] + sd_type: Union[str, None] = None + + +def get_combined_per_block_list(listDefault: Union[list[PerBlock], None], listNew: Union[list[PerBlock], None]): + if listDefault is None: + return listNew + elif listNew is None: + return listDefault + else: + return listNew +#---------------------- +####################### + + class ADKeyframe: def __init__(self, start_percent: float = 0.0, scale_multival: Union[float, Tensor]=None, effect_multival: Union[float, Tensor]=None, + per_block_replace: AllPerBlocks=None, cameractrl_multival: Union[float, Tensor]=None, pia_input: InputPIA=None, inherit_missing: bool=True, @@ -350,18 +417,28 @@ def __init__(self, self.start_t = 999999999.9 self.scale_multival = scale_multival self.effect_multival = effect_multival + self._per_block_replace = per_block_replace self.cameractrl_multival = cameractrl_multival self.pia_input = pia_input self.inherit_missing = inherit_missing self.guarantee_steps = guarantee_steps self.default = default + @property + def per_block_list(self): + if self._per_block_replace is None: + return None + return self._per_block_replace.per_block_list + def has_scale(self): return self.scale_multival is not None def has_effect(self): return self.effect_multival is not None + def has_per_block_replace(self): + return self._per_block_replace is not None + def has_cameractrl_effect(self): return self.cameractrl_multival is not None From 93cafde180924deec5852051c60c7192fbbc58f1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 29 Dec 2024 14:42:54 -0600 Subject: [PATCH 06/12] Got MotionCtrl's camere control working, fixed frame_length affecting CameraCtrl Pose speed; still need to fix CameraCtrl Poses not adding to each other as expected --- animatediff/adapter_motionctrl.py | 11 +++++++ animatediff/model_injection.py | 47 ++++++++++++++++++++++++++++ animatediff/motion_module_ad.py | 22 ++++++++++--- animatediff/nodes.py | 3 +- animatediff/nodes_cameractrl.py | 34 ++++++++++++++++++--- animatediff/nodes_motionctrl.py | 51 +++++++++++++++++++++++++++++-- animatediff/sampling.py | 5 +-- 7 files changed, 159 insertions(+), 14 deletions(-) diff --git a/animatediff/adapter_motionctrl.py b/animatediff/adapter_motionctrl.py index fef7d1d..01bf922 100644 --- a/animatediff/adapter_motionctrl.py +++ b/animatediff/adapter_motionctrl.py @@ -1,6 +1,7 @@ # main code adapted from https://github.com/TencentARC/MotionCtrl/tree/animatediff from __future__ import annotations from torch import nn, Tensor +import torch from comfy.model_patcher import ModelPatcher import comfy.model_management @@ -83,6 +84,16 @@ def _remove_module_prefix(state_dict: dict[str, Tensor]): state_dict.pop(key) +def convert_cameractrl_poses_to_RT(poses: list[list[float]]): + tensors = [] + for pose in poses: + new_tensor = torch.tensor(pose[7:]) + new_tensor = new_tensor.unsqueeze(0) + tensors.append(new_tensor) + RT = torch.cat(tensors, dim=0) + return RT + + class ObjectControlModelPatcher(ModelPatcher): '''Class only used for type hints.''' def __init__(self): diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 954d7b5..0a4aca7 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -291,6 +291,12 @@ def __init__(self): self.prev_fancy_latents_shape: tuple = None self.fancy_multival: Union[float, Tensor] = None + # MotionCtrl + self.orig_RT: Tensor = None + self.RT: Tensor = None + self.prev_RT_shape: tuple = None + self.prev_RT_uuids: list = None + # temporary variables self.current_used_steps = 0 self.current_keyframe: ADKeyframe = None @@ -460,6 +466,43 @@ def prepare_camera_features(self, patcher: MotionModelPatcher, x: Tensor, cond_o self.prev_sub_idxs = sub_idxs self.prev_batched_number = batched_number + def prepare_motionctrl_camera(self, patcher: MotionModelPatcher, x: Tensor, transformer_options: dict[str]): + '''Used for MotionCtrl''' + # if no cc enabled, done + if not patcher.model.is_motionctrl_cc_enabled(): + if "ADE_RT" in transformer_options: + transformer_options.pop("ADE_RT") + return + cond_or_uncond: list[int] = transformer_options["cond_or_uncond"] + uuids: list = transformer_options["uuids"] + batched_number = len(cond_or_uncond) + ad_params = transformer_options["ad_params"] + full_length = ad_params["full_length"] + sub_idxs = ad_params["sub_idxs"] + goal_length = x.size(0) // batched_number + if self.prev_RT_shape != x.shape or sub_idxs != self.prev_sub_idxs or uuids != self.prev_RT_uuids: + real_RT = self.orig_RT.clone().to(dtype=x.dtype, device=x.device) # [t, 12] + # make sure RT is of the valid length + real_RT = extend_to_batch_size(real_RT, full_length) + if sub_idxs is not None: + real_RT = real_RT[sub_idxs] + real_RT = real_RT.unsqueeze(0) # [1, t, 12] + # match batch length - conds get real_RT, unconds get empty + if batched_number > 1: + batched_RTs = [] + for condtype in cond_or_uncond: + if condtype == 0: # cond + batched_RTs.append(real_RT) + else: # uncond + batched_RTs.append(torch.zeros_like(real_RT)) + real_RT = torch.cat(batched_RTs, dim=0) + self.RT = real_RT.to(dtype=x.dtype, device=x.device) + self.prev_RT_shape = x.shape + transformer_options["ADE_RT"] = self.RT + self.prev_sub_idxs = sub_idxs + self.prev_batched_number = batched_number + + def get_pia_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: '''Used for PIA''' # if have cached shape, check if matches - if so, return cached pia_latents @@ -583,6 +626,10 @@ def cleanup(self, patcher: MotionModelPatcher): # PIA self.combined_pia_mask = None self.combined_pia_effect = None + # MotionCtrl + self.RT = None + self.prev_RT_shape = None + self.prev_RT_uuids = None # Default self.current_used_steps = 0 self.current_keyframe = None diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 0fefe1d..d6af306 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -1,3 +1,4 @@ +from __future__ import annotations import math from typing import Iterable, Tuple, Union, TYPE_CHECKING import re @@ -366,7 +367,7 @@ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo, i if has_img_encoder(mm_state_dict): self.init_img_encoder() # CameraCtrl stuff - self.camera_encoder: 'CameraPoseEncoder' = None + self.camera_encoder: CameraPoseEncoder = None # PIA/FancyVideo stuff - create conv_in if keys are present for it self.conv_in: comfy.ops.disable_weight_init.Conv2d = None self.orig_conv_in: comfy.ops.disable_weight_init.Conv2d = None @@ -382,11 +383,15 @@ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo, i # get_unet_func initialization self.get_unet_func = init_kwargs.get(InitKwargs.GET_UNET_FUNC, get_unet_default) + def needs_apply_model_wrapper(self): + '''Returns true of AnimateLCM-I2V, CameraCtrl, or MotionCtrl is in use.''' + return self.img_encoder is not None or self.camera_encoder is not None or self.is_motionctrl_cc_enabled() + def init_img_encoder(self): del self.img_encoder self.img_encoder = AdapterEmbed(cin=4, channels=self.layer_channels, nums_rb=2, ksize=1, sk=True, use_conv=False, ops=self.ops) - def set_camera_encoder(self, camera_encoder: 'CameraPoseEncoder'): + def set_camera_encoder(self, camera_encoder: CameraPoseEncoder): del self.camera_encoder self.camera_encoder = camera_encoder @@ -428,6 +433,13 @@ def init_motionctrl_cc_projections(self, state_dict: dict[str, Tensor]): ttb: TemporalTransformerBlock = comfy.utils.get_attr(self, ttb_key) ttb.init_cc_projection(in_features=in_features, out_features=out_features, ops=self.ops) + def is_motionctrl_cc_enabled(self): + '''Used for MotionCtrl''' + if self.down_blocks: + ttb: TemporalTransformerBlock = self.down_blocks[0].motion_modules[0].temporal_transformer.transformer_blocks[0] + return ttb.cc_projection is not None + return False + def get_fancyvideo_emb_patches(self, dtype, device, fps=25, motion_score=3.0): patches = [] if self.fps_embedding is not None: @@ -1320,12 +1332,12 @@ def forward( ) # do MotionCtrl-CMCM stuff if needed if self.cc_projection is not None and count==0 and 'ADE_RT' in transformer_options: - RT: Tensor = transformer_options['ADE_RT'] + RT: Tensor = transformer_options['ADE_RT'].to(dtype=hidden_states.dtype) B, t, _ = RT.shape RT = RT.reshape(B*t, 1, -1) - RT = RT.repeat(1, hidden_states.shape[1]) + RT = RT.repeat(1, hidden_states.shape[1], 1) hidden_states = torch.cat([hidden_states, RT], dim=-1) - hidden_states = self.cc_projection(hidden_states) + hidden_states = self.cc_projection(hidden_states).to(dtype=hidden_states.dtype) count += 1 else: # views idea gotten from diffusers AnimateDiff FreeNoise implementation: diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 64befdc..f83b32a 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -8,7 +8,7 @@ LoadCameraPosesFromFile, LoadCameraPosesFromPath, CameraCtrlPoseBasic, CameraCtrlPoseCombo, CameraCtrlPoseAdvanced, CameraCtrlManualAppendPose, CameraCtrlReplaceCameraParameters, CameraCtrlSetOriginalAspectRatio) -from .nodes_motionctrl import (LoadMotionCtrlCMCM, LoadMotionCtrlOMCM, ApplyAnimateDiffMotionCtrlModel) +from .nodes_motionctrl import (LoadMotionCtrlCMCM, LoadMotionCtrlOMCM, ApplyAnimateDiffMotionCtrlModel, LoadMotionCtrlCameraPosesFromFile) from .nodes_pia import (ApplyAnimateDiffPIAModel, LoadAnimateDiffAndInjectPIANode, InputPIA_MultivalNode, InputPIA_PaperPresetsNode, PIA_ADKeyframeNode) from .nodes_fancyvideo import (ApplyAnimateDiffFancyVideo,) from .nodes_hellomeme import (TestHMRefNetInjection,) @@ -197,6 +197,7 @@ LoadMotionCtrlCMCM.NodeID: LoadMotionCtrlCMCM, LoadMotionCtrlOMCM.NodeID: LoadMotionCtrlOMCM, ApplyAnimateDiffMotionCtrlModel.NodeID: ApplyAnimateDiffMotionCtrlModel, + LoadMotionCtrlCameraPosesFromFile.NodeID: LoadMotionCtrlCameraPosesFromFile, # CameraCtrl Nodes "ADE_ApplyAnimateDiffModelWithCameraCtrl": ApplyAnimateDiffWithCameraCtrl, "ADE_LoadAnimateDiffModelWithCameraCtrl": LoadAnimateDiffModelWithCameraCtrl, diff --git a/animatediff/nodes_cameractrl.py b/animatediff/nodes_cameractrl.py index 7b7bb9a..b91c033 100644 --- a/animatediff/nodes_cameractrl.py +++ b/animatediff/nodes_cameractrl.py @@ -115,13 +115,13 @@ def compute_R_from_rad_angle(angles: np.ndarray): R = np.dot(Rz, np.dot(Ry, Rx)) return R -def get_camera_motion(angle: np.ndarray, T: np.ndarray, speed: float, n=16): +def get_camera_motion(angle: np.ndarray, T: np.ndarray, speed: float, n=16, base=16): RT = [] for i in range(n): - _angle = (i/n)*speed*(CAM.BASE_ANGLE)*angle + _angle = (i/base)*speed*(CAM.BASE_ANGLE)*angle R = compute_R_from_rad_angle(_angle) # _T = (i/n)*speed*(T.reshape(3,1)) - _T=(i/n)*speed*(CAM.BASE_T_NORM)*(T.reshape(3,1)) + _T=(i/base)*speed*(CAM.BASE_T_NORM)*(T.reshape(3,1)) _RT = np.concatenate([R,_T], axis=1) RT.append(_RT) RT = np.stack(RT) @@ -143,6 +143,22 @@ def combine_RTs(RT_0: np.ndarray, RT_1: np.ndarray): return np.concatenate([RT_0, RT_1], axis=0) +def stack_RTs(RT_0: np.ndarray, RT_1: np.ndarray): + RT_target = copy.deepcopy(RT_1) + static_motion = CAM.get(CAM.STATIC) + RT_static = get_camera_motion(static_motion.rotate, static_motion.translate, 1.0, 1) + RT_offset = RT_0[-1] - RT_static[-1] + + temp = [] + for sub_RT in RT_target: + temp.append(sub_RT + RT_offset) + + RT_1 = np.stack(temp) + RT_0 = RT_0[:-1] + + return np.concatenate([RT_0, RT_1], axis=0) + + def set_original_pose_dims(poses: list[list[float]], pose_width, pose_height): # indexes 5 and 6 are not used for anything in the poses, so can use 5 and 6 to set original pose width/height new_poses = copy.deepcopy(poses) @@ -157,7 +173,17 @@ def combine_poses(poses0: list[list[float]], poses1: list[list[float]]): inter_poses = ndarray_to_poses(new_RT) # maintain fx, fy, cx, and cy values by pasting only the movement portion of poses for i in range(len(new_poses)): - new_poses[7:] = inter_poses[7:] + new_poses[i][7:] = inter_poses[i][7:] + return new_poses + + +def combine_poses_redux(poses0: list[list[float]], poses1: list[list[float]]): + new_poses = copy.deepcopy(poses0[:-1]) + copy.deepcopy(poses1) + new_RT = stack_RTs(poses_to_ndarray(poses0), poses_to_ndarray(poses1)) + inter_poses = ndarray_to_poses(new_RT) + # maintain fx, fy, cx, and cy values by pasting only the movement portion of poses + for i in range(len(new_poses)): + new_poses[i][7:] = inter_poses[i][7:] return new_poses diff --git a/animatediff/nodes_motionctrl.py b/animatediff/nodes_motionctrl.py index 3d02e1f..81c057f 100644 --- a/animatediff/nodes_motionctrl.py +++ b/animatediff/nodes_motionctrl.py @@ -1,10 +1,16 @@ import torch from torch import Tensor +import numpy as np +import os +import json + +import folder_paths from .ad_settings import AnimateDiffSettings -from .adapter_motionctrl import inject_motionctrl_cmcm, load_motionctrl_omcm +from .adapter_motionctrl import (ObjectControlModelPatcher, inject_motionctrl_cmcm, load_motionctrl_omcm, + convert_cameractrl_poses_to_RT) -from .model_injection import MotionModelPatcher, MotionModelGroup, load_motion_module_gen2 +from .model_injection import MotionModelPatcher, MotionModelGroup, load_motion_module_gen2, get_mm_attachment from .motion_lora import MotionLoraList from .nodes_gen2 import ApplyAnimateDiffModelNode @@ -60,6 +66,33 @@ def load_motionctrl_omcm(self, motionctrl_omcm: str): return (omcm_modelpatcher,) +class LoadMotionCtrlCameraPosesFromFile: + NodeID = "ADE_LoadMotionCtrlCameraPosesFromFile" + NodeName = "Load MotionCtrl Camera Poses πŸŽ­πŸ…πŸ…“" + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = [f for f in files if f.endswith(".json")] + return { + "required": { + "pose_filename": (sorted(files),), + } + } + + RETURN_TYPES = ("CAMERA_MOTIONCTRL",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/MotionCtrl" + FUNCTION = "load_camera_poses" + + def load_camera_poses(self, pose_filename): + file_path = folder_paths.get_annotated_filepath(pose_filename) + with open(file_path, 'r') as f: + RT = json.load(f) + RT = np.array(RT) + RT = torch.tensor(RT).float() # [t, 12] + return (RT,) + + class ApplyAnimateDiffMotionCtrlModel: NodeID = "ADE_ApplyAnimateDiffModelWithMotionCtrl" NodeName = "Apply AnimateDiff+MotionCtrl Model πŸŽ­πŸ…πŸ…“β‘‘" @@ -73,6 +106,7 @@ def INPUT_TYPES(s): }, "optional": { "omcm_motionctrl": ("OMCM_MOTIONCTRL",), + "cameractrl_poses": ("CAMERACTRL_POSES",), "motion_lora": ("MOTION_LORA",), "scale_multival": ("MULTIVAL",), "effect_multival": ("MULTIVAL",), @@ -90,6 +124,8 @@ def INPUT_TYPES(s): FUNCTION = "apply_motion_model" def apply_motion_model(self, motion_model: MotionModelPatcher, start_percent: float=0.0, end_percent: float=1.0, + omcm_motionctrl: ObjectControlModelPatcher=None, + cameractrl_poses: list[list[float]]=None, motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, scale_multival=None, effect_multival=None, per_block: AllPerBlocks=None, prev_m_models: MotionModelGroup=None,): @@ -99,5 +135,16 @@ def apply_motion_model(self, motion_model: MotionModelPatcher, start_percent: fl # most recent added model will always be first in list curr_model = new_m_models.models[0] # check if model has CMCM; if so, make sure something is provided for it + if curr_model.model.is_motionctrl_cc_enabled(): + attachment = get_mm_attachment(curr_model) + if cameractrl_poses is not None: + RT = convert_cameractrl_poses_to_RT(cameractrl_poses) + attachment.orig_RT = RT + else: + attachment.orig_RT = torch.zeros((1, 12)) + # attachment.orig_RT = cameractrl_poses + # else: + # attachment.orig_RT = torch.zeros([]) + # check if OMCM is provided; if so, make sure something is provided for it return (new_m_models,) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index f0fdf84..a48af04 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -193,6 +193,7 @@ def _apply_model_wrapper(executor, *args, **kwargs): attachment = get_mm_attachment(motion_model) attachment.prepare_alcmi2v_features(motion_model, x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params, latent_format=executor.class_obj.latent_format) attachment.prepare_camera_features(motion_model, x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params) + attachment.prepare_motionctrl_camera(motion_model, x=x, transformer_options=transformer_options) del x return executor(*args, **kwargs) @@ -285,9 +286,9 @@ def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams, helper.model.model.memory_required = unlimited_memory_required except Exception: pass - # if img_encoder or camera_encoder present, inject apply_model to handle correctly + # if AnimateLCM-I2V, CameraCtrl, or MotionCtrl present, inject apply_model to handle correctly for motion_model in helper.get_motion_models(): - if (motion_model.model.img_encoder is not None) or (motion_model.model.camera_encoder is not None): + if motion_model.model.needs_apply_model_wrapper(): create_special_model_apply_model_wrapper(model_options) break del info From 856d850437d9b21729d25888bcfcc0b217beff6d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Dec 2024 23:32:47 -0600 Subject: [PATCH 07/12] Fixed keyframes not working as expected when sampling is spread across multiple nodes/calls, removed conditioning.py since most of the code there was already deprecated in favor of vanilla ComfyUI after my PR was merged in early December --- animatediff/conditioning.py | 303 ------------------------------ animatediff/context.py | 20 +- animatediff/context_extras.py | 30 ++- animatediff/model_injection.py | 14 +- animatediff/nodes_conditioning.py | 7 +- animatediff/sample_settings.py | 16 +- animatediff/sampling.py | 10 +- animatediff/utils_motion.py | 6 + 8 files changed, 70 insertions(+), 336 deletions(-) delete mode 100644 animatediff/conditioning.py diff --git a/animatediff/conditioning.py b/animatediff/conditioning.py deleted file mode 100644 index 455257a..0000000 --- a/animatediff/conditioning.py +++ /dev/null @@ -1,303 +0,0 @@ -from torch import Tensor - -from comfy.model_base import BaseModel - -from .utils_motion import get_sorted_list_via_attr - - -class LoraHookMode: - MIN_VRAM = "min_vram" - MAX_SPEED = "max_speed" - #MIN_VRAM_LOWVRAM = "min_vram_lowvram" - #MAX_SPEED_LOWVRAM = "max_speed_lowvram" - - -# Acts simply as a way to track unique LoraHooks -class HookRef: - pass - - -class LoraHook: - def __init__(self, lora_name: str): - self.lora_name = lora_name - self.lora_keyframe = LoraHookKeyframeGroup() - self.hook_ref = HookRef() - - def initialize_timesteps(self, model: BaseModel): - self.lora_keyframe.initialize_timesteps(model) - - def reset(self): - self.lora_keyframe.reset() - - - def get_copy(self): - ''' - Copies LoraHook, but maintains same HookRef - ''' - c = LoraHook(lora_name=self.lora_name) - c.lora_keyframe = self.lora_keyframe - c.hook_ref = self.hook_ref # same instance that acts as ref - return c - - @property - def strength(self): - return self.lora_keyframe.strength - - def __eq__(self, other: 'LoraHook'): - return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref - - def __hash__(self): - return hash(self.hook_ref) - - -class LoraHookGroup: - ''' - Stores LoRA hooks to apply for conditioning - ''' - def __init__(self): - self.hooks: list[LoraHook] = [] - - def names(self): - names = [] - for hook in self.hooks: - names.append(hook.lora_name) - return ",".join(names) - - def add(self, hook: LoraHook): - if hook not in self.hooks: - self.hooks.append(hook) - - def is_empty(self): - return len(self.hooks) == 0 - - def contains(self, lora_hook: LoraHook): - return lora_hook in self.hooks - - def clone(self): - cloned = LoraHookGroup() - for hook in self.hooks: - cloned.add(hook.get_copy()) - return cloned - - def clone_and_combine(self, other: 'LoraHookGroup'): - cloned = self.clone() - for hook in other.hooks: - cloned.add(hook.get_copy()) - return cloned - - def set_keyframes_on_hooks(self, hook_kf: 'LoraHookKeyframeGroup'): - hook_kf = hook_kf.clone() - for hook in self.hooks: - hook.lora_keyframe = hook_kf - - @staticmethod - def combine_all_lora_hooks(lora_hooks_list: list['LoraHookGroup'], require_count=1) -> 'LoraHookGroup': - actual: list[LoraHookGroup] = [] - for group in lora_hooks_list: - if group is not None: - actual.append(group) - if len(actual) < require_count: - raise Exception(f"Need at least {require_count} LoRA Hooks to combine, but only had {len(actual)}.") - # if only 1 hook, just return itself without any cloning - if len(actual) == 1: - return actual[0] - final_hook: LoraHookGroup = None - for hook in actual: - if final_hook is None: - final_hook = hook.clone() - else: - final_hook = final_hook.clone_and_combine(hook) - return final_hook - - -class LoraHookKeyframe: - def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1): - self.strength = strength - # scheduling - self.start_percent = float(start_percent) - self.start_t = 999999999.9 - self.guarantee_steps = guarantee_steps - - def clone(self): - c = LoraHookKeyframe(strength=self.strength, - start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) - c.start_t = self.start_t - return c - -class LoraHookKeyframeGroup: - def __init__(self): - self.keyframes: list[LoraHookKeyframe] = [] - self._current_keyframe: LoraHookKeyframe = None - self._current_used_steps: int = 0 - self._current_index: int = 0 - self._curr_t: float = -1 - - def reset(self): - self._current_keyframe = None - self._current_used_steps = 0 - self._current_index = 0 - self._curr_t = -1 - self._set_first_as_current() - - def add(self, keyframe: LoraHookKeyframe): - # add to end of list, then sort - self.keyframes.append(keyframe) - self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") - self._set_first_as_current() - - def _set_first_as_current(self): - if len(self.keyframes) > 0: - self._current_keyframe = self.keyframes[0] - else: - self._current_keyframe = None - - def has_index(self, index: int) -> int: - return index >= 0 and index < len(self.keyframes) - - def is_empty(self) -> bool: - return len(self.keyframes) == 0 - - def clone(self): - cloned = LoraHookKeyframeGroup() - for keyframe in self.keyframes: - cloned.keyframes.append(keyframe) - cloned._set_first_as_current() - return cloned - - def initialize_timesteps(self, model: BaseModel): - for keyframe in self.keyframes: - keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - - def prepare_current_keyframe(self, curr_t: float) -> bool: - if self.is_empty(): - return False - if curr_t == self._curr_t: - return False - prev_index = self._current_index - # if met guaranteed steps, look for next keyframe in case need to switch - if self._current_used_steps >= self._current_keyframe.guarantee_steps: - # if has next index, loop through and see if need t oswitch - if self.has_index(self._current_index+1): - for i in range(self._current_index+1, len(self.keyframes)): - eval_c = self.keyframes[i] - # check if start_t is greater or equal to curr_t - # NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling - if eval_c.start_t >= curr_t: - self._current_index = i - self._current_keyframe = eval_c - self._current_used_steps = 0 - # if guarantee_steps greater than zero, stop searching for other keyframes - if self._current_keyframe.guarantee_steps > 0: - break - # if eval_c is outside the percent range, stop looking further - else: break - # update steps current context is used - self._current_used_steps += 1 - # update current timestep this was performed on - self._curr_t = curr_t - # return True if keyframe changed, False if no change - return prev_index != self._current_index - - # properties shadow those of LoraHookKeyframe - @property - def strength(self): - if self._current_keyframe is not None: - return self._current_keyframe.strength - return 1.0 - - -class COND_CONST: - KEY_LORA_HOOK = "lora_hook" - KEY_DEFAULT_COND = "default_cond" - - COND_AREA_DEFAULT = "default" - COND_AREA_MASK_BOUNDS = "mask bounds" - _LIST_COND_AREA = [COND_AREA_DEFAULT, COND_AREA_MASK_BOUNDS] - - -class TimestepsCond: - def __init__(self, start_percent: float, end_percent: float): - self.start_percent = start_percent - self.end_percent = end_percent - - -def conditioning_set_values(conditioning, values={}): - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - for k in values: - n[1][k] = values[k] - c.append(n) - return c - -def set_lora_hook_for_conditioning(conditioning, lora_hook: LoraHookGroup): - if lora_hook is None: - return conditioning - return conditioning_set_values(conditioning, {COND_CONST.KEY_LORA_HOOK: lora_hook}) - -def set_timesteps_for_conditioning(conditioning, timesteps_cond: TimestepsCond): - if timesteps_cond is None: - return conditioning - return conditioning_set_values(conditioning, {"start_percent": timesteps_cond.start_percent, - "end_percent": timesteps_cond.end_percent}) - -def set_mask_for_conditioning(conditioning, mask: Tensor, set_cond_area: str, strength: float): - if mask is None: - return conditioning - set_area_to_bounds = False - if set_cond_area != COND_CONST.COND_AREA_DEFAULT: - set_area_to_bounds = True - if len(mask.shape) < 3: - mask = mask.unsqueeze(0) - - return conditioning_set_values(conditioning, {"mask": mask, - "set_area_to_bounds": set_area_to_bounds, - "mask_strength": strength}) - -def combine_conditioning(conds: list): - combined_conds = [] - for cond in conds: - combined_conds.extend(cond) - return combined_conds - -def set_mask_conds(conds: list, strength: float, set_cond_area: str, - opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None): - masked_conds = [] - for c in conds: - # first, apply lora_hook to conditioning, if provided - c = set_lora_hook_for_conditioning(c, opt_lora_hook) - # next, apply mask to conditioning - c = set_mask_for_conditioning(conditioning=c, mask=opt_mask, strength=strength, set_cond_area=set_cond_area) - # apply timesteps, if present - c = set_timesteps_for_conditioning(conditioning=c, timesteps_cond=opt_timesteps) - # finally, apply mask to conditioning and store - masked_conds.append(c) - return masked_conds - -def set_mask_and_combine_conds(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", - opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None): - combined_conds = [] - for c, masked_c in zip(conds, new_conds): - # first, apply lora_hook to new conditioning, if provided - masked_c = set_lora_hook_for_conditioning(masked_c, opt_lora_hook) - # next, apply mask to new conditioning, if provided - masked_c = set_mask_for_conditioning(conditioning=masked_c, mask=opt_mask, set_cond_area=set_cond_area, strength=strength) - # apply timesteps, if present - masked_c = set_timesteps_for_conditioning(conditioning=masked_c, timesteps_cond=opt_timesteps) - # finally, combine with existing conditioning and store - combined_conds.append(combine_conditioning([c, masked_c])) - return combined_conds - -def set_unmasked_and_combine_conds(conds: list, new_conds: list, - opt_lora_hook: LoraHookGroup, opt_timesteps: TimestepsCond=None): - combined_conds = [] - for c, new_c in zip(conds, new_conds): - # first, apply lora_hook to new conditioning, if provided - new_c = set_lora_hook_for_conditioning(new_c, opt_lora_hook) - # next, add default_cond key to cond so that during sampling, it can be identified - new_c = conditioning_set_values(new_c, {COND_CONST.KEY_DEFAULT_COND: True}) - # apply timesteps, if present - new_c = set_timesteps_for_conditioning(conditioning=new_c, timesteps_cond=opt_timesteps) - # finally, combine with existing conditioning and store - combined_conds.append(combine_conditioning([c, new_c])) - return combined_conds diff --git a/animatediff/context.py b/animatediff/context.py index df116d1..106937d 100644 --- a/animatediff/context.py +++ b/animatediff/context.py @@ -12,6 +12,7 @@ from comfy.model_patcher import ModelPatcher from .context_extras import ContextExtrasGroup +from .utils_model import BIGMAX from .utils_motion import get_sorted_list_via_attr @@ -65,6 +66,12 @@ def step(self, value: int): if self.view_options: self.view_options.step = value + def get_effective_guarantee_steps(self, max_sigma: torch.Tensor): + '''If keyframe starts before current sampling range (max_sigma), treat as 0.''' + if self.start_t > max_sigma: + return 0 + return self.guarantee_steps + def clone(self): n = ContextOptions(context_length=self.context_length, context_stride=self.context_stride, context_overlap=self.context_overlap, context_schedule=self.context_schedule, @@ -141,18 +148,19 @@ def initialize_timesteps(self, model: BaseModel): context.start_t = model.model_sampling.percent_to_sigma(context.start_percent) self.extras.initialize_timesteps(model) - def prepare_current(self, t: Tensor): - self.prepare_current_context(t) - self.extras.prepare_current(t) + def prepare_current(self, t: Tensor, transformer_options): + self.prepare_current_context(t, transformer_options) + self.extras.prepare_current(t, transformer_options) - def prepare_current_context(self, t: Tensor): + def prepare_current_context(self, t: Tensor, transformer_options: dict[str, Tensor]): curr_t: float = t[0] # if same as previous, do nothing as step already accounted for if curr_t == self._previous_t: return prev_index = self._current_index + max_sigma = torch.max(transformer_options.get("sigmas", BIGMAX)) # if met guaranteed steps, look for next context in case need to switch - if self._current_used_steps >= self._current_context.guarantee_steps: + if self._current_used_steps >= self._current_context.get_effective_guarantee_steps(max_sigma): # if has next index, loop through and see if need to switch if self.has_index(self._current_index+1): for i in range(self._current_index+1, len(self.contexts)): @@ -164,7 +172,7 @@ def prepare_current_context(self, t: Tensor): self._current_context = eval_c self._current_used_steps = 0 # if guarantee_steps greater than zero, stop searching for other keyframes - if self._current_context.guarantee_steps > 0: + if self._current_context.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further else: diff --git a/animatediff/context_extras.py b/animatediff/context_extras.py index 2f19856..dc7d959 100644 --- a/animatediff/context_extras.py +++ b/animatediff/context_extras.py @@ -5,6 +5,7 @@ from comfy.model_base import BaseModel +from .utils_model import BIGMAX from .utils_motion import (prepare_mask_batch, extend_to_batch_size, get_combined_multival, resize_multival, get_sorted_list_via_attr) @@ -25,7 +26,7 @@ def initialize_timesteps(self, model: BaseModel): self.start_t = model.model_sampling.percent_to_sigma(self.start_percent) self.end_t = model.model_sampling.percent_to_sigma(self.end_percent) - def prepare_current(self, t: Tensor): + def prepare_current(self, t: Tensor, transformer_options: dict[str, Tensor]): self.curr_t = t[0] def should_run(self): @@ -260,6 +261,12 @@ def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=None, start_per self.guarantee_steps = guarantee_steps self.inherit_missing = inherit_missing + def get_effective_guarantee_steps(self, max_sigma: torch.Tensor): + '''If keyframe starts before current sampling range (max_sigma), treat as 0.''' + if self.start_t > max_sigma: + return 0 + return self.guarantee_steps + def clone(self): c = NaiveReuseKeyframe(mult=self.mult, mult_multival=self.mult_multival, start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) @@ -330,7 +337,7 @@ def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - def prepare_current_keyframe(self, t: Tensor): + def prepare_current_keyframe(self, t: Tensor, transformer_options: dict[str, Tensor]): if self.is_empty(): return curr_t: float = t[0] @@ -338,8 +345,9 @@ def prepare_current_keyframe(self, t: Tensor): if curr_t == self._previous_t: return prev_index = self._current_index + max_sigma = torch.max(transformer_options.get("sigmas", BIGMAX)) # if met guaranteed steps, look for next keyframe in case need to switch - if self._current_used_steps >= self._current_keyframe.guarantee_steps: + if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma): # if has next index, loop through and see if need t oswitch if self.has_index(self._current_index+1): for i in range(self._current_index+1, len(self.keyframes)): @@ -351,7 +359,7 @@ def prepare_current_keyframe(self, t: Tensor): self._current_keyframe = eval_c self._current_used_steps = 0 # if guarantee_steps greater than zero, stop searching for other keyframes - if self._current_keyframe.guarantee_steps > 0: + if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further else: break @@ -394,9 +402,9 @@ def initialize_timesteps(self, model: BaseModel): super().initialize_timesteps(model) self.keyframe.initialize_timesteps(model) - def prepare_current(self, t: Tensor): - super().prepare_current(t) - self.keyframe.prepare_current_keyframe(t) + def prepare_current(self, t: Tensor, transformer_options: dict[str, Tensor]): + super().prepare_current(t, transformer_options) + self.keyframe.prepare_current_keyframe(t, transformer_options) def get_effective_weighted_mean(self, x: Tensor, idxs: list[int]): if self.orig_multival is None and self.keyframe.mult_multival is None: @@ -427,6 +435,10 @@ def should_run(self): #-------------------------------- +################################ +# DenoiseReuse + + class ContextExtrasGroup: def __init__(self): self.context_ref: ContextRef = None @@ -444,9 +456,9 @@ def initialize_timesteps(self, model: BaseModel): for extra in self.get_extras_list(): extra.initialize_timesteps(model) - def prepare_current(self, t: Tensor): + def prepare_current(self, t: Tensor, transformer_options): for extra in self.get_extras_list(): - extra.prepare_current(t) + extra.prepare_current(t, transformer_options) def should_run_context_ref(self): if not self.context_ref: diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 0a4aca7..6ce1144 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -29,9 +29,8 @@ PerBlock, AllPerBlocks, get_combined_per_block_list, get_combined_multival, get_combined_input, get_combined_input_effect_multival, ade_broadcast_image_to, extend_to_batch_size, prepare_mask_batch) -from .conditioning import HookRef, LoraHook, LoraHookGroup, LoraHookMode from .motion_lora import MotionLoraInfo, MotionLoraList -from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type, vae_encode_raw_batched +from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type, vae_encode_raw_batched, BIGMAX from .sample_settings import SampleSettings, SeedNoiseGeneration from .dinklink import DinkLinkConst, get_dinklink, get_acn_outer_sample_wrapper @@ -328,14 +327,15 @@ def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - def prepare_current_keyframe(self, patcher: MotionModelPatcher, x: Tensor, t: Tensor): + def prepare_current_keyframe(self, patcher: MotionModelPatcher, x: Tensor, t: Tensor, transformer_options: dict[str, Tensor]): curr_t: float = t[0] # if curr_t was previous_t, then do nothing (already accounted for this step) if curr_t == self.previous_t: return prev_index = self.current_index + max_sigma = torch.max(transformer_options.get("sigmas", BIGMAX)) # if met guaranteed steps, look for next keyframe in case need to switch - if self.current_keyframe is None or self.current_used_steps >= self.current_keyframe.guarantee_steps: + if self.current_keyframe is None or self.current_used_steps >= self.current_keyframe.get_effective_guarantee_steps(max_sigma): # if has next index, loop through and see if need to switch if self.keyframes.has_index(self.current_index+1): for i in range(self.current_index+1, len(self.keyframes)): @@ -373,7 +373,7 @@ def prepare_current_keyframe(self, patcher: MotionModelPatcher, x: Tensor, t: Te elif not self.current_keyframe.inherit_missing: self.current_pia_input = None # if guarantee_steps greater than zero, stop searching for other keyframes - if self.current_keyframe.guarantee_steps > 0: + if self.current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_kf is outside the percent range, stop looking further else: @@ -723,10 +723,10 @@ def cleanup(self): for motion_model in self.models: motion_model.cleanup() - def prepare_current_keyframe(self, x: Tensor, t: Tensor): + def prepare_current_keyframe(self, x: Tensor, t: Tensor, transformer_options: dict[str, Tensor]): for motion_model in self.models: attachment = get_mm_attachment(motion_model) - attachment.prepare_current_keyframe(motion_model, x=x, t=t) + attachment.prepare_current_keyframe(motion_model, x=x, t=t, transformer_options=transformer_options) def get_special_models(self): pia_motion_models: list[MotionModelPatcher] = [] diff --git a/animatediff/nodes_conditioning.py b/animatediff/nodes_conditioning.py index 54af3e5..7bd09f6 100644 --- a/animatediff/nodes_conditioning.py +++ b/animatediff/nodes_conditioning.py @@ -12,7 +12,6 @@ import comfy.hooks import comfy.utils -from .conditioning import (COND_CONST) from .utils_model import BIGMAX, InterpolationMethod from .logger import logger @@ -25,6 +24,12 @@ #------------------------------------------------------------------ #------------------------------------------------------------------ #------------------------------------------------------------------ +class COND_CONST: + COND_AREA_DEFAULT = "default" + COND_AREA_MASK_BOUNDS = "mask bounds" + _LIST_COND_AREA = [COND_AREA_DEFAULT, COND_AREA_MASK_BOUNDS] + + class CreateLoraHookKeyframeInterpolationDEPR: @classmethod def INPUT_TYPES(s): diff --git a/animatediff/sample_settings.py b/animatediff/sample_settings.py index b7eeb38..d504875 100644 --- a/animatediff/sample_settings.py +++ b/animatediff/sample_settings.py @@ -13,9 +13,8 @@ from comfy.sd import VAE from . import freeinit -from .conditioning import LoraHookMode from .context import ContextOptions, ContextOptionsGroup -from .utils_model import SigmaSchedule +from .utils_model import SigmaSchedule, BIGMAX from .utils_motion import extend_to_batch_size, get_sorted_list_via_attr, prepare_mask_batch from .logger import logger @@ -611,6 +610,12 @@ def __init__(self, cfg_multival: Union[float, Tensor], start_percent=0.0, guaran self.start_t = 999999999.9 self.guarantee_steps = guarantee_steps + def get_effective_guarantee_steps(self, max_sigma: torch.Tensor): + '''If keyframe starts before current sampling range (max_sigma), treat as 0.''' + if self.start_t > max_sigma: + return 0 + return self.guarantee_steps + def clone(self): c = CustomCFGKeyframe(cfg_multival=self.cfg_multival, start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) @@ -661,14 +666,15 @@ def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - def prepare_current_keyframe(self, t: Tensor): + def prepare_current_keyframe(self, t: Tensor, transformer_options: dict[str, Tensor]): curr_t: float = t[0] # if curr_t same as before, do nothing as step already accounted for if curr_t == self._previous_t: return prev_index = self._current_index + max_sigma = torch.max(transformer_options.get("sigmas", BIGMAX)) # if met guaranteed steps, look for next keyframe in case need to switch - if self._current_used_steps >= self._current_keyframe.guarantee_steps: + if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma): # if has next index, loop through and see if need t oswitch if self.has_index(self._current_index+1): for i in range(self._current_index+1, len(self.keyframes)): @@ -680,7 +686,7 @@ def prepare_current_keyframe(self, t: Tensor): self._current_keyframe = eval_c self._current_used_steps = 0 # if guarantee_steps greater than zero, stop searching for other keyframes - if self._current_keyframe.guarantee_steps > 0: + if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further else: break diff --git a/animatediff/sampling.py b/animatediff/sampling.py index a48af04..7969ce2 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -54,13 +54,13 @@ def initialize(self, model: BaseModel): if self.sample_settings.custom_cfg is not None: self.sample_settings.custom_cfg.initialize_timesteps(model) - def prepare_current_keyframes(self, x: Tensor, timestep: Tensor): + def prepare_current_keyframes(self, x: Tensor, timestep: Tensor, transformer_options: dict[str, Tensor]): if self.motion_models is not None: - self.motion_models.prepare_current_keyframe(x=x, t=timestep) + self.motion_models.prepare_current_keyframe(x=x, t=timestep, transformer_options=transformer_options) if self.params.context_options is not None: - self.params.context_options.prepare_current(t=timestep) + self.params.context_options.prepare_current(t=timestep, transformer_options=transformer_options) if self.sample_settings.custom_cfg is not None: - self.sample_settings.custom_cfg.prepare_current_keyframe(t=timestep) + self.sample_settings.custom_cfg.prepare_current_keyframe(t=timestep, transformer_options=transformer_options) def perform_special_model_features(self, model: BaseModel, conds: list, x_in: Tensor, model_options: dict[str]): if self.motion_models is not None: @@ -540,7 +540,7 @@ def ad_callback(step, x0, x, total_steps): def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, cond_scale, model_options: dict={}, seed=None): ADGS: AnimateDiffGlobalState = model_options["transformer_options"]["ADGS"] ADGS.initialize(model) - ADGS.prepare_current_keyframes(x=x, timestep=timestep) + ADGS.prepare_current_keyframes(x=x, timestep=timestep, transformer_options=model_options["transformer_options"]) try: # add AD/evolved-sampling params to model_options (transformer_options) model_options = model_options.copy() diff --git a/animatediff/utils_motion.py b/animatediff/utils_motion.py index 122244f..362c56b 100644 --- a/animatediff/utils_motion.py +++ b/animatediff/utils_motion.py @@ -445,6 +445,12 @@ def has_cameractrl_effect(self): def has_pia_input(self): return self.pia_input is not None + def get_effective_guarantee_steps(self, max_sigma: torch.Tensor): + '''If keyframe starts before current sampling range (max_sigma), treat as 0.''' + if self.start_t > max_sigma: + return 0 + return self.guarantee_steps + class ADKeyframeGroup: def __init__(self): From d6d5a9241cce9dfc5d2f8e99b4be77de468a6354 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 00:27:03 -0600 Subject: [PATCH 08/12] Fix ContextRef for newest Advanced-ControlNet --- animatediff/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 7969ce2..1afec5f 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -724,7 +724,7 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list # check if ContextRef ReferenceAdvanced ACN objs should_run actually_should_run = True for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: - refcn.prepare_current_timestep(timestep) + refcn.prepare_current_timestep(timestep, model_options["transformer_options"]) if not refcn.should_run(): actually_should_run = False if actually_should_run: From 06e41b1ad643016179234482fb83398997312a8f Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 00:38:42 -0600 Subject: [PATCH 09/12] Make ContextRef be backwards compatible with previous versions of Advanced-ControlNet (DinkLink versioning coming in clutch) --- animatediff/dinklink.py | 11 +++++++++++ animatediff/sampling.py | 7 ++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/animatediff/dinklink.py b/animatediff/dinklink.py index 66ffdd6..247a226 100644 --- a/animatediff/dinklink.py +++ b/animatediff/dinklink.py @@ -60,3 +60,14 @@ def get_acn_outer_sample_wrapper(throw_exception=True): raise Exception("Advanced-ControlNet nodes need to be installed to make use of ContextRef; " + \ "they are either not installed or are of an insufficient version.") return None + +def get_acn_dinklink_version(throw_exception=True): + d = get_dinklink() + try: + link_acn = d[DinkLinkConst.ACN] + return link_acn[DinkLinkConst.VERSION] + except KeyError: + if throw_exception: + raise Exception("Advanced-ControlNet nodes need to be installed to make use of ContextRef; " + \ + "they are either not installed or are of an insufficient version.") + return None diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 1afec5f..ec2174b 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -28,6 +28,7 @@ from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion from .adapter_hellomeme import HMRefConst, HMRefStates, get_hmref_attachment, create_hmref_apply_model_wrapper from .logger import logger +from .dinklink import get_acn_dinklink_version ################################################################################## @@ -724,7 +725,11 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list # check if ContextRef ReferenceAdvanced ACN objs should_run actually_should_run = True for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: - refcn.prepare_current_timestep(timestep, model_options["transformer_options"]) + acn_dl_version = get_acn_dinklink_version() + if acn_dl_version > 10000: + refcn.prepare_current_timestep(timestep, model_options["transformer_options"]) + else: + refcn.prepare_current_timestep(timestep) if not refcn.should_run(): actually_should_run = False if actually_should_run: From 4b615267dc62f4493942f0b012d2c56c412e6bb8 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 00:42:21 -0600 Subject: [PATCH 10/12] Commented out MotionCtrl nodes for now --- animatediff/nodes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/animatediff/nodes.py b/animatediff/nodes.py index f83b32a..ef315d2 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -194,10 +194,10 @@ "ADE_UpscaleAndVAEEncode": UpscaleAndVaeEncode, "ADE_InjectI2VIntoAnimateDiffModel": LoadAnimateDiffAndInjectI2VNode, # MotionCtrl Nodes - LoadMotionCtrlCMCM.NodeID: LoadMotionCtrlCMCM, - LoadMotionCtrlOMCM.NodeID: LoadMotionCtrlOMCM, - ApplyAnimateDiffMotionCtrlModel.NodeID: ApplyAnimateDiffMotionCtrlModel, - LoadMotionCtrlCameraPosesFromFile.NodeID: LoadMotionCtrlCameraPosesFromFile, + #LoadMotionCtrlCMCM.NodeID: LoadMotionCtrlCMCM, + #LoadMotionCtrlOMCM.NodeID: LoadMotionCtrlOMCM, + #ApplyAnimateDiffMotionCtrlModel.NodeID: ApplyAnimateDiffMotionCtrlModel, + #LoadMotionCtrlCameraPosesFromFile.NodeID: LoadMotionCtrlCameraPosesFromFile, # CameraCtrl Nodes "ADE_ApplyAnimateDiffModelWithCameraCtrl": ApplyAnimateDiffWithCameraCtrl, "ADE_LoadAnimateDiffModelWithCameraCtrl": LoadAnimateDiffModelWithCameraCtrl, From af97ac83aaa38ea5f83a301ffa0cc84c4ded4d27 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 00:49:42 -0600 Subject: [PATCH 11/12] version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 976f12b..cda5fc0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.3.3" +version = "1.3.4" license = { file = "LICENSE" } dependencies = [] From 297f96279397df610b213aff7d8838f21338aff9 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 00:50:09 -0600 Subject: [PATCH 12/12] proper version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cda5fc0..b4c9a8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.3.4" +version = "1.4.0" license = { file = "LICENSE" } dependencies = []