-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrandar_xl_0.7b_llamagen.yaml
88 lines (77 loc) · 1.92 KB
/
randar_xl_0.7b_llamagen.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
ar_model:
target: RandAR.model.randar_gpt.RandARTransformer
params:
n_layer: 36
n_head: 20
dim: 1280
model_type: c2i
vocab_size: 16384
block_size: 256 # latent_size ** 2
num_classes: 1000
cls_token_num: 1
resid_dropout_p: 0.1
ffn_dropout_p: 0.1
drop_path_rate: 0.0
token_dropout_p: 0.1
position_order: random # or scan
grad_checkpointing: True
zero_class_qk: True
num_inference_steps: 88
dataset:
target: RandAR.dataset.imagenet.INatLatentDataset
params:
root_dir: /tmp/imagenet-llamagen-adm-256_codes
tokenizer:
target: RandAR.model.tokenizer.VQModel
params:
codebook_size: 16384
codebook_embed_dim: 8
codebook_l2_norm: True
codebook_show_usage: True
commit_loss_beta: 0.25
entropy_loss_ratio: 0.0
encoder_ch_mult: [1, 1, 2, 2, 4]
decoder_ch_mult: [1, 1, 2, 2, 4]
z_channels: 256
dropout_p: 0.0
accelerator:
gradient_accumulation_steps: 1
mixed_precision: bf16
log_with: wandb
optimizer:
lr: 0.0004
weight_decay: 0.05 # 5e-2
beta1: 0.9
beta2: 0.95
max_grad_norm: 1.0
skip_grad_iter: 100
skip_grad_norm: 10
lr_scheduler:
type: cosine
warm_up_iters: 50000
min_lr_ratio: 0.05
num_cycles: 0.5
# training related parameters
exp_name: "randar_xl_0.7b"
max_iters: 360000 # 100k firt
global_batch_size: 1024
ema: False
ema_decay: 0.9999
num_workers: 8
log_every: 100
ckpt_every: 10000
visualize_every: 10 # every 10 log intervals, do one visualization
keep_last_k: -1
resume_from: null # not used, might remove
global_seed: 0
results_dir: "results/" # will append exp_name later => results_dir/exp_name
gpt_ckpt: None # path to resume from
# for eval fid:
vq_ckpt: "./checkpoints/vq_ds16_c2i.pt"
fid_ref_path: "/mnt/localssd/dataset/VIRTUAL_imagenet256_labeled.npz"
fid_save_dir: "/mnt/localssd/results_tmp/"
fid_num_samples: 10000
fid_eval_every: 500000
# for wandb
wandb_entity: "RandAR"
wandb_offline: False