Only 1 gpu is required.
IDFF simultaneously learns an implicit flow and a scoring model that come together during the sampling process. This structure allows IDFF to reduce the number of function evaluations (NFE) by more than 10 times compared to traditional CFMs, enabling rapid sampling and efficient handling of image and time-series data generation tasks. See bellow for an illustration:
Image generation using IDFF (with NFE=10) applied to multiple datasets: CIFAR10, CelebA-64, ImageNet-64, LSUN-Bedroom, LSUN-Church, and CelebA-HQ.
A) Comparison of trajectory sampling between IDFF and OT-CFMs: The figure displays 4096 final samples generated by IDFF. As shown, IDFF takes larger steps toward the target distribution, guided by the momentum term. While OT-CFM follows a nearly straight path to reach the target distribution, it requires a higher number of function evaluations (NFEs). B) OT-CFMs sampling process. C) IDFF sampling process.
For more information, please see our paper, Implicit Dynamical Flow Fusion (IDFF) for Generative Modeling.
-
A simple toy example is available here
2D-toys/2D-toy-examples.pyynb
. -
To utilize IDFF for image generation examples, execute the following command for each dataset (inside each directory):
python simple_gen_test.py
To run the code, you must either download the pre-trained model from the link below or train it from scratch.
Each pre-trained model should be placed in the results/IDFF-2.0-0.2
directory associated with each example.
Exp | Args | FID | NFE | Checkpoints | Samples |
---|---|---|---|---|---|
CIFAR-10 | cifar10/simple_gen_test.py | 5.87 | 10 | IDFF_cifar10_weights_step_final.pt | cifar10_samples |
CelebA-64 | celebA/simple_gen_test.py | 11.83 | 10 | IDFF_celeba_weights_step_final.pt | celeba_64_samples |
CelebA-256 | celebA_HQ/simple_gen_test.py | --- | 10 | IDFF_celeba_256_weights_step_final.pt | celeba_256_samples |
LSUN-Bed | lsun_bed/simple_gen_test.py | --- | 10 | IDFF_lsun_bed_weights_step_final.pt | lsun_bed_samples |
LSUN-Church | lsun_church/simple_gen_test.py | 12.86 | 10 | IDFF_lsun_church_weights_step_final.pt | lsun_church_samples |
Exp | Args | Checkpoints |
---|---|---|
SST | sst/sst_forecaster.py | IDFF_sst_weights_step_final.pt |
Exp | Args | Checkpoints |
---|---|---|
MD | timeseries_examples/MD_simulation.py | IDFF_MD_v1.pt |
For CelebA HQ 256 and LSUN datasets, we used the instructions provided here NVAE's instructions.
The datasets for SST and MD experiments are provided Here.
To calculate the FID score using 50K samples, you need to generate 50K samples. For instance, you can use the following commands:
python gen_cifar10.py
To generate original samples for CIFAR-10:
python gen_true_cifar10.py
Finally, you can use pytorch_fid to compute the FID between the two sets of samples with the following code:
python -m pytorch_fid /path_to_original_samples /path_to_generated_samples
When utilizing this repository to generate published results or integrate it into other software, kindly acknowledge our paper by citing it.
@article{xx,
title={xx},
author={xx},
journal={xx},
year={xx}
}