Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could you provide the training script ? #42

Open
zt1112 opened this issue Nov 27, 2020 · 3 comments
Open

Could you provide the training script ? #42

zt1112 opened this issue Nov 27, 2020 · 3 comments

Comments

@zt1112
Copy link

zt1112 commented Nov 27, 2020

I want to train a model on kitti dataset, could you provide a training script? Thanks!

@LulaSan
Copy link

LulaSan commented May 6, 2021

It would be useful also for me

@PouyaNF
Copy link

PouyaNF commented Jun 25, 2022

useful also for me

@lovelyzzkei
Copy link

For those who have same question, I attached the training script I used for conveinence.
I referred https://github.com/fangchangma/sparse-to-dense.pytorch from #60

def main():
   ... 
   _existing codes_
   ...
   
   if args.train:
        if args.data == 'nyudepthv2':
            from dataloaders.nyu import NYUDataset
            train_dataset = NYUDataset(traindir, split='train', modality=args.modality)
        else:
            raise RuntimeError('Dataset not found.')
       
        if args.criterion == 'l2':
            criterion = MaskedMSELoss().cuda()
        elif args.criterion == 'l1':
            criterion = MaskedL1Loss().cuda()

        # set batch size to be 1 for validation
        train_loader = torch.utils.data.DataLoader(train_dataset,
            batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        print("=> Train data loaders created.")
        model = models.MobileNetSkipAdd(output_size=(1, 224, 224, 1), pretrained=args.pretrained)
        model = model.cuda()
        
        start_epoch = 0
        optimizer = torch.optim.SGD(model.parameters(), args.lr, \
            momentum=args.momentum, weight_decay=args.weight_decay)
        
        # create results folder, if not already exists
        output_directory = utils.get_output_directory(args)
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)
        train_csv = os.path.join(output_directory, 'train.csv')
        test_csv = os.path.join(output_directory, 'test.csv')
        best_txt = os.path.join(output_directory, 'best.txt')

        # create new csv files with only header
        if not args.resume:
            with open(train_csv, 'w') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
            with open(test_csv, 'w') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
        
        for epoch in range(start_epoch, args.epochs):
            utils.adjust_learning_rate(optimizer, epoch, args.lr)
            train(train_loader, model, criterion, optimizer, epoch) # train for one epoch
            result, img_merge = validate(val_loader, model, epoch) # evaluate on validation set

            # remember best rmse and save checkpoint
            is_best = result.rmse < best_result.rmse
            if is_best:
                best_result = result
                with open(best_txt, 'w') as txtfile:
                    txtfile.write("epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n".
                        format(epoch, result.mse, result.rmse, result.absrel, result.lg10, result.mae, result.delta1, result.gpu_time))
                if img_merge is not None:
                    img_filename = output_directory + '/comparison_best.png'
                    utils.save_image(img_merge, img_filename)

            utils.save_checkpoint({
                'args': args,
                'epoch': epoch,
                'model': model,
                'best_result': best_result,
                'optimizer' : optimizer,
            }, is_best, epoch, output_directory)

def train(train_loader, model, criterion, optimizer, epoch):
    average_meter = AverageMeter()
    model.train() # switch to train mode
    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        input, target = input.cuda(), target.cuda()
        torch.cuda.synchronize()
        data_time = time.time() - end

        # compute pred
        end = time.time()
        pred = model(input)

        loss = criterion(pred, target)
        optimizer.zero_grad()
        loss.backward() # compute gradient and do SGD step
        optimizer.step()
        torch.cuda.synchronize()
        gpu_time = time.time() - end

        # measure accuracy and record loss
        result = Result()
        result.evaluate(pred.data, target.data)
        average_meter.update(result, gpu_time, data_time, input.size(0))
        end = time.time()

        if (i + 1) % args.print_freq == 0:
            print('=> output: {}'.format(output_directory))
            print('Train Epoch: {0} [{1}/{2}]\t'
                  't_Data={data_time:.3f}({average.data_time:.3f}) '
                  't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                  'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
                  'MAE={result.mae:.2f}({average.mae:.2f}) '
                  'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
                  'REL={result.absrel:.3f}({average.absrel:.3f}) '
                  'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format(
                  epoch, i+1, len(train_loader), data_time=data_time,
                  gpu_time=gpu_time, result=result, average=average_meter.average()))

    avg = average_meter.average()
    with open(train_csv, 'a') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10,
            'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3,
            'gpu_time': avg.gpu_time, 'data_time': avg.data_time})


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants