Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-tune the size of the dataset and the value of h_div_w #33

Open
l-dawei opened this issue Jan 8, 2025 · 18 comments
Open

Fine-tune the size of the dataset and the value of h_div_w #33

l-dawei opened this issue Jan 8, 2025 · 18 comments

Comments

@l-dawei
Copy link

l-dawei commented Jan 8, 2025

Dear authors,

Thanks for publishing the codebase and checkpoints, and for the great work! I am interested in the training data used to train Infinity. I used the 10 images you provided for fine-tuning, but got a bad image result. I would like to know the size of the dataset required for fine-tuning. And how did you get the possible values ​​of h_div_w?

Best wishes,
Antonsen
image

@JeyesHan
Copy link
Collaborator

JeyesHan commented Jan 8, 2025

It seems that there exist bugs in your fine-tunning. Using 10 images to fine-tune Infinity will soon overfit and generate the same images to the training set. Could you please post your training logs and testing logs here?

@l-dawei
Copy link
Author

l-dawei commented Jan 8, 2025

Thank you for your quick recovery.

My training log is as shown below,
image

My command line output is shown below,
image

I would appreciate a quick answer to my question.

Best wishes,
Antonsen

@xuanyuzhang21
Copy link

I encountered the same situation when setting rush_resume and fine-tuning 1200 images with pre-trained weights, the loss can converge normally, but the image results are completely corrupted. Can you provide some help?
微信图片_20250108170954

@JeyesHan
Copy link
Collaborator

JeyesHan commented Jan 8, 2025

@l-dawei Could you please post your inference code here?

@JeyesHan
Copy link
Collaborator

JeyesHan commented Jan 8, 2025

I encountered the same situation when setting rush_resume and fine-tuning 1200 images with pre-trained weights, the loss can converge normally, but the image results are completely corrupted. Can you provide some help? 微信图片_20250108170954

@xuanyuzhang21 Could you please post your training logs (b1_stdout.txt)and inference code here?

@xuanyuzhang21
Copy link

b1_stdout.txt

import torch
torch.cuda.set_device(0)
import cv2
import numpy as np
from tools.run_infinity import *

model_path='local_output/debug/ar-ckpt-giter010K-ep16-iter516-last.pth'
vae_path='weights/infinity_vae_d32_reg.pth'
text_encoder_ckpt = 'weights/models--google--flan-t5-xl/snapshots/7d6315df2c2fb742f0f5b556879d730926ca9001'
args=argparse.Namespace(
    pn='0.06M',
    model_path=model_path,
    cfg_insertion_layer=0,
    vae_type=32,
    vae_path=vae_path,
    add_lvl_embeding_only_first_block=1,
    use_bit_label=1,
    model_type='infinity_2b',
    rope2d_each_sa_layer=1,
    rope2d_normalized_by_hw=2,
    use_scale_schedule_embedding=0,
    sampling_per_bits=1,
    text_encoder_ckpt=text_encoder_ckpt,
    text_channels=2048,
    apply_spatial_patchify=0,
    h_div_w_template=1.000,
    use_flex_attn=0,
    cache_dir='/dev/shm',
    checkpoint_type='torch',
    seed=0,
    bf16=1,
    save_file='tmp.jpg',
    enable_model_cache=False, 
)

# load text encoder
text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)
# load vae
vae = load_visual_tokenizer(args)
# load infinity
infinity = load_transformer(vae, args)

prompt = """The picture shows a lush, green garden with a bright sun shining through the trees, casting a warm glow over the scene."""
cfg = 3
tau = 0.5
h_div_w = 0.5 # aspect ratio, height:width
seed = random.randint(0, 10000)
enable_positive_prompt=0

h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))]
scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
generated_image = gen_one_img(
    infinity,
    vae,
    text_tokenizer,
    text_encoder,
    prompt,
    g_seed=seed,
    gt_leak=0,
    gt_ls_Bl=None,
    cfg_list=cfg,
    tau_list=tau,
    scale_schedule=scale_schedule,
    cfg_insertion_layer=[args.cfg_insertion_layer],
    vae_type=args.vae_type,
    sampling_per_bits=args.sampling_per_bits,
    enable_positive_prompt=enable_positive_prompt,
)
args.save_file = 'ipynb_tmp.jpg'
os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)
cv2.imwrite(args.save_file, generated_image.cpu().numpy())
print(f'Save to {osp.abspath(args.save_file)}')

This is my training log and inference code. Thanks~

@l-dawei
Copy link
Author

l-dawei commented Jan 8, 2025

This is my infer.sh , I did not modify the file Infinity/tools/run_infinity.py. Thanks a lot.
WechatIMG525

@JeyesHan
Copy link
Collaborator

JeyesHan commented Jan 8, 2025

@l-dawei Your training '--pn' is 0.06M but your inference '--pn' is 1M

@l-dawei
Copy link
Author

l-dawei commented Jan 8, 2025

@l-dawei Your training '--pn' is 0.06M but your inference '--pn' is 1M

But after I changed --pn in the infer.sh file to 0.06M, the generated image was still bad. Looking forward to your reply.

@JeyesHan
Copy link
Collaborator

JeyesHan commented Jan 8, 2025

@xuanyuzhang21 @l-dawei
I see you guys enable --apply_spatial_patchify=1 during training. Therefore, you should enable apply_spatial_patchify=1 during inference. You can try it.
However, we do not recommand enable apply_spatial_patchify=1 during training. The released model is trained without "spatial_patchify". I have updated scripts/train.sh. You can pull latest commit in your next run.

@l-dawei
Copy link
Author

l-dawei commented Jan 8, 2025

@JeyesHan

