diff --git a/embedding/training.py b/embedding/training.py index d018063..cbf3d38 100644 --- a/embedding/training.py +++ b/embedding/training.py @@ -1,4 +1,5 @@ import argparse +import torch import torch.optim as optim from utils import GraphPairsDataset, contrastive_loss, create_loaders from torch_geometric.datasets import WikiCS, Amazon, Coauthor @@ -6,7 +7,7 @@ import json def train_routine(params, model, train_loader, val_loader): - for epoch in range(params['epochs']): + for epoch in range(params['start_epoch'], params['epochs']): # training model.train() optimizer = optim.Adam(model.parameters(), lr=params['lr']) @@ -21,6 +22,12 @@ def train_routine(params, model, train_loader, val_loader): print(f"Training: epoch {epoch}, batch {batch_idx}") batch_idx += 1 + # save the model + torch.save({ + 'model_state_dict': model.state_dict(), + 'epoch': epoch + 1, + }, './checkpoints/model.pth') + # validation model.eval() avg_val_loss = 0 @@ -47,7 +54,7 @@ def test_routine(params, model, test_loader): avg_test_loss /= test_loader_size print(f"Average test loss: {avg_test_loss}") -def train_and_test(dataset_name, params): +def train_and_test(dataset_name, params, resume=False): graphs = None if (dataset_name == 'WikiCS'): graphs = WikiCS('../datasets/WikiCS') @@ -78,6 +85,15 @@ def train_and_test(dataset_name, params): params['k'], ) + # load saved model if resuming training + start_epoch = 0 + if resume: + checkpoint = torch.load('./checkpoints/model.pth') + start_epoch = checkpoint.epoch + model.load_state_dict(checkpoint['model_state_dict']) + + params['start_epoch'] = start_epoch + train_routine(params, model, train_loader, val_loader) test_routine(params, model, test_loader) @@ -87,9 +103,10 @@ def train_and_test(dataset_name, params): # get the dataset parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, required=True, choices=['WikiCS', 'AmazonPhoto', 'CoauthorCS']) + parser.add_argument('--resume', type=bool, default=False, required=False) args = parser.parse_args() # load the hyperparameters params = json.load(f'../configs/{args.dataset}.json') - train_and_test(args.dataset, params) \ No newline at end of file + train_and_test(args.dataset, params, args.resume) \ No newline at end of file