Skip to content

Commit

Permalink
use checkpoints for training
Browse files Browse the repository at this point in the history
  • Loading branch information
mjyoussef committed Apr 22, 2024
1 parent a318419 commit 4fcfe21
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions embedding/training.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import argparse
import torch
import torch.optim as optim
from utils import GraphPairsDataset, contrastive_loss, create_loaders
from torch_geometric.datasets import WikiCS, Amazon, Coauthor
from model import GraphSAGE
import json

def train_routine(params, model, train_loader, val_loader):
for epoch in range(params['epochs']):
for epoch in range(params['start_epoch'], params['epochs']):
# training
model.train()
optimizer = optim.Adam(model.parameters(), lr=params['lr'])
Expand All @@ -21,6 +22,12 @@ def train_routine(params, model, train_loader, val_loader):
print(f"Training: epoch {epoch}, batch {batch_idx}")
batch_idx += 1

# save the model
torch.save({
'model_state_dict': model.state_dict(),
'epoch': epoch + 1,
}, './checkpoints/model.pth')

# validation
model.eval()
avg_val_loss = 0
Expand All @@ -47,7 +54,7 @@ def test_routine(params, model, test_loader):
avg_test_loss /= test_loader_size
print(f"Average test loss: {avg_test_loss}")

def train_and_test(dataset_name, params):
def train_and_test(dataset_name, params, resume=False):
graphs = None
if (dataset_name == 'WikiCS'):
graphs = WikiCS('../datasets/WikiCS')
Expand Down Expand Up @@ -78,6 +85,15 @@ def train_and_test(dataset_name, params):
params['k'],
)

# load saved model if resuming training
start_epoch = 0
if resume:
checkpoint = torch.load('./checkpoints/model.pth')
start_epoch = checkpoint.epoch
model.load_state_dict(checkpoint['model_state_dict'])

params['start_epoch'] = start_epoch

train_routine(params, model, train_loader, val_loader)

test_routine(params, model, test_loader)
Expand All @@ -87,9 +103,10 @@ def train_and_test(dataset_name, params):
# get the dataset
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, required=True, choices=['WikiCS', 'AmazonPhoto', 'CoauthorCS'])
parser.add_argument('--resume', type=bool, default=False, required=False)
args = parser.parse_args()

# load the hyperparameters
params = json.load(f'../configs/{args.dataset}.json')

train_and_test(args.dataset, params)
train_and_test(args.dataset, params, args.resume)

0 comments on commit 4fcfe21

Please sign in to comment.