Skip to content

Commit

Permalink
Merge PR #503 from Kosinkadink/develop
Browse files Browse the repository at this point in the history
Fixed PromptScheduling AssertionError
  • Loading branch information
Kosinkadink authored Dec 3, 2024
2 parents edb939f + 268f7fb commit 406155c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
22 changes: 17 additions & 5 deletions animatediff/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import Tensor
import torch.nn.functional as F
from dataclasses import dataclass
from dataclasses import dataclass, replace

from comfy.sd import CLIP
from comfy.utils import ProgressBar
Expand Down Expand Up @@ -291,16 +291,21 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
prev_holder: Union[CondHolder, None] = None
for idx, pair in enumerate(pairs):
holder = None
is_over_length = False
# if no last pair is set, then use first provided val up to the idx
if prev_holder is None:
for i in range(idx, pair.idx+1):
if i >= length:
is_over_length = True
continue
real_prompt = apply_values_replace_to_prompt(pair.val, i, values_replace=values_replace)
if holder is None or holder.prompt != real_prompt:
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
else:
holder = replace(holder)
holder.idx = i
real_cond[i] = cond
real_pooled[i] = pooled
real_holders[i] = holder
Expand All @@ -326,6 +331,7 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
# however, need to check if real_prompt remains the same
for i in range(prev_holder.idx+1, pair.idx):
if i >= length:
is_over_length = True
continue
if holder is None:
holder = prev_holder
Expand All @@ -334,6 +340,9 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
else:
holder = replace(holder)
holder.idx = i
real_cond[i] = holder.cond
real_pooled[i] = holder.pooled
real_holders[i] = holder
Expand Down Expand Up @@ -361,16 +370,17 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
cond_from = None
holder = None
interm_holder = prev_holder
for idx, weight in zip(interp_idxs, interp_weights):
if idx >= length:
for raw_idx, weight in zip(interp_idxs, interp_weights):
if raw_idx >= length:
is_over_length = True
continue
idx_int = round(float(idx))
idx_int = round(float(raw_idx))
# calculate cond_to stuff if not done yet
real_prompt = apply_values_replace_to_prompt(pair.val, idx_int, values_replace=values_replace)
if holder is None or holder.prompt != real_prompt:
cond_to, pooled_to = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond_to = pad_cond(cond_to, target_length=max_size)
holder = CondHolder(idx=pair.idx, prompt=real_prompt, raw_prompt=pair.val, cond=cond_to, pooled=pooled_to, hold=pair.hold)
holder = CondHolder(idx=idx_int, prompt=real_prompt, raw_prompt=pair.val, cond=cond_to, pooled=pooled_to, hold=pair.hold)
# calculate interm_holder stuff if needed
real_prompt = apply_values_replace_to_prompt(interm_holder.raw_prompt, idx_int, values_replace=values_replace)
if interm_holder.prompt != real_prompt:
Expand All @@ -394,6 +404,8 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
real_holders[idx_int] = interm_holder
pbar.update(1)
comfy.model_management.throw_exception_if_processing_interrupted()
if is_over_length:
break
assert holder is not None
prev_holder = holder

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-animatediff-evolved"
description = "Improved AnimateDiff integration for ComfyUI."
version = "1.3.1"
version = "1.3.2"
license = { file = "LICENSE" }
dependencies = []

Expand Down

0 comments on commit 406155c

Please sign in to comment.