Skip to content

Commit

Permalink
updated train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saleemhamo committed Jul 23, 2024
1 parent 8091868 commit 5707106
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions models/fine_grained/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,23 @@
from models.fine_grained.components.supervision import SupervisionLoss
from models.fine_grained.components.qd_detr import QDDETRModel
from models.fine_grained.data_loaders.charades_sta_dataset import CharadesSTADatasetFineGrained
from transformers import BertTokenizer, CLIPProcessor
import numpy as np

# Setup logger
logger = setup_logger('train_logger')

# Initialize config
config = Config()

# Initialize tokenizer based on config
if config.fine_grained_text_extractor == 'bert':
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
elif config.fine_grained_text_extractor == 'clip':
tokenizer = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
else:
raise ValueError("Invalid text_extractor value in config")


def fine_grained_retrieval(train_loader, config):
device = get_device(logger)
Expand Down Expand Up @@ -99,16 +111,21 @@ def pad_sequence(sequences, batch_first=False, padding_value=0.0):
def collate_fn(batch):
video_frames, text_sentences, labels = zip(*batch)

# Tokenize text sentences
if config.fine_grained_text_extractor == 'bert':
text_sentences = [tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)['input_ids'].squeeze(0) for text in text_sentences]
elif config.fine_grained_text_extractor == 'clip':
text_sentences = [tokenizer(texts=[text], return_tensors='pt')['input_ids'].squeeze(0) for text in text_sentences]

# Convert video frames and text sentences to numpy arrays first
video_frames_padded = pad_sequence([torch.tensor(np.array(v)) for v in video_frames], batch_first=True)
text_sentences_padded = pad_sequence([torch.tensor(np.array(t)) for t in text_sentences], batch_first=True)
text_sentences_padded = pad_sequence(text_sentences, batch_first=True)
labels = torch.tensor(labels)
return video_frames_padded, text_sentences_padded, labels


def main():
logger.info("Loading configuration.")
config = Config()
charades_sta = CharadesSTA(
video_dir=CHARADES_VIDEOS_DIR,
train_file=CHARADES_ANNOTATIONS_TRAIN,
Expand Down

0 comments on commit 5707106

Please sign in to comment.