Skip to content

Commit

Permalink
Making NLL Pipeline More Clear + LJ55 Fixes (#3)
Browse files Browse the repository at this point in the history
* Fix that negative time wasn't getting used

* Code changes

* Fix ODE tolerance and flag using exact likelihood (too slow)

* Minor

* Env minor

* More config changes

* Add README updates

* A readme minor

* Add data_path_train arg for gmm energy

* Small conf changes for what metric to monitor on

* Minor change to readme

* Add a note about potentially trying multiple checkpoints

* Minor

* Minors

* Run pre-commit

* Minor

* Switch bflow dependency

* Pin np version

* More minors

* Update requirements.txt

* remove broken docformatter

* Update distribution_distances.py

* Update distribution_distances.py

* Update distribution_distances.py

* Update .pre-commit-config.yaml

* Update distribution_distances.py

Change to have some tolerance for Mac which is more precise

* RM notebook linting

* Change mdformat version

* small pre-commit edits

* More pre-commit changes

* Minor

---------

Co-authored-by: Alexander Tong <alexandertongdev@gmail.com>
  • Loading branch information
jarridrb and atong01 authored Jan 24, 2025
1 parent 77fd2ec commit 3b4d705
Show file tree
Hide file tree
Showing 18 changed files with 525 additions and 104 deletions.
36 changes: 3 additions & 33 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
default_language_version:
python: python3
node: 16.14.2

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down Expand Up @@ -39,29 +38,15 @@ repos:
- id: pyupgrade
args: [--py38-plus]

# python docstring formatting
- repo: https://github.com/myint/docformatter
rev: v1.7.4
hooks:
- id: docformatter
args:
[
--in-place,
--wrap-summaries=99,
--wrap-descriptions=99,
--style=sphinx,
--black,
]

# python docstring coverage checking
- repo: https://github.com/econchick/interrogate
rev: 1.5.0 # or master if you're bold
rev: 1.7.0 # or master if you're bold
hooks:
- id: interrogate
args:
[
--verbose,
--fail-under=20,
--fail-under=15,
--ignore-init-module,
--ignore-init-method,
--ignore-module,
Expand Down Expand Up @@ -106,7 +91,7 @@ repos:

# md formatting
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
rev: 0.7.21
hooks:
- id: mdformat
args: ["--number"]
Expand All @@ -131,18 +116,3 @@ repos:
rev: 0.6.1
hooks:
- id: nbstripout

# jupyter notebook linting
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.3
hooks:
- id: nbqa-black
args: ["--line-length=99"]
- id: nbqa-isort
args: ["--profile=black"]
- id: nbqa-flake8
args:
[
"--extend-ignore=E203,E402,E501,F401,F841",
"--exclude=logs/*,data/*",
]
58 changes: 58 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,64 @@ which we only include a config for iDEM as pDEM had convergence issues on this d

The current repository contains code for experiments for iDEM and pDEM as specified in our paper.

## Update January 2025

In this update we provide code and more detailed instructions on how to run the CFM models including log Z and ESS computation.
In doing this, we also found a few bugs in the public code implementation for LJ55 (note that this codebase is an adaptation of
a large number of notebooks used for the paper) which we have fixed in a set of code updates just merged to the repository.

### CFM for Computing NLL Pipeline

We will use the example of LJ55 in detailing the pipeline. First, run the training script as normal as follows

```bash
python dem/train.py experiment=lj55_idem
```

After training is complete, find the epochs with the best `val/2-Wasserstein` values in wandb. We will use the
best checkpoint to generate a training dataset for CFM in the following command. This command will also log the
2-Wasserstein and total variation distance for the dataset generated from the trained iDEM model compared to the
test set. To run this, you must provide the eval script with the checkpoint path you are using.

```bash
python dem/eval.py experiment=lj55_idem ckpt_path=<path_to_ckpt>
```

This will take some time to run and will generate a file named `samples_<n_samples_to_generate>.pt` in the hydra
runtime directory for the eval run. We can now use these samples to train a CFM model. We provide a config `lj55_idem_cfm`
which has the settings to enable the CFM pipeline to run by default for the LJ55 task, though doing so for other tasks
is also simple. The main config changes required are to set `model.debug_use_train_data=true, model.nll_with_cfm=true`
and `model.logz_with_cfm=true`. To point the CFM training run to the dataset generated from iDEM samples we can set the
`energy.data_path_train` attribute to the path of the generated samples. CFM training in this example can then be done
with

```bash
python dem/train.py experiment=lj55_idem_cfm energy.data_path_train=<path_to_samples>
```

Finally, to eval test set NLL, take the checkpoint of the CFM run with the best `val/nll` and run the eval script
again

```bash
python dem/eval.py experiment=lj55_idem_cfm ckpt_path=<path_to_cfm_ckpt>
```

Finally, we note that you may need to try a couple different checkpoints from the original
`python dem/train.py experiment=lj55_idem` run to be used in generating samples and downstream CFM training/eval in
order to get the best combination of eval metrics.

### ESS Computation Considerations

In preparing this update we noticed our original evaluation of ESS was evaluated on a batch size of 16 on all tasks. We recommend users of our
repository instead evaluate ESS on a larger batch size, (default to 1000) in the updated code. To reproduce the results in the paper you can
either set this to 16 or look at the wandb during validation when training the CFM model which evaluates on batch size 16.

### LJ55 negative time

In our original manuscript for LJ55 we used 10 steps of "negative time" (described in Section 4 of our manuscript)
during inference where we continued SDE inference for 10 extra steps using the true score at time 0. The repository
code had the flag to do this turned on in the configs but the code ignored this flag. This has been corrected in the update.

## Citations

If this codebase is useful towards other research efforts please consider citing us.
Expand Down
8 changes: 4 additions & 4 deletions configs/callbacks/default.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
defaults:
#- model_checkpoint
- model_checkpoint
- model_summary
- rich_progress_bar
- _self_

model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
filename: "epoch_{epoch:03d}"
monitor: "val/nll"
monitor: "val/2-Wasserstein"
mode: "min"
every_n_epochs: 50 # number of epochs between checkpoints
every_n_epochs: ${trainer.check_val_every_n_epoch} # number of epochs between checkpoints
save_last: True
save_top_k: 3
save_top_k: 5
auto_insert_metric_name: False
verbose: true

Expand Down
2 changes: 1 addition & 1 deletion configs/callbacks/model_checkpoint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: null # directory to save the model file
filename: null # checkpoint filename
monitor: "val/loss" # name of the logged metric which determines when model is improving
monitor: "val/2-Wasserstein" # name of the logged metric which determines when model is improving
verbose: False # verbosity mode
save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 1 # save k best models (determined by above metric)
Expand Down
1 change: 1 addition & 0 deletions configs/energy/gmm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ plot_samples_epoch_period: 1

should_unnormalize: true
data_normalization_factor: 50
data_path_train: null
3 changes: 2 additions & 1 deletion configs/energy/lj13.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ dimensionality: 39
n_particles: 13
data_path: "data/test_split_LJ13-1000.npy"
data_path_train: "data/train_split_LJ13-1000.npy"
data_path_val: "data/test_split_LJ13-1000.npy"
data_path_val: "data/val_split_LJ13-1000.npy"
data_path_test: "data/test_split_LJ13-1000.npy"

device: ${trainer.accelerator}

Expand Down
1 change: 1 addition & 0 deletions configs/energy/lj55.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ n_particles: 55
data_path: "data/test_split_LJ55-1000-part1.npy"
data_path_train: "data/train_split_LJ55-1000-part1.npy"
data_path_val: "data/val_split_LJ55-1000-part1.npy"
data_path_test: "data/test_split_LJ55-1000-part1.npy"

device: ${trainer.accelerator}

Expand Down
11 changes: 8 additions & 3 deletions configs/experiment/lj55_idem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ data:
n_val_batches_per_epoch: 4

trainer:
check_val_every_n_epoch: 50
check_val_every_n_epoch: 5
max_epochs: 2000

model:
Expand Down Expand Up @@ -59,12 +59,17 @@ model:
num_init_samples: 1024
num_samples_to_generate_per_epoch: 128
num_samples_to_sample_from_buffer: 128
eval_batch_size: 16
eval_batch_size: 128

init_from_prior: true
num_samples_to_save: 10000

nll_integration_method: dopri5

negative_time: True
num_negative_time_steps: 100
num_negative_time_steps: 10

callbacks:
model_checkpoint:
monitor: "val/2-Wasserstein"
save_top_k: -1
90 changes: 90 additions & 0 deletions configs/experiment/lj55_idem_cfm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

defaults:
- override /energy: lj55
- override /model/net: egnn

tags: ["LJ55", "CFM", "NLL_eval"]

seed: 12345

logger:
wandb:
tags: ${tags}
group: "lj55"

# You need to fill energy.data_path_train at the command line with samples generated
# by running dem/eval.py on a checkpoint
energy:
data_path_train: null

data:
n_val_batches_per_epoch: 4

trainer:
check_val_every_n_epoch: 1
max_epochs: 2000

model:
debug_use_train_data: true
use_otcfm: false
use_ema: false
nll_with_cfm: true
logz_with_cfm: true
tol: 1e-3
use_exact_likelihood: false
net:
n_particles: 55
n_layers: 5
hidden_nf: 128

noise_schedule:
_target_: dem.models.components.noise_schedules.GeometricNoiseSchedule
sigma_min: 0.5
sigma_max: 4

partial_prior:
_target_: dem.energies.base_prior.MeanFreePrior
_partial_: true
n_particles: 55
spatial_dim: 3

optimizer:
lr: 1e-3

lambda_weighter:
_target_: dem.models.components.lambda_weighter.NoLambdaWeighter
_partial_: true

clipper:
_target_: dem.models.components.clipper.Clipper
should_clip_scores: True
should_clip_log_rewards: False
max_score_norm: 20
min_log_reward: null

diffusion_scale: 0.5

num_init_samples: 1024
num_samples_to_generate_per_epoch: 128
num_samples_to_sample_from_buffer: 512
num_samples_to_save: 1000
eval_batch_size: 16

init_from_prior: true

nll_integration_method: dopri5

negative_time: false
num_negative_time_steps: 10

callbacks:
model_checkpoint:
monitor: "val/nll"
save_top_k: -1
6 changes: 5 additions & 1 deletion configs/model/dem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ use_richardsons: false

cfm_loss_weight: 1.0
use_ema: false
use_exact_likelihood: True
use_exact_likelihood: true

# train cfm only on train data and not dem
debug_use_train_data: false
Expand All @@ -83,3 +83,7 @@ num_samples_to_save: 100000

negative_time: false
num_negative_time_steps: 100

nll_batch_size: 256

seed: ${seed}
21 changes: 19 additions & 2 deletions dem/energies/gmm_energy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from fab.target_distributions import gmm
from fab.utils.plotting import plot_contours, plot_marginal_pair
Expand All @@ -27,6 +28,7 @@ def __init__(
train_set_size=100000,
test_set_size=2000,
val_set_size=2000,
data_path_train=None,
):
use_gpu = device != "cpu"
torch.manual_seed(0) # seed of 0 for GMM problem
Expand All @@ -51,6 +53,8 @@ def __init__(
self.test_set_size = test_set_size
self.val_set_size = val_set_size

self.data_path_train = data_path_train

self.name = "gmm"

super().__init__(
Expand All @@ -65,8 +69,21 @@ def setup_test_set(self):
return self.gmm.test_set

def setup_train_set(self):
train_samples = self.gmm.sample((self.train_set_size,))
return self.normalize(train_samples)
if self.data_path_train is None:
train_samples = self.normalize(self.gmm.sample((self.train_set_size,)))

else:
# Assume the samples we are loading from disk are already normalized.
# This breaks if they are not.

if self.data_path_train.endswith(".pt"):
data = torch.load(self.data_path_train).cpu().numpy()
else:
data = np.load(self.data_path_train, allow_pickle=True)

data = torch.tensor(data, device=self.device)

return train_samples

def setup_val_set(self):
val_samples = self.gmm.sample((self.val_set_size,))
Expand Down
Loading

0 comments on commit 3b4d705

Please sign in to comment.