I pulled the latest code and found a small bug. Sometimes scale_schedule[-1] in the Infinity/trainer.py file is a tuple type. I modified it to use numpy's prod() function. Then I used the latest code to fine-tune the training, but the resulting image was still bad. Looking forward to your reply.
My b1_stdout.txt is as follows (I'm sorry that I can only take a screenshot because the server cannot transfer the file)
image

My infer.sh is as follows
image

@JeyesHan
Copy link
Collaborator

JeyesHan commented Jan 8, 2025

@l-dawei iter100 may be too less for fine-tuing a 1024 model to 256. What about more iterations? What is your current training acc now?

@l-dawei
Copy link
Author

l-dawei commented Jan 9, 2025

@JeyesHan
OK, I increased the epoch to 500 to try. But after the 58th epoch, the Accm and Acct of the training results are 98.96 and 99.08.

@xuanyuzhang21
Copy link

xuanyuzhang21 commented Jan 9, 2025

_IncompatibleKeys(missing_keys=['cfg_uncond', 'pos_start', 'text_norm.weight', 'text_proj_for_sos.ca.mat_q', 'text_proj_for_sos.ca.v_bias', 'text_proj_for_sos.ca.zero_k_bias', 'text_proj_for_sos.ca.mat_kv.weight', 'text_proj_for_sos.ca.proj.weight', 'text_proj_for_sos.ca.proj.bias', 'text_proj_for_ca.0.weight', 'text_proj_for_ca.0.bias', 'text_proj_for_ca.2.weight', 'text_proj_for_ca.2.bias', 'lvl_embed.weight', 'word_embed.weight', 'word_embed.bias', 'shared_ada_lin.1.weight', 'shared_ada_lin.1.bias', 'head_nm.ada_lin.1.weight', 'head_nm.ada_lin.1.bias', 'head.weight', 'head.bias', 'block_chunks.0.module.0.ada_gss', 'block_chunks.0.module.0.sa.scale_mul_1H11', 'block_chunks.0.module.0.sa.q_bias', 'block_chunks.0.module.0.sa.v_bias', 'block_chunks.0.module.0.sa.zero_k_bias', 'block_chunks.0.module.0.sa.mat_qkv.weight', 'block_chunks.0.module.0.sa.proj.weight', 'block_chunks.0.module.0.sa.proj.bias', 'block_chunks.0.module.0.ca.v_bias', 'block_chunks.0.module.0.ca.zero_k_bias', 'block_chunks.0.module.0.ca.mat_q.weight', 'block_chunks.0.module.0.ca.mat_q.bias', 'block_chunks.0.module.0.ca.mat_kv.weight', 'block_chunks.0.module.0.ca.proj.weight', 'block_chunks.0.module.0.ca.proj.bias', 'block_chunks.0.module.0.ffn.fc1.weight', 'block_chunks.0.module.0.ffn.fc1.bias', 'block_chunks.0.module.0.ffn.fc2.weight', 'block_chunks.0.module.0.ffn.fc2.bias', 'block_chunks.0.module.0.ca_norm.weight', 'block_chunks.0.module.0.ca_norm.bias', 'block_chunks.0.module.1.ada_gss', 'block_chunks.0.module.1.sa.scale_mul_1H11', 'block_chunks.0.module.1.sa.q_bias', 'block_chunks.0.module.1.sa.v_bias', 'block_chunks.0.module.1.sa.zero_k_bias', 'block_chunks.0.module.1.sa.mat_qkv.weight', 'block_chunks.0.module.1.sa.proj.weight', 'block_chunks.0.module.1.sa.proj.bias', 'block_chunks.0.module.1.ca.v_bias', 'block_chunks.0.module.1.ca.zero_k_bias', 'block_chunks.0.module.1.ca.mat_q.weight', 'block_chunks.0.module.1.ca.mat_q.bias', 'block_chunks.0.module.1.ca.mat_kv.weight', 'block_chunks.0.module.1.ca.proj.weight', 'block_chunks.0.module.1.ca.proj.bias', 'block_chunks.0.module.1.ffn.fc1.weight', 'block_chunks.0.module.1.ffn.fc1.bias', 'block_chunks.0.module.1.ffn.fc2.weight', 'block_chunks.0.module.1.ffn.fc2.bias', 'block_chunks.0.module.1.ca_norm.weight', 'block_chunks.0.module.1.ca_norm.bias', 'block_chunks.0.module.2.ada_gss', 'block_chunks.0.module.2.sa.scale_mul_1H11', 'block_chunks.0.module.2.sa.q_bias', 'block_chunks.0.module.2.sa.v_bias', 'block_chunks.0.module.2.sa.zero_k_bias', 'block_chunks.0.module.2.sa.mat_qkv.weight', 'block_chunks.0.module.2.sa.proj.weight', 'block_chunks.0.module.2.sa.proj.bias', 'block_chunks.0.module.2.ca.v_bias', 'block_chunks.0.module.2.ca.zero_k_bias', 'block_chunks.0.module.2.ca.mat_q.weight', 'block_chunks.0.module.2.ca.mat_q.bias', 'block_chunks.0.module.2.ca.mat_kv.weight', 'block_chunks.0.module.2.ca.proj.weight', 'block_chunks.0.module.2.ca.proj.bias', 'block_chunks.0.module.2.ffn.fc1.weight', 'block_chunks.0.module.2.ffn.fc1.bias', 'block_chunks.0.module.2.ffn.fc2.weight', 'block_chunks.0.module.2.ffn.fc2.bias', 'block_chunks.0.module.2.ca_norm.weight', 'block_chunks.0.module.2.ca_norm.bias', 'block_chunks.0.module.3.ada_gss', 'block_chunks.0.module.3.sa.scale_mul_1H11', 'block_chunks.0.module.3.sa.q_bias', 'block_chunks.0.module.3.sa.v_bias', 'block_chunks.0.module.3.sa.zero_k_bias', 'block_chunks.0.module.3.sa.mat_qkv.weight', 'block_chunks.0.module.3.sa.proj.weight', 'block_chunks.0.module.3.sa.proj.bias', 'block_chunks.0.module.3.ca.v_bias', 'block_chunks.0.module.3.ca.zero_k_bias', 'block_chunks.0.module.3.ca.mat_q.weight', 'block_chunks.0.module.3.ca.mat_q.bias', 'block_chunks.0.module.3.ca.mat_kv.weight', 'block_chunks.0.module.3.ca.proj.weight', 'block_chunks.0.module.3.ca.proj.bias', 'block_chunks.0.module.3.ffn.fc1.weight', 'block_chunks.0.module.3.ffn.fc1.bias', 'block_chunks.0.module.3.ffn.fc2.weight', 'block_chunks.0.module.3.ffn.fc2.bias', 'block_chunks.0.module.3.ca_norm.weight', 'block_chunks.0.module.3.ca_norm.bias', 'block_chunks.1.module.0.ada_gss', 'block_chunks.1.module.0.sa.scale_mul_1H11', 'block_chunks.1.module.0.sa.q_bias', 'block_chunks.1.module.0.sa.v_bias', 'block_chunks.1.module.0.sa.zero_k_bias', 'block_chunks.1.module.0.sa.mat_qkv.weight', 'block_chunks.1.module.0.sa.proj.weight', 'block_chunks.1.module.0.sa.proj.bias', 'block_chunks.1.module.0.ca.v_bias', 'block_chunks.1.module.0.ca.zero_k_bias', 'block_chunks.1.module.0.ca.mat_q.weight', 'block_chunks.1.module.0.ca.mat_q.bias', 'block_chunks.1.module.0.ca.mat_kv.weight', 'block_chunks.1.module.0.ca.proj.weight', 'block_chunks.1.module.0.ca.proj.bias', 'block_chunks.1.module.0.ffn.fc1.weight', 'block_chunks.1.module.0.ffn.fc1.bias', 'block_chunks.1.module.0.ffn.fc2.weight', 'block_chunks.1.module.0.ffn.fc2.bias', 'block_chunks.1.module.0.ca_norm.weight', 'block_chunks.1.module.0.ca_norm.bias', 'block_chunks.1.module.1.ada_gss', 'block_chunks.1.module.1.sa.scale_mul_1H11', 'block_chunks.1.module.1.sa.q_bias', 'block_chunks.1.module.1.sa.v_bias', 'block_chunks.1.module.1.sa.zero_k_bias', 'block_chunks.1.module.1.sa.mat_qkv.weight', 'block_chunks.1.module.1.sa.proj.weight', 'block_chunks.1.module.1.sa.proj.bias', 'block_chunks.1.module.1.ca.v_bias', 'block_chunks.1.module.1.ca.zero_k_bias', 'block_chunks.1.module.1.ca.mat_q.weight', 'block_chunks.1.module.1.ca.mat_q.bias', 'block_chunks.1.module.1.ca.mat_kv.weight', 'block_chunks.1.module.1.ca.proj.weight', 'block_chunks.1.module.1.ca.proj.bias', 'block_chunks.1.module.1.ffn.fc1.weight', 'block_chunks.1.module.1.ffn.fc1.bias', 'block_chunks.1.module.1.ffn.fc2.weight', 'block_chunks.1.module.1.ffn.fc2.bias', 'block_chunks.1.module.1.ca_norm.weight', 'block_chunks.1.module.1.ca_norm.bias', 'block_chunks.1.module.2.ada_gss', 'block_chunks.1.module.2.sa.scale_mul_1H11', 'block_chunks.1.module.2.sa.q_bias', 'block_chunks.1.module.2.sa.v_bias', 'block_chunks.1.module.2.sa.zero_k_bias', 'block_chunks.1.module.2.sa.mat_qkv.weight', 'block_chunks.1.module.2.sa.proj.weight', 'block_chunks.1.module.2.sa.proj.bias', 'block_chunks.1.module.2.ca.v_bias', 'block_chunks.1.module.2.ca.zero_k_bias', 'block_chunks.1.module.2.ca.mat_q.weight', 'block_chunks.1.module.2.ca.mat_q.bias', 'block_chunks.1.module.2.ca.mat_kv.weight', 'block_chunks.1.module.2.ca.proj.weight', 'block_chunks.1.module.2.ca.proj.bias', 'block_chunks.1.module.2.ffn.fc1.weight', 'block_chunks.1.module.2.ffn.fc1.bias', 'block_chunks.1.module.2.ffn.fc2.weight', 'block_chunks.1.module.2.ffn.fc2.bias', 'block_chunks.1.module.2.ca_norm.weight', 'block_chunks.1.module.2.ca_norm.bias', 'block_chunks.1.module.3.ada_gss', 'block_chunks.1.module.3.sa.scale_mul_1H11', 'block_chunks.1.module.3.sa.q_bias', 'block_chunks.1.module.3.sa.v_bias', 'block_chunks.1.module.3.sa.zero_k_bias', 'block_chunks.1.module.3.sa.mat_qkv.weight', 'block_chunks.1.module.3.sa.proj.weight', 'block_chunks.1.module.3.sa.proj.bias', 'block_chunks.1.module.3.ca.v_bias', 'block_chunks.1.module.3.ca.zero_k_bias', 'block_chunks.1.module.3.ca.mat_q.weight', 'block_chunks.1.module.3.ca.mat_q.bias', 'block_chunks.1.module.3.ca.mat_kv.weight', 'block_chunks.1.module.3.ca.proj.weight', 'block_chunks.1.module.3.ca.proj.bias', 'block_chunks.1.module.3.ffn.fc1.weight', 'block_chunks.1.module.3.ffn.fc1.bias', 'block_chunks.1.module.3.ffn.fc2.weight', 'block_chunks.1.module.3.ffn.fc2.bias', 'block_chunks.1.module.3.ca_norm.weight', 'block_chunks.1.module.3.ca_norm.bias', 'block_chunks.2.module.0.ada_gss', 'block_chunks.2.module.0.sa.scale_mul_1H11', 'block_chunks.2.module.0.sa.q_bias', 'block_chunks.2.module.0.sa.v_bias', 'block_chunks.2.module.0.sa.zero_k_bias', 'block_chunks.2.module.0.sa.mat_qkv.weight', 'block_chunks.2.module.0.sa.proj.weight', 'block_chunks.2.module.0.sa.proj.bias', 'block_chunks.2.module.0.ca.v_bias', 'block_chunks.2.module.0.ca.zero_k_bias', 'block_chunks.2.module.0.ca.mat_q.weight', 'block_chunks.2.module.0.ca.mat_q.bias', 'block_chunks.2.module.0.ca.mat_kv.weight', 'block_chunks.2.module.0.ca.proj.weight', 'block_chunks.2.module.0.ca.proj.bias', 'block_chunks.2.module.0.ffn.fc1.weight', 'block_chunks.2.module.0.ffn.fc1.bias', 'block_chunks.2.module.0.ffn.fc2.weight', 'block_chunks.2.module.0.ffn.fc2.bias', 'block_chunks.2.module.0.ca_norm.weight', 'block_chunks.2.module.0.ca_norm.bias', 'block_chunks.2.module.1.ada_gss', 'block_chunks.2.module.1.sa.scale_mul_1H11', 'block_chunks.2.module.1.sa.q_bias', 'block_chunks.2.module.1.sa.v_bias', 'block_chunks.2.module.1.sa.zero_k_bias', 'block_chunks.2.module.1.sa.mat_qkv.weight', 'block_chunks.2.module.1.sa.proj.weight', 'block_chunks.2.module.1.sa.proj.bias', 'block_chunks.2.module.1.ca.v_bias', 'block_chunks.2.module.1.ca.zero_k_bias', 'block_chunks.2.module.1.ca.mat_q.weight', 'block_chunks.2.module.1.ca.mat_q.bias', 'block_chunks.2.module.1.ca.mat_kv.weight', 'block_chunks.2.module.1.ca.proj.weight', 'block_chunks.2.module.1.ca.proj.bias', 'block_chunks.2.module.1.ffn.fc1.weight', 'block_chunks.2.module.1.ffn.fc1.bias', 'block_chunks.2.module.1.ffn.fc2.weight', 'block_chunks.2.module.1.ffn.fc2.bias', 'block_chunks.2.module.1.ca_norm.weight', 'block_chunks.2.module.1.ca_norm.bias', 'block_chunks.2.module.2.ada_gss', 'block_chunks.2.module.2.sa.scale_mul_1H11', 'block_chunks.2.module.2.sa.q_bias', 'block_chunks.2.module.2.sa.v_bias', 'block_chunks.2.module.2.sa.zero_k_bias', 'block_chunks.2.module.2.sa.mat_qkv.weight', 'block_chunks.2.module.2.sa.proj.weight', 'block_chunks.2.module.2.sa.proj.bias', 'block_chunks.2.module.2.ca.v_bias', 'block_chunks.2.module.2.ca.zero_k_bias', 'block_chunks.2.module.2.ca.mat_q.weight', 'block_chunks.2.module.2.ca.mat_q.bias', 'block_chunks.2.module.2.ca.mat_kv.weight', 'block_chunks.2.module.2.ca.proj.weight', 'block_chunks.2.module.2.ca.proj.bias', 'block_chunks.2.module.2.ffn.fc1.weight', 'block_chunks.2.module.2.ffn.fc1.bias', 'block_chunks.2.module.2.ffn.fc2.weight', 'block_chunks.2.module.2.ffn.fc2.bias', 'block_chunks.2.module.2.ca_norm.weight', 'block_chunks.2.module.2.ca_norm.bias', 'block_chunks.2.module.3.ada_gss', 'block_chunks.2.module.3.sa.scale_mul_1H11', 'block_chunks.2.module.3.sa.q_bias', 'block_chunks.2.module.3.sa.v_bias', 'block_chunks.2.module.3.sa.zero_k_bias', 'block_chunks.2.module.3.sa.mat_qkv.weight', 'block_chunks.2.module.3.sa.proj.weight', 'block_chunks.2.module.3.sa.proj.bias', 'block_chunks.2.module.3.ca.v_bias', 'block_chunks.2.module.3.ca.zero_k_bias', 'block_chunks.2.module.3.ca.mat_q.weight', 'block_chunks.2.module.3.ca.mat_q.bias', 'block_chunks.2.module.3.ca.mat_kv.weight', 'block_chunks.2.module.3.ca.proj.weight', 'block_chunks.2.module.3.ca.proj.bias', 'block_chunks.2.module.3.ffn.fc1.weight', 'block_chunks.2.module.3.ffn.fc1.bias', 'block_chunks.2.module.3.ffn.fc2.weight', 'block_chunks.2.module.3.ffn.fc2.bias', 'block_chunks.2.module.3.ca_norm.weight', 'block_chunks.2.module.3.ca_norm.bias', 'block_chunks.3.module.0.ada_gss', 'block_chunks.3.module.0.sa.scale_mul_1H11', 'block_chunks.3.module.0.sa.q_bias', 'block_chunks.3.module.0.sa.v_bias', 'block_chunks.3.module.0.sa.zero_k_bias', 'block_chunks.3.module.0.sa.mat_qkv.weight', 'block_chunks.3.module.0.sa.proj.weight', 'block_chunks.3.module.0.sa.proj.bias', 'block_chunks.3.module.0.ca.v_bias', 'block_chunks.3.module.0.ca.zero_k_bias', 'block_chunks.3.module.0.ca.mat_q.weight', 'block_chunks.3.module.0.ca.mat_q.bias', 'block_chunks.3.module.0.ca.mat_kv.weight', 'block_chunks.3.module.0.ca.proj.weight', 'block_chunks.3.module.0.ca.proj.bias', 'block_chunks.3.module.0.ffn.fc1.weight', 'block_chunks.3.module.0.ffn.fc1.bias', 'block_chunks.3.module.0.ffn.fc2.weight', 'block_chunks.3.module.0.ffn.fc2.bias', 'block_chunks.3.module.0.ca_norm.weight', 'block_chunks.3.module.0.ca_norm.bias', 'block_chunks.3.module.1.ada_gss', 'block_chunks.3.module.1.sa.scale_mul_1H11', 'block_chunks.3.module.1.sa.q_bias', 'block_chunks.3.module.1.sa.v_bias', 'block_chunks.3.module.1.sa.zero_k_bias', 'block_chunks.3.module.1.sa.mat_qkv.weight', 'block_chunks.3.module.1.sa.proj.weight', 'block_chunks.3.module.1.sa.proj.bias', 'block_chunks.3.module.1.ca.v_bias', 'block_chunks.3.module.1.ca.zero_k_bias', 'block_chunks.3.module.1.ca.mat_q.weight', 'block_chunks.3.module.1.ca.mat_q.bias', 'block_chunks.3.module.1.ca.mat_kv.weight', 'block_chunks.3.module.1.ca.proj.weight', 'block_chunks.3.module.1.ca.proj.bias', 'block_chunks.3.module.1.ffn.fc1.weight', 'block_chunks.3.module.1.ffn.fc1.bias', 'block_chunks.3.module.1.ffn.fc2.weight', 'block_chunks.3.module.1.ffn.fc2.bias', 'block_chunks.3.module.1.ca_norm.weight', 'block_chunks.3.module.1.ca_norm.bias', 'block_chunks.3.module.2.ada_gss', 'block_chunks.3.module.2.sa.scale_mul_1H11', 'block_chunks.3.module.2.sa.q_bias', 'block_chunks.3.module.2.sa.v_bias', 'block_chunks.3.module.2.sa.zero_k_bias', 'block_chunks.3.module.2.sa.mat_qkv.weight', 'block_chunks.3.module.2.sa.proj.weight', 'block_chunks.3.module.2.sa.proj.bias', 'block_chunks.3.module.2.ca.v_bias', 'block_chunks.3.module.2.ca.zero_k_bias', 'block_chunks.3.module.2.ca.mat_q.weight', 'block_chunks.3.module.2.ca.mat_q.bias', 'block_chunks.3.module.2.ca.mat_kv.weight', 'block_chunks.3.module.2.ca.proj.weight', 'block_chunks.3.module.2.ca.proj.bias', 'block_chunks.3.module.2.ffn.fc1.weight', 'block_chunks.3.module.2.ffn.fc1.bias', 'block_chunks.3.module.2.ffn.fc2.weight', 'block_chunks.3.module.2.ffn.fc2.bias', 'block_chunks.3.module.2.ca_norm.weight', 'block_chunks.3.module.2.ca_norm.bias', 'block_chunks.3.module.3.ada_gss', 'block_chunks.3.module.3.sa.scale_mul_1H11', 'block_chunks.3.module.3.sa.q_bias', 'block_chunks.3.module.3.sa.v_bias', 'block_chunks.3.module.3.sa.zero_k_bias', 'block_chunks.3.module.3.sa.mat_qkv.weight', 'block_chunks.3.module.3.sa.proj.weight', 'block_chunks.3.module.3.sa.proj.bias', 'block_chunks.3.module.3.ca.v_bias', 'block_chunks.3.module.3.ca.zero_k_bias', 'block_chunks.3.module.3.ca.mat_q.weight', 'block_chunks.3.module.3.ca.mat_q.bias', 'block_chunks.3.module.3.ca.mat_kv.weight', 'block_chunks.3.module.3.ca.proj.weight', 'block_chunks.3.module.3.ca.proj.bias', 'block_chunks.3.module.3.ffn.fc1.weight', 'block_chunks.3.module.3.ffn.fc1.bias', 'block_chunks.3.module.3.ffn.fc2.weight', 'block_chunks.3.module.3.ffn.fc2.bias', 'block_chunks.3.module.3.ca_norm.weight', 'block_chunks.3.module.3.ca_norm.bias', 'block_chunks.4.module.0.ada_gss', 'block_chunks.4.module.0.sa.scale_mul_1H11', 'block_chunks.4.module.0.sa.q_bias', 'block_chunks.4.module.0.sa.v_bias', 'block_chunks.4.module.0.sa.zero_k_bias', 'block_chunks.4.module.0.sa.mat_qkv.weight', 'block_chunks.4.module.0.sa.proj.weight', 'block_chunks.4.module.0.sa.proj.bias', 'block_chunks.4.module.0.ca.v_bias', 'block_chunks.4.module.0.ca.zero_k_bias', 'block_chunks.4.module.0.ca.mat_q.weight', 'block_chunks.4.module.0.ca.mat_q.bias', 'block_chunks.4.module.0.ca.mat_kv.weight', 'block_chunks.4.module.0.ca.proj.weight', 'block_chunks.4.module.0.ca.proj.bias', 'block_chunks.4.module.0.ffn.fc1.weight', 'block_chunks.4.module.0.ffn.fc1.bias', 'block_chunks.4.module.0.ffn.fc2.weight', 'block_chunks.4.module.0.ffn.fc2.bias', 'block_chunks.4.module.0.ca_norm.weight', 'block_chunks.4.module.0.ca_norm.bias', 'block_chunks.4.module.1.ada_gss', 'block_chunks.4.module.1.sa.scale_mul_1H11', 'block_chunks.4.module.1.sa.q_bias', 'block_chunks.4.module.1.sa.v_bias', 'block_chunks.4.module.1.sa.zero_k_bias', 'block_chunks.4.module.1.sa.mat_qkv.weight', 'block_chunks.4.module.1.sa.proj.weight', 'block_chunks.4.module.1.sa.proj.bias', 'block_chunks.4.module.1.ca.v_bias', 'block_chunks.4.module.1.ca.zero_k_bias', 'block_chunks.4.module.1.ca.mat_q.weight', 'block_chunks.4.module.1.ca.mat_q.bias', 'block_chunks.4.module.1.ca.mat_kv.weight', 'block_chunks.4.module.1.ca.proj.weight', 'block_chunks.4.module.1.ca.proj.bias', 'block_chunks.4.module.1.ffn.fc1.weight', 'block_chunks.4.module.1.ffn.fc1.bias', 'block_chunks.4.module.1.ffn.fc2.weight', 'block_chunks.4.module.1.ffn.fc2.bias', 'block_chunks.4.module.1.ca_norm.weight', 'block_chunks.4.module.1.ca_norm.bias', 'block_chunks.4.module.2.ada_gss', 'block_chunks.4.module.2.sa.scale_mul_1H11', 'block_chunks.4.module.2.sa.q_bias', 'block_chunks.4.module.2.sa.v_bias', 'block_chunks.4.module.2.sa.zero_k_bias', 'block_chunks.4.module.2.sa.mat_qkv.weight', 'block_chunks.4.module.2.sa.proj.weight', 'block_chunks.4.module.2.sa.proj.bias', 'block_chunks.4.module.2.ca.v_bias', 'block_chunks.4.module.2.ca.zero_k_bias', 'block_chunks.4.module.2.ca.mat_q.weight', 'block_chunks.4.module.2.ca.mat_q.bias', 'block_chunks.4.module.2.ca.mat_kv.weight', 'block_chunks.4.module.2.ca.proj.weight', 'block_chunks.4.module.2.ca.proj.bias', 'block_chunks.4.module.2.ffn.fc1.weight', 'block_chunks.4.module.2.ffn.fc1.bias', 'block_chunks.4.module.2.ffn.fc2.weight', 'block_chunks.4.module.2.ffn.fc2.bias', 'block_chunks.4.module.2.ca_norm.weight', 'block_chunks.4.module.2.ca_norm.bias', 'block_chunks.4.module.3.ada_gss', 'block_chunks.4.module.3.sa.scale_mul_1H11', 'block_chunks.4.module.3.sa.q_bias', 'block_chunks.4.module.3.sa.v_bias', 'block_chunks.4.module.3.sa.zero_k_bias', 'block_chunks.4.module.3.sa.mat_qkv.weight', 'block_chunks.4.module.3.sa.proj.weight', 'block_chunks.4.module.3.sa.proj.bias', 'block_chunks.4.module.3.ca.v_bias', 'block_chunks.4.module.3.ca.zero_k_bias', 'block_chunks.4.module.3.ca.mat_q.weight', 'block_chunks.4.module.3.ca.mat_q.bias', 'block_chunks.4.module.3.ca.mat_kv.weight', 'block_chunks.4.module.3.ca.proj.weight', 'block_chunks.4.module.3.ca.proj.bias', 'block_chunks.4.module.3.ffn.fc1.weight', 'block_chunks.4.module.3.ffn.fc1.bias', 'block_chunks.4.module.3.ffn.fc2.weight', 'block_chunks.4.module.3.ffn.fc2.bias', 'block_chunks.4.module.3.ca_norm.weight', 'block_chunks.4.module.3.ca_norm.bias', 'block_chunks.5.module.0.ada_gss', 'block_chunks.5.module.0.sa.scale_mul_1H11', 'block_chunks.5.module.0.sa.q_bias', 'block_chunks.5.module.0.sa.v_bias', 'block_chunks.5.module.0.sa.zero_k_bias', 'block_chunks.5.module.0.sa.mat_qkv.weight', 'block_chunks.5.module.0.sa.proj.weight', 'block_chunks.5.module.0.sa.proj.bias', 'block_chunks.5.module.0.ca.v_bias', 'block_chunks.5.module.0.ca.zero_k_bias', 'block_chunks.5.module.0.ca.mat_q.weight', 'block_chunks.5.module.0.ca.mat_q.bias', 'block_chunks.5.module.0.ca.mat_kv.weight', 'block_chunks.5.module.0.ca.proj.weight', 'block_chunks.5.module.0.ca.proj.bias', 'block_chunks.5.module.0.ffn.fc1.weight', 'block_chunks.5.module.0.ffn.fc1.bias', 'block_chunks.5.module.0.ffn.fc2.weight', 'block_chunks.5.module.0.ffn.fc2.bias', 'block_chunks.5.module.0.ca_norm.weight', 'block_chunks.5.module.0.ca_norm.bias', 'block_chunks.5.module.1.ada_gss', 'block_chunks.5.module.1.sa.scale_mul_1H11', 'block_chunks.5.module.1.sa.q_bias', 'block_chunks.5.module.1.sa.v_bias', 'block_chunks.5.module.1.sa.zero_k_bias', 'block_chunks.5.module.1.sa.mat_qkv.weight', 'block_chunks.5.module.1.sa.proj.weight', 'block_chunks.5.module.1.sa.proj.bias', 'block_chunks.5.module.1.ca.v_bias', 'block_chunks.5.module.1.ca.zero_k_bias', 'block_chunks.5.module.1.ca.mat_q.weight', 'block_chunks.5.module.1.ca.mat_q.bias', 'block_chunks.5.module.1.ca.mat_kv.weight', 'block_chunks.5.module.1.ca.proj.weight', 'block_chunks.5.module.1.ca.proj.bias', 'block_chunks.5.module.1.ffn.fc1.weight', 'block_chunks.5.module.1.ffn.fc1.bias', 'block_chunks.5.module.1.ffn.fc2.weight', 'block_chunks.5.module.1.ffn.fc2.bias', 'block_chunks.5.module.1.ca_norm.weight', 'block_chunks.5.module.1.ca_norm.bias', 'block_chunks.5.module.2.ada_gss', 'block_chunks.5.module.2.sa.scale_mul_1H11', 'block_chunks.5.module.2.sa.q_bias', 'block_chunks.5.module.2.sa.v_bias', 'block_chunks.5.module.2.sa.zero_k_bias', 'block_chunks.5.module.2.sa.mat_qkv.weight', 'block_chunks.5.module.2.sa.proj.weight', 'block_chunks.5.module.2.sa.proj.bias', 'block_chunks.5.module.2.ca.v_bias', 'block_chunks.5.module.2.ca.zero_k_bias', 'block_chunks.5.module.2.ca.mat_q.weight', 'block_chunks.5.module.2.ca.mat_q.bias', 'block_chunks.5.module.2.ca.mat_kv.weight', 'block_chunks.5.module.2.ca.proj.weight', 'block_chunks.5.module.2.ca.proj.bias', 'block_chunks.5.module.2.ffn.fc1.weight', 'block_chunks.5.module.2.ffn.fc1.bias', 'block_chunks.5.module.2.ffn.fc2.weight', 'block_chunks.5.module.2.ffn.fc2.bias', 'block_chunks.5.module.2.ca_norm.weight', 'block_chunks.5.module.2.ca_norm.bias', 'block_chunks.5.module.3.ada_gss', 'block_chunks.5.module.3.sa.scale_mul_1H11', 'block_chunks.5.module.3.sa.q_bias', 'block_chunks.5.module.3.sa.v_bias', 'block_chunks.5.module.3.sa.zero_k_bias', 'block_chunks.5.module.3.sa.mat_qkv.weight', 'block_chunks.5.module.3.sa.proj.weight', 'block_chunks.5.module.3.sa.proj.bias', 'block_chunks.5.module.3.ca.v_bias', 'block_chunks.5.module.3.ca.zero_k_bias', 'block_chunks.5.module.3.ca.mat_q.weight', 'block_chunks.5.module.3.ca.mat_q.bias', 'block_chunks.5.module.3.ca.mat_kv.weight', 'block_chunks.5.module.3.ca.proj.weight', 'block_chunks.5.module.3.ca.proj.bias', 'block_chunks.5.module.3.ffn.fc1.weight', 'block_chunks.5.module.3.ffn.fc1.bias', 'block_chunks.5.module.3.ffn.fc2.weight', 'block_chunks.5.module.3.ffn.fc2.bias', 'block_chunks.5.module.3.ca_norm.weight', 'block_chunks.5.module.3.ca_norm.bias', 'block_chunks.6.module.0.ada_gss', 'block_chunks.6.module.0.sa.scale_mul_1H11', 'block_chunks.6.module.0.sa.q_bias', 'block_chunks.6.module.0.sa.v_bias', 'block_chunks.6.module.0.sa.zero_k_bias', 'block_chunks.6.module.0.sa.mat_qkv.weight', 'block_chunks.6.module.0.sa.proj.weight', 'block_chunks.6.module.0.sa.proj.bias', 'block_chunks.6.module.0.ca.v_bias', 'block_chunks.6.module.0.ca.zero_k_bias', 'block_chunks.6.module.0.ca.mat_q.weight', 'block_chunks.6.module.0.ca.mat_q.bias', 'block_chunks.6.module.0.ca.mat_kv.weight', 'block_chunks.6.module.0.ca.proj.weight', 'block_chunks.6.module.0.ca.proj.bias', 'block_chunks.6.module.0.ffn.fc1.weight', 'block_chunks.6.module.0.ffn.fc1.bias', 'block_chunks.6.module.0.ffn.fc2.weight', 'block_chunks.6.module.0.ffn.fc2.bias', 'block_chunks.6.module.0.ca_norm.weight', 'block_chunks.6.module.0.ca_norm.bias', 'block_chunks.6.module.1.ada_gss', 'block_chunks.6.module.1.sa.scale_mul_1H11', 'block_chunks.6.module.1.sa.q_bias', 'block_chunks.6.module.1.sa.v_bias', 'block_chunks.6.module.1.sa.zero_k_bias', 'block_chunks.6.module.1.sa.mat_qkv.weight', 'block_chunks.6.module.1.sa.proj.weight', 'block_chunks.6.module.1.sa.proj.bias', 'block_chunks.6.module.1.ca.v_bias', 'block_chunks.6.module.1.ca.zero_k_bias', 'block_chunks.6.module.1.ca.mat_q.weight', 'block_chunks.6.module.1.ca.mat_q.bias', 'block_chunks.6.module.1.ca.mat_kv.weight', 'block_chunks.6.module.1.ca.proj.weight', 'block_chunks.6.module.1.ca.proj.bias', 'block_chunks.6.module.1.ffn.fc1.weight', 'block_chunks.6.module.1.ffn.fc1.bias', 'block_chunks.6.module.1.ffn.fc2.weight', 'block_chunks.6.module.1.ffn.fc2.bias', 'block_chunks.6.module.1.ca_norm.weight', 'block_chunks.6.module.1.ca_norm.bias', 'block_chunks.6.module.2.ada_gss', 'block_chunks.6.module.2.sa.scale_mul_1H11', 'block_chunks.6.module.2.sa.q_bias', 'block_chunks.6.module.2.sa.v_bias', 'block_chunks.6.module.2.sa.zero_k_bias', 'block_chunks.6.module.2.sa.mat_qkv.weight', 'block_chunks.6.module.2.sa.proj.weight', 'block_chunks.6.module.2.sa.proj.bias', 'block_chunks.6.module.2.ca.v_bias', 'block_chunks.6.module.2.ca.zero_k_bias', 'block_chunks.6.module.2.ca.mat_q.weight', 'block_chunks.6.module.2.ca.mat_q.bias', 'block_chunks.6.module.2.ca.mat_kv.weight', 'block_chunks.6.module.2.ca.proj.weight', 'block_chunks.6.module.2.ca.proj.bias', 'block_chunks.6.module.2.ffn.fc1.weight', 'block_chunks.6.module.2.ffn.fc1.bias', 'block_chunks.6.module.2.ffn.fc2.weight', 'block_chunks.6.module.2.ffn.fc2.bias', 'block_chunks.6.module.2.ca_norm.weight', 'block_chunks.6.module.2.ca_norm.bias', 'block_chunks.6.module.3.ada_gss', 'block_chunks.6.module.3.sa.scale_mul_1H11', 'block_chunks.6.module.3.sa.q_bias', 'block_chunks.6.module.3.sa.v_bias', 'block_chunks.6.module.3.sa.zero_k_bias', 'block_chunks.6.module.3.sa.mat_qkv.weight', 'block_chunks.6.module.3.sa.proj.weight', 'block_chunks.6.module.3.sa.proj.bias', 'block_chunks.6.module.3.ca.v_bias', 'block_chunks.6.module.3.ca.zero_k_bias', 'block_chunks.6.module.3.ca.mat_q.weight', 'block_chunks.6.module.3.ca.mat_q.bias', 'block_chunks.6.module.3.ca.mat_kv.weight', 'block_chunks.6.module.3.ca.proj.weight', 'block_chunks.6.module.3.ca.proj.bias', 'block_chunks.6.module.3.ffn.fc1.weight', 'block_chunks.6.module.3.ffn.fc1.bias', 'block_chunks.6.module.3.ffn.fc2.weight', 'block_chunks.6.module.3.ffn.fc2.bias', 'block_chunks.6.module.3.ca_norm.weight', 'block_chunks.6.module.3.ca_norm.bias', 'block_chunks.7.module.0.ada_gss', 'block_chunks.7.module.0.sa.scale_mul_1H11', 'block_chunks.7.module.0.sa.q_bias', 'block_chunks.7.module.0.sa.v_bias', 'block_chunks.7.module.0.sa.zero_k_bias', 'block_chunks.7.module.0.sa.mat_qkv.weight', 'block_chunks.7.module.0.sa.proj.weight', 'block_chunks.7.module.0.sa.proj.bias', 'block_chunks.7.module.0.ca.v_bias', 'block_chunks.7.module.0.ca.zero_k_bias', 'block_chunks.7.module.0.ca.mat_q.weight', 'block_chunks.7.module.0.ca.mat_q.bias', 'block_chunks.7.module.0.ca.mat_kv.weight', 'block_chunks.7.module.0.ca.proj.weight', 'block_chunks.7.module.0.ca.proj.bias', 'block_chunks.7.module.0.ffn.fc1.weight', 'block_chunks.7.module.0.ffn.fc1.bias', 'block_chunks.7.module.0.ffn.fc2.weight', 'block_chunks.7.module.0.ffn.fc2.bias', 'block_chunks.7.module.0.ca_norm.weight', 'block_chunks.7.module.0.ca_norm.bias', 'block_chunks.7.module.1.ada_gss', 'block_chunks.7.module.1.sa.scale_mul_1H11', 'block_chunks.7.module.1.sa.q_bias', 'block_chunks.7.module.1.sa.v_bias', 'block_chunks.7.module.1.sa.zero_k_bias', 'block_chunks.7.module.1.sa.mat_qkv.weight', 'block_chunks.7.module.1.sa.proj.weight', 'block_chunks.7.module.1.sa.proj.bias', 'block_chunks.7.module.1.ca.v_bias', 'block_chunks.7.module.1.ca.zero_k_bias', 'block_chunks.7.module.1.ca.mat_q.weight', 'block_chunks.7.module.1.ca.mat_q.bias', 'block_chunks.7.module.1.ca.mat_kv.weight', 'block_chunks.7.module.1.ca.proj.weight', 'block_chunks.7.module.1.ca.proj.bias', 'block_chunks.7.module.1.ffn.fc1.weight', 'block_chunks.7.module.1.ffn.fc1.bias', 'block_chunks.7.module.1.ffn.fc2.weight', 'block_chunks.7.module.1.ffn.fc2.bias', 'block_chunks.7.module.1.ca_norm.weight', 'block_chunks.7.module.1.ca_norm.bias', 'block_chunks.7.module.2.ada_gss', 'block_chunks.7.module.2.sa.scale_mul_1H11', 'block_chunks.7.module.2.sa.q_bias', 'block_chunks.7.module.2.sa.v_bias', 'block_chunks.7.module.2.sa.zero_k_bias', 'block_chunks.7.module.2.sa.mat_qkv.weight', 'block_chunks.7.module.2.sa.proj.weight', 'block_chunks.7.module.2.sa.proj.bias', 'block_chunks.7.module.2.ca.v_bias', 'block_chunks.7.module.2.ca.zero_k_bias', 'block_chunks.7.module.2.ca.mat_q.weight', 'block_chunks.7.module.2.ca.mat_q.bias', 'block_chunks.7.module.2.ca.mat_kv.weight', 'block_chunks.7.module.2.ca.proj.weight', 'block_chunks.7.module.2.ca.proj.bias', 'block_chunks.7.module.2.ffn.fc1.weight', 'block_chunks.7.module.2.ffn.fc1.bias', 'block_chunks.7.module.2.ffn.fc2.weight', 'block_chunks.7.module.2.ffn.fc2.bias', 'block_chunks.7.module.2.ca_norm.weight', 'block_chunks.7.module.2.ca_norm.bias', 'block_chunks.7.module.3.ada_gss', 'block_chunks.7.module.3.sa.scale_mul_1H11', 'block_chunks.7.module.3.sa.q_bias', 'block_chunks.7.module.3.sa.v_bias', 'block_chunks.7.module.3.sa.zero_k_bias', 'block_chunks.7.module.3.sa.mat_qkv.weight', 'block_chunks.7.module.3.sa.proj.weight', 'block_chunks.7.module.3.sa.proj.bias', 'block_chunks.7.module.3.ca.v_bias', 'block_chunks.7.module.3.ca.zero_k_bias', 'block_chunks.7.module.3.ca.mat_q.weight', 'block_chunks.7.module.3.ca.mat_q.bias', 'block_chunks.7.module.3.ca.mat_kv.weight', 'block_chunks.7.module.3.ca.proj.weight', 'block_chunks.7.module.3.ca.proj.bias', 'block_chunks.7.module.3.ffn.fc1.weight', 'block_chunks.7.module.3.ffn.fc1.bias', 'block_chunks.7.module.3.ffn.fc2.weight', 'block_chunks.7.module.3.ffn.fc2.bias', 'block_chunks.7.module.3.ca_norm.weight', 'block_chunks.7.module.3.ca_norm.bias'], unexpected_keys=['args', 'gpt_training', 'arch', 'epoch', 'iter', 'trainer', 'acc_str', 'milestones'])

When I use the following inference code, it will report the above imcompatible errors.

import torch
torch.cuda.set_device(0)
import cv2
import numpy as np
from tools.run_infinity import *
import random

model_path="/data03/zxy/Infinity/local_output/debug/ar-ckpt-giter004K-ep3-iter756-last.pth"
vae_path='weights/infinity_vae_d32_reg.pth'
text_encoder_ckpt = 'weights/models--google--flan-t5-xl/snapshots/7d6315df2c2fb742f0f5b556879d730926ca9001'
args=argparse.Namespace(
    pn='0.06M',
    model_path=model_path,
    cfg_insertion_layer=0,
    vae_type=32,
    vae_path=vae_path,
    add_lvl_embeding_only_first_block=1,
    use_bit_label=1,
    model_type='infinity_2b',
    rope2d_each_sa_layer=1,
    rope2d_normalized_by_hw=2,
    use_scale_schedule_embedding=0,
    sampling_per_bits=1,
    text_encoder_ckpt=text_encoder_ckpt,
    text_channels=2048,
    apply_spatial_patchify=0,
    h_div_w_template=1.000,
    use_flex_attn=0,
    cache_dir='/dev/shm',
    checkpoint_type='torch',
    seed=0,
    bf16=1,
    save_file='tmp.jpg',
    enable_model_cache=False, 
)

# load text encoder
text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)
# load vae
vae = load_visual_tokenizer(args)
# load infinity
infinity = load_transformer(vae, args)

