diff --git a/models/fine_grained/train.py b/models/fine_grained/train.py index cfd6451..34fa949 100644 --- a/models/fine_grained/train.py +++ b/models/fine_grained/train.py @@ -74,6 +74,36 @@ def fine_grained_retrieval(train_loader, config): return detector +def pad_sequence(sequences, batch_first=False, padding_value=0.0): + """Pad a list of sequences to the same length.""" + max_size = sequences[0].size() + trailing_dims = max_size[1:] + max_len = max([s.size(0) for s in sequences]) + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims + + out_tensor = sequences[0].new_full(out_dims, padding_value) + for i, tensor in enumerate(sequences): + length = tensor.size(0) + # Use index notation to prevent duplicate references to the tensor + if batch_first: + out_tensor[i, :length, ...] = tensor + else: + out_tensor[:length, i, ...] = tensor + + return out_tensor + + +def collate_fn(batch): + video_frames, text_sentences, labels = zip(*batch) + video_frames_padded = pad_sequence([torch.tensor(v) for v in video_frames], batch_first=True) + text_sentences_padded = pad_sequence([torch.tensor(t) for t in text_sentences], batch_first=True) + labels = torch.tensor(labels) + return video_frames_padded, text_sentences_padded, labels + + def main(): logger.info("Loading configuration.") config = Config() @@ -86,7 +116,7 @@ def main(): annotations = annotations[:5] # For testing purposes dataset = CharadesSTADatasetFineGrained(annotations, CHARADES_VIDEOS_DIR) - train_loader = DataLoader(dataset, batch_size=config.fine_grained_batch_size, shuffle=True) + train_loader = DataLoader(dataset, batch_size=config.fine_grained_batch_size, shuffle=True, collate_fn=collate_fn) logger.info("Data loader created.") model = fine_grained_retrieval(train_loader, config)