Skip to content

A small modification to bpnetlite's BPNet to accomodate large validation datasets.

License

Notifications You must be signed in to change notification settings

adamyhe/PersonalBPNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

46 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PersonalBPNet

A small modification to bpnetlite's BPNet to accomodate large validation datasets.

Redid the validation loop to work with a PyTorch DataLoader (e.g., one generated by GenVarLoader), rather than having to load the whole validation set into memory at once. Also, the model checkpoints save the optimizer state dict, epoch number, and number of steps since last improvement in addition to the model state dict, so that training can be resumed from a checkpoint w/ the correct optimizer and early stopping/epoch states.

Additionally, we include a Pytorch implementation of CLIPNET, which is essentially BPNet with added batch norm layers, similar to what was done with the original CLIPNET implementation in tensorflow.

Installation and programmatic usage

Clone and install github repo:

git clone git@github.com:adamyhe/PersonalBPNet.git
cd PersonalBPNet
pip install -e . # for editable mode.

Then the PersonalBPNet and CLIPNET classes can be directly imported:

from personal_bpnet import PersonalBPNet, CLIPNET

We also provide a PauseNet class. This is designed to be a wrapper around bpnetlite.bpnet.BPNet, PersonalBPNet, or CLIPNET models that transforms them to predict a single scalar output per input sequence. This is designed for fine-tuning the base-resolution models to predicting regulatory phenotypes that can only be represented as a single scalar value per region (e.g., pausing index, for which this class is named). The intended use for this class is as follows:

from personal_bpnet import CLIPNET, PauseNet

# This is for loading from a weights dictionary.
# If you saved the full model, just directly use pretrain=torch.load("weights.torch")
pretrain = CLIPNET(**init_args)
pretrain.load_state_dict(torch.load("weights.torch"))

model = PauseNet(pretrain)
model.fit(**params)

This package is currently in active dev and may change drastically. Models have not been extensively benchmarked yet. May be lots of typos/copy paste errors. A personalized ChromBPNet fitting method has not been included, as I personally have not had success training such models.

Command line interface

For convenience, prediction and attribution (DeepLIFT/SHAP) methods for CLIPNET or PauseNet models can be accessed via a CLI:

clipnet predict -h
clipnet predict_tss -h
clipnet attribute -h

pausenet predict -h
pausenet attribute -h

About

A small modification to bpnetlite's BPNet to accomodate large validation datasets.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages