From 5707106fbe13121649834a56d6e7bef73f708541 Mon Sep 17 00:00:00 2001 From: saleemhamo Date: Tue, 23 Jul 2024 19:17:30 +0300 Subject: [PATCH] updated train.py --- models/fine_grained/train.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/models/fine_grained/train.py b/models/fine_grained/train.py index f7ae03f..3664dd2 100644 --- a/models/fine_grained/train.py +++ b/models/fine_grained/train.py @@ -14,11 +14,23 @@ from models.fine_grained.components.supervision import SupervisionLoss from models.fine_grained.components.qd_detr import QDDETRModel from models.fine_grained.data_loaders.charades_sta_dataset import CharadesSTADatasetFineGrained +from transformers import BertTokenizer, CLIPProcessor import numpy as np # Setup logger logger = setup_logger('train_logger') +# Initialize config +config = Config() + +# Initialize tokenizer based on config +if config.fine_grained_text_extractor == 'bert': + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +elif config.fine_grained_text_extractor == 'clip': + tokenizer = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') +else: + raise ValueError("Invalid text_extractor value in config") + def fine_grained_retrieval(train_loader, config): device = get_device(logger) @@ -99,16 +111,21 @@ def pad_sequence(sequences, batch_first=False, padding_value=0.0): def collate_fn(batch): video_frames, text_sentences, labels = zip(*batch) + # Tokenize text sentences + if config.fine_grained_text_extractor == 'bert': + text_sentences = [tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)['input_ids'].squeeze(0) for text in text_sentences] + elif config.fine_grained_text_extractor == 'clip': + text_sentences = [tokenizer(texts=[text], return_tensors='pt')['input_ids'].squeeze(0) for text in text_sentences] + # Convert video frames and text sentences to numpy arrays first video_frames_padded = pad_sequence([torch.tensor(np.array(v)) for v in video_frames], batch_first=True) - text_sentences_padded = pad_sequence([torch.tensor(np.array(t)) for t in text_sentences], batch_first=True) + text_sentences_padded = pad_sequence(text_sentences, batch_first=True) labels = torch.tensor(labels) return video_frames_padded, text_sentences_padded, labels def main(): logger.info("Loading configuration.") - config = Config() charades_sta = CharadesSTA( video_dir=CHARADES_VIDEOS_DIR, train_file=CHARADES_ANNOTATIONS_TRAIN,