Skip to content

Commit

Permalink
updated train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saleemhamo committed Jul 23, 2024
1 parent 0cfe658 commit e2a1ca3
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions models/fine_grained/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from data.charades_sta import CharadesSTA
from models.coarse_grained.helpers import config
from utils.config import Config
from utils.model_utils import save_model, get_device
from utils.constants import CHARADES_VIDEOS_DIR, CHARADES_ANNOTATIONS_TRAIN, CHARADES_ANNOTATIONS_TEST, \
Expand Down Expand Up @@ -33,13 +34,17 @@ def collate_fn(batch):
# Pad video frames
video_frames_padded = pad_sequence([torch.tensor(np.array(v)) for v in video_frames], batch_first=True)

# Convert text sentences to strings if they are tensors
if isinstance(text_sentences[0], torch.Tensor):
text_sentences = [text_sentence.tolist() for text_sentence in text_sentences]
text_sentences = [" ".join(map(str, text_sentence)) for text_sentence in text_sentences]
# Tokenize text sentences
if config.fine_grained_text_extractor == 'bert':
text_sentences = [bert_tokenizer(text, return_tensors='pt')['input_ids'].squeeze(0) for text in text_sentences]
elif config.fine_grained_text_extractor == 'clip':
text_sentences = [clip_processor(text=[text], return_tensors='pt')['input_ids'].squeeze(0) for text in
text_sentences]
else:
raise ValueError("Invalid text_extractor value in config")

# Pad text sentences
text_sentences_padded = pad_sequence([torch.tensor(np.array(t)) for t in text_sentences], batch_first=True)
text_sentences_padded = pad_sequence(text_sentences, batch_first=True)

labels = torch.tensor(labels)
return video_frames_padded, text_sentences_padded, labels
Expand Down

0 comments on commit e2a1ca3

Please sign in to comment.