A small modification to bpnetlite's BPNet to accomodate large validation datasets.
Redid the validation loop to work with a PyTorch DataLoader (e.g., one generated by GenVarLoader), rather than having to load the whole validation set into memory at once. Also, the model checkpoints save the optimizer state dict, epoch number, and number of steps since last improvement in addition to the model state dict, so that training can be resumed from a checkpoint w/ the correct optimizer and early stopping/epoch states.
Additionally, we include a Pytorch implementation of CLIPNET, which is essentially BPNet with added batch norm layers, similar to what was done with the original CLIPNET implementation in tensorflow.
Clone and install github repo:
git clone git@github.com:adamyhe/PersonalBPNet.git
cd PersonalBPNet
pip install -e . # for editable mode.
Then the PersonalBPNet
and CLIPNET
classes can be directly imported:
from personal_bpnet import PersonalBPNet, CLIPNET
We also provide a PauseNet
class. This is designed to be a wrapper around bpnetlite.bpnet.BPNet
, PersonalBPNet
, or CLIPNET
models that transforms them to predict a single scalar output per input sequence. This is designed for fine-tuning the base-resolution models to predicting regulatory phenotypes that can only be represented as a single scalar value per region (e.g., pausing index, for which this class is named). The intended use for this class is as follows:
from personal_bpnet import CLIPNET, PauseNet
# This is for loading from a weights dictionary.
# If you saved the full model, just directly use pretrain=torch.load("weights.torch")
pretrain = CLIPNET(**init_args)
pretrain.load_state_dict(torch.load("weights.torch"))
model = PauseNet(pretrain)
model.fit(**params)
This package is currently in active dev and may change drastically. Models have not been extensively benchmarked yet. May be lots of typos/copy paste errors. A personalized ChromBPNet fitting method has not been included, as I personally have not had success training such models.
For convenience, prediction and attribution (DeepLIFT/SHAP) methods for CLIPNET or PauseNet models can be accessed via a CLI:
clipnet predict -h
clipnet predict_tss -h
clipnet attribute -h
pausenet predict -h
pausenet attribute -h