Skip to content

Commit

Permalink
Merge PR #431 from Kosinkadink/develop - ContextRef and NaiveReuse
Browse files Browse the repository at this point in the history
Created Context Extras for Increased Consistency Across Context Windows
  • Loading branch information
Kosinkadink authored Aug 8, 2024
2 parents f297a20 + 1de8a7c commit 925511d
Show file tree
Hide file tree
Showing 18 changed files with 1,417 additions and 104 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ NOTE: you can also use custom locations for models/motion loras by making use of
- Per-frame GLIGEN coordinates control
- Currently requires GLIGENTextBoxApplyBatch from KJNodes to do so, but I will add native nodes to do this soon.
- Image Injection mid-sampling
- ContextRef and NaiveReuse (novel cross-context consistency techniques)

## Upcoming Features
- Example workflows for **every feature** in AnimateDiff-Evolved repo, and hopefully a long Youtube video showing all features (Goal: before Elden Ring DLC releases. Working on it right now.)
Expand Down
21 changes: 19 additions & 2 deletions animatediff/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher

from .context_extras import ContextExtrasGroup
from .utils_motion import get_sorted_list_via_attr


class ContextFuseMethod:
FLAT = "flat"
PYRAMID = "pyramid"
Expand Down Expand Up @@ -76,17 +78,21 @@ def clone(self):
class ContextOptionsGroup:
def __init__(self):
self.contexts: list[ContextOptions] = []
self.extras = ContextExtrasGroup()
self._current_context: ContextOptions = None
self._current_used_steps: int = 0
self._current_index: int = 0
self._previous_t = -1
self._step = 0

def reset(self):
self._current_context = None
self._current_used_steps = 0
self._current_index = 0
self._previous_t = -1
self.step = 0
self._set_first_as_current()
self.extras.cleanup()

@property
def step(self):
Expand Down Expand Up @@ -121,9 +127,10 @@ def has_index(self, index: int) -> int:

def is_empty(self) -> bool:
return len(self.contexts) == 0

def clone(self):
cloned = ContextOptionsGroup()
cloned.extras = self.extras.clone()
for context in self.contexts:
cloned.contexts.append(context)
cloned._set_first_as_current()
Expand All @@ -132,9 +139,17 @@ def clone(self):
def initialize_timesteps(self, model: BaseModel):
for context in self.contexts:
context.start_t = model.model_sampling.percent_to_sigma(context.start_percent)
self.extras.initialize_timesteps(model)

def prepare_current(self, t: Tensor):
self.prepare_current_context(t)
self.extras.prepare_current(t)

def prepare_current_context(self, t: Tensor):
curr_t: float = t[0]
# if same as previous, do nothing as step already accounted for
if curr_t == self._previous_t:
return
prev_index = self._current_index
# if met guaranteed steps, look for next context in case need to switch
if self._current_used_steps >= self._current_context.guarantee_steps:
Expand All @@ -156,6 +171,8 @@ def prepare_current_context(self, t: Tensor):
break
# update steps current context is used
self._current_used_steps += 1
# update previous_t
self._previous_t = curr_t

def _set_first_as_current(self):
if len(self.contexts) > 0:
Expand Down Expand Up @@ -620,7 +637,7 @@ def generate_context_visualization(context_opts: ContextOptionsGroup, model: Mod

for i, t in enumerate(sigmas):
# make context_opts reflect current step/sigma
context_opts.prepare_current_context([t])
context_opts.prepare_current([t])
context_opts.step = start_step+i

# check if context should even be active in this case
Expand Down
Loading

0 comments on commit 925511d

Please sign in to comment.