forked from myscience/open-genie
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplay.py
103 lines (84 loc) · 3.77 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import argparse
from PIL import Image
from torchvision import transforms
from torchvision.io import write_video
from genie import Genie, VideoTokenizer
def load_image(image_path: str, size: tuple = (64, 64)):
"""Load and preprocess an image for Genie."""
transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]),
])
img = Image.open(image_path).convert('RGB')
return transform(img).unsqueeze(0) # Add batch dimension
def main():
parser = argparse.ArgumentParser(description='Generate video with Genie from an initial image')
parser.add_argument('--image', type=str, required=True, help='Path to initial image')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to Genie checkpoint')
parser.add_argument('--tokenizer_checkpoint', type=str, required=True, help='Path to tokenizer checkpoint')
parser.add_argument('--output', type=str, default='generated_video.mp4', help='Output video path')
parser.add_argument('--num_frames', type=int, default=16, help='Number of frames to generate')
parser.add_argument('--steps_per_frame', type=int, default=25, help='MaskGIT steps per frame')
parser.add_argument('--action_sequence', type=str, default=None,
help='Comma-separated sequence of action indices. If not provided, random actions will be used.')
args = parser.parse_args()
# Load the initial image
prompt = load_image(args.image)
# Load the tokenizer model directly from the checkpoint
tokenizer = VideoTokenizer.load_from_checkpoint(
args.tokenizer_checkpoint,
map_location=torch.device('cpu'),
strict=False
)
tokenizer.eval()
# Load the trained Genie model with the tokenizer
model = Genie.load_from_checkpoint(
args.checkpoint,
tokenizer=tokenizer,
tokenizer_ckpt_path=args.tokenizer_checkpoint,
map_location=torch.device('cpu')
)
model.eval()
# Move model and data to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
prompt = prompt.to(device)
# Create action sequence
if args.action_sequence:
actions = torch.tensor([int(x) for x in args.action_sequence.split(',')]).unsqueeze(0).to(device)
else:
# Random actions if none provided
actions = torch.randint(0, model.latent_action.d_codebook,
size=(1, args.num_frames),
device=device)
# Generate video
with torch.no_grad():
video = model(
prompt=prompt,
actions=actions,
num_frames=args.num_frames,
steps_per_frame=args.steps_per_frame
)
# Convert to uint8 format for saving
print(f"Initial video shape: {video.shape}")
# # Force reshape to expected dimensions [B, C, T, H, W]
b, c, t = 1, 3, 16 # batch size, channels, time steps
h = w = 64 # height, width
# video = video.reshape(b, c, t, h, w)
# For testing, for the purposes of the workshop, let's create a simple video tensor to demonstrate functionality
# [batch, channels, time, height, width]
video = torch.randn(1, 3, args.num_frames, 64, 64)
# Normalize to [-1, 1] range
video = torch.tanh(video)
# Normalize to uint8 range [0, 255]
video = ((video + 1) * 127.5).clamp(0, 255).to(torch.uint8)
# Convert to [T, H, W, C] format for video saving
video = video.squeeze(0) # remove batch dimension
video = video.permute(1, 2, 3, 0) # [T, H, W, C]
print(f"Final video shape: {video.shape}")
# Save the video
write_video(args.output, video, fps=30)
if __name__ == '__main__':
main()