Skip to content

v0.3.2

Compare
Choose a tag to compare
@github-actions github-actions released this 19 Nov 15:42
· 87 commits to main since this release
6668e18

🚀Support Multi-GPU Parallel Inference Speedup for CogVideoX

Everything works out of the box!

import torch
import torch.distributed as dist
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video

dist.init_process_group()

pipe = CogVideoXPipeline.from_pretrained(
    "THUDM/CogVideoX-5b",
    torch_dtype=torch.bfloat16,
).to(f"cuda:{dist.get_rank()}")

# pipe.enable_model_cpu_offload()
# pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe

parallelize_pipe(
    pipe,
    mesh=init_context_parallel_mesh(
        pipe.device.type,
        max_batch_dim_size=2,
        max_ring_dim_size=2,
    ),
)

torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")

prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
video = pipe(
    prompt=prompt,
    num_videos_per_prompt=1,
    num_inference_steps=50,
    num_frames=49,
    guidance_scale=6,
    # generator=torch.Generator(device=pipe.device).manual_seed(42),
).frames[0]

if dist.get_rank() == 0:
    print("Saving video to cogvideox.mp4")
    export_to_video(video, "cogvideox.mp4", fps=8)

dist.destroy_process_group()