Skip to content

Commit

Permalink
updated train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saleemhamo committed Jul 23, 2024
1 parent db44be8 commit f8fa3cb
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion models/fine_grained/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,36 @@ def fine_grained_retrieval(train_loader, config):
return detector


def pad_sequence(sequences, batch_first=False, padding_value=0.0):
"""Pad a list of sequences to the same length."""
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max([s.size(0) for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims

out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# Use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor

return out_tensor


def collate_fn(batch):
video_frames, text_sentences, labels = zip(*batch)
video_frames_padded = pad_sequence([torch.tensor(v) for v in video_frames], batch_first=True)
text_sentences_padded = pad_sequence([torch.tensor(t) for t in text_sentences], batch_first=True)
labels = torch.tensor(labels)
return video_frames_padded, text_sentences_padded, labels


def main():
logger.info("Loading configuration.")
config = Config()
Expand All @@ -86,7 +116,7 @@ def main():
annotations = annotations[:5] # For testing purposes

dataset = CharadesSTADatasetFineGrained(annotations, CHARADES_VIDEOS_DIR)
train_loader = DataLoader(dataset, batch_size=config.fine_grained_batch_size, shuffle=True)
train_loader = DataLoader(dataset, batch_size=config.fine_grained_batch_size, shuffle=True, collate_fn=collate_fn)
logger.info("Data loader created.")

model = fine_grained_retrieval(train_loader, config)
Expand Down

0 comments on commit f8fa3cb

Please sign in to comment.