Skip to content

Implicit Dynamical Flow Fusion (IDFF) for Generative Modeling

Notifications You must be signed in to change notification settings

MrRezaeiUofT/IDFF

Repository files navigation

Implicit Dynamical Flow Fusion (IDFF) for Generative Modeling

Only 1 gpu is required.

IDFF simultaneously learns an implicit flow and a scoring model that come together during the sampling process. This structure allows IDFF to reduce the number of function evaluations (NFE) by more than 10 times compared to traditional CFMs, enabling rapid sampling and efficient handling of image and time-series data generation tasks. See bellow for an illustration:

IDFF

Image generation using IDFF (with NFE=10) applied to multiple datasets: CIFAR10, CelebA-64, ImageNet-64, LSUN-Bedroom, LSUN-Church, and CelebA-HQ.

IDFF

A) Comparison of trajectory sampling between IDFF and OT-CFMs: The figure displays 4096 final samples generated by IDFF. As shown, IDFF takes larger steps toward the target distribution, guided by the momentum term. While OT-CFM follows a nearly straight path to reach the target distribution, it requires a higher number of function evaluations (NFEs). B) OT-CFMs sampling process. C) IDFF sampling process.

For more information, please see our paper, Implicit Dynamical Flow Fusion (IDFF) for Generative Modeling.

Usage

  • A simple toy example is available here 2D-toys/2D-toy-examples.pyynb.

  • To utilize IDFF for image generation examples, execute the following command for each dataset (inside each directory):

python simple_gen_test.py 

To run the code, you must either download the pre-trained model from the link below or train it from scratch.

Each pre-trained model should be placed in the results/IDFF-2.0-0.2 directory associated with each example.

Pretrained Image Generation Models

Exp Args FID NFE Checkpoints Samples
CIFAR-10 cifar10/simple_gen_test.py 5.87 10 IDFF_cifar10_weights_step_final.pt cifar10_samples
CelebA-64 celebA/simple_gen_test.py 11.83 10 IDFF_celeba_weights_step_final.pt celeba_64_samples
CelebA-256 celebA_HQ/simple_gen_test.py --- 10 IDFF_celeba_256_weights_step_final.pt celeba_256_samples
LSUN-Bed lsun_bed/simple_gen_test.py --- 10 IDFF_lsun_bed_weights_step_final.pt lsun_bed_samples
LSUN-Church lsun_church/simple_gen_test.py 12.86 10 IDFF_lsun_church_weights_step_final.pt lsun_church_samples

Pretrained Time-series Generation Models

Sea Surface Temperature (SST) Forecasting

Exp Args Checkpoints
SST sst/sst_forecaster.py IDFF_sst_weights_step_final.pt

sst_IDFF

sst_IDFF_result

Molecular Dynamic

Exp Args Checkpoints
MD timeseries_examples/MD_simulation.py IDFF_MD_v1.pt

MD_IDFF

MD_IDFF_result

Dataset preparation

For CelebA HQ 256 and LSUN datasets, we used the instructions provided here NVAE's instructions.

The datasets for SST and MD experiments are provided Here.

FID

To calculate the FID score using 50K samples, you need to generate 50K samples. For instance, you can use the following commands:

python gen_cifar10.py 

To generate original samples for CIFAR-10:

python gen_true_cifar10.py 

Finally, you can use pytorch_fid to compute the FID between the two sets of samples with the following code:

python -m pytorch_fid  /path_to_original_samples  /path_to_generated_samples

Citation

When utilizing this repository to generate published results or integrate it into other software, kindly acknowledge our paper by citing it.

@article{xx,
  title={xx},
  author={xx},
  journal={xx},
  year={xx}
}