-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Making NLL Pipeline More Clear + LJ55 Fixes (#3)
* Fix that negative time wasn't getting used * Code changes * Fix ODE tolerance and flag using exact likelihood (too slow) * Minor * Env minor * More config changes * Add README updates * A readme minor * Add data_path_train arg for gmm energy * Small conf changes for what metric to monitor on * Minor change to readme * Add a note about potentially trying multiple checkpoints * Minor * Minors * Run pre-commit * Minor * Switch bflow dependency * Pin np version * More minors * Update requirements.txt * remove broken docformatter * Update distribution_distances.py * Update distribution_distances.py * Update distribution_distances.py * Update .pre-commit-config.yaml * Update distribution_distances.py Change to have some tolerance for Mac which is more precise * RM notebook linting * Change mdformat version * small pre-commit edits * More pre-commit changes * Minor --------- Co-authored-by: Alexander Tong <alexandertongdev@gmail.com>
- Loading branch information
Showing
18 changed files
with
525 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ plot_samples_epoch_period: 1 | |
|
||
should_unnormalize: true | ||
data_normalization_factor: 50 | ||
data_path_train: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# @package _global_ | ||
|
||
# to execute this experiment run: | ||
# python train.py experiment=example | ||
|
||
# all parameters below will be merged with parameters from default configurations set above | ||
# this allows you to overwrite only specified parameters | ||
|
||
defaults: | ||
- override /energy: lj55 | ||
- override /model/net: egnn | ||
|
||
tags: ["LJ55", "CFM", "NLL_eval"] | ||
|
||
seed: 12345 | ||
|
||
logger: | ||
wandb: | ||
tags: ${tags} | ||
group: "lj55" | ||
|
||
# You need to fill energy.data_path_train at the command line with samples generated | ||
# by running dem/eval.py on a checkpoint | ||
energy: | ||
data_path_train: null | ||
|
||
data: | ||
n_val_batches_per_epoch: 4 | ||
|
||
trainer: | ||
check_val_every_n_epoch: 1 | ||
max_epochs: 2000 | ||
|
||
model: | ||
debug_use_train_data: true | ||
use_otcfm: false | ||
use_ema: false | ||
nll_with_cfm: true | ||
logz_with_cfm: true | ||
tol: 1e-3 | ||
use_exact_likelihood: false | ||
net: | ||
n_particles: 55 | ||
n_layers: 5 | ||
hidden_nf: 128 | ||
|
||
noise_schedule: | ||
_target_: dem.models.components.noise_schedules.GeometricNoiseSchedule | ||
sigma_min: 0.5 | ||
sigma_max: 4 | ||
|
||
partial_prior: | ||
_target_: dem.energies.base_prior.MeanFreePrior | ||
_partial_: true | ||
n_particles: 55 | ||
spatial_dim: 3 | ||
|
||
optimizer: | ||
lr: 1e-3 | ||
|
||
lambda_weighter: | ||
_target_: dem.models.components.lambda_weighter.NoLambdaWeighter | ||
_partial_: true | ||
|
||
clipper: | ||
_target_: dem.models.components.clipper.Clipper | ||
should_clip_scores: True | ||
should_clip_log_rewards: False | ||
max_score_norm: 20 | ||
min_log_reward: null | ||
|
||
diffusion_scale: 0.5 | ||
|
||
num_init_samples: 1024 | ||
num_samples_to_generate_per_epoch: 128 | ||
num_samples_to_sample_from_buffer: 512 | ||
num_samples_to_save: 1000 | ||
eval_batch_size: 16 | ||
|
||
init_from_prior: true | ||
|
||
nll_integration_method: dopri5 | ||
|
||
negative_time: false | ||
num_negative_time_steps: 10 | ||
|
||
callbacks: | ||
model_checkpoint: | ||
monitor: "val/nll" | ||
save_top_k: -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.