prompt = """jenson button's autobiography, life to the limit"""
cfg = 3
tau = 0.5
h_div_w = 1.5 # aspect ratio, height:width
seed = random.randint(0, 10000)
enable_positive_prompt=0

h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))]
scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
generated_image = gen_one_img(
    infinity,
    vae,
    text_tokenizer,
    text_encoder,
    prompt,
    g_seed=seed,
    gt_leak=0,
    gt_ls_Bl=None,
    cfg_list=cfg,
    tau_list=tau,
    scale_schedule=scale_schedule,
    cfg_insertion_layer=[args.cfg_insertion_layer],
    vae_type=args.vae_type,
    sampling_per_bits=args.sampling_per_bits,
    enable_positive_prompt=enable_positive_prompt,
)
args.save_file = 'ipynb_tmp.jpg'
os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)
cv2.imwrite(args.save_file, generated_image.cpu().numpy())
print(f'Save to {osp.abspath(args.save_file)}')

The size of the finetuning checkpoint “local_output/debug/ar-ckpt-giter004K-ep3-iter756-last.pth” seems not to be compatible with the transformer of Infinity. Is any operation required to convert this weight during testing?

@JeyesHan
Copy link
Collaborator

JeyesHan commented Jan 9, 2025

@xuanyuzhang21 You need to convert [ar-ckpt-giter004K-ep3-iter756-last.pth] to [slim-ckpt-giter004K-ep3-iter756-last.pth] first. Please refer to https://github.com/FoundationVision/Infinity/blob/main/tools/run_infinity.py#L137.

Besides, enable_model_cache=True will do this automatically.

@xuanyuzhang21
Copy link

Thanks for your help! It works~

@l-dawei
Copy link
Author

l-dawei commented Jan 10, 2025

@xuanyuzhang21
Hello, I would like to ask if the inference is normal after your retraining or fine-tuning?

@xuanyuzhang21
Copy link

xuanyuzhang21 commented Jan 10, 2025

@l-dawei Yes. My training result is normal now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants