Skip to content

Latest commit

 

History

History
61 lines (49 loc) · 2.51 KB

README.md

File metadata and controls

61 lines (49 loc) · 2.51 KB

GradNCP

Official PyTorch implementation of "Learning Large-scale Neural Fields via Context Pruned Meta-Learning" (NeurIPS 2023) by Jihoon Tack, Subin Kim, Sihyun Yu, Jaeho Lee, Jinwoo Shin, Jonathan Richard Schwarz.

TL;DR: We propose an efficient meta-learning framework for scalable neural fields learning that involves online data pruning of the context set.

1. Dependencies

conda create -n gradncp python=3.8 -y
conda activate gradncp

pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
pip install einops pyyaml tensorboardX tensorboard natsort pyspng av pytorch_msssim lpips

2. Dataset

  • Dataset path /data, one can change the path in data.dataset.py (e.g., DATA_PATH = './PATH_TO_DATA')
  • Download CelebA, CelebA-HQ, AFHQ, Imagenette-320, ImageNet, Text, UCF-101, Librispeech, ERA5

3. How to run?

Train

# Learnit
CUDA_VISIBLE_DEVICES=0 python main.py --configs ./configs/main/maml_celeba.yaml

# Ours
CUDA_VISIBLE_DEVICES=0 python main.py --configs ./configs/main/ours_celeba.yaml

Evaluation

  • Example of <PATH TO CHECKPOINT>: ./logs/maml_celeba/best.pth
# Learnit
CUDA_VISIBLE_DEVICES=0 python eval.py --configs ./configs/evaluation/eval_celeba.yaml --load_path ./logs/xxxx/best.model

# Ours (CelebaA) Example
CUDA_VISIBLE_DEVICES=0 python eval.py --configs ./configs/evaluation/eval_celeba_ours.yaml --load_path ./logs/xxxx/best.model

Reference

This code is mainly built upon JAX Learnit, JAX Functa, PyTorch Siren, PyTorch MetaSDF, PyTorch Meta-SparseINR, and PyTorch COIN++ repositories.

Citation

@inproceedings{tack2023learning,
  title={Learning Large-scale Neural Fields via Context Pruned Meta-Learning},
  author={Tack, Jihoon and Kim, Subin and Yu, Sihyun and Lee, Jaeho and Shin, Jinwoo and Schwarz, Jonathan Richard},
  booktitle={Advances in Neural Information Processing Systems},
  year={2023}
}