diff --git a/models/fine_grained/components/text_feature_extractor.py b/models/fine_grained/components/text_feature_extractor.py index 48406e1..697f663 100644 --- a/models/fine_grained/components/text_feature_extractor.py +++ b/models/fine_grained/components/text_feature_extractor.py @@ -20,8 +20,10 @@ def __init__(self): self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") def extract_features(self, text, device): + # Check if input is already tokenized if isinstance(text, torch.Tensor): - text = text.tolist() - inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True).to(device) + inputs = text.to(device) + else: + inputs = self.processor(text=[text], return_tensors="pt", padding=True, truncation=True).to(device) outputs = self.model.get_text_features(**inputs) return outputs diff --git a/models/fine_grained/train.py b/models/fine_grained/train.py index e866e4a..abc3292 100644 --- a/models/fine_grained/train.py +++ b/models/fine_grained/train.py @@ -1,6 +1,8 @@ # models/fine_grained/train.py import torch from torch.utils.data import DataLoader +from torch.nn.utils.rnn import pad_sequence +import numpy as np from data.charades_sta import CharadesSTA from utils.config import Config from utils.model_utils import save_model, get_device @@ -15,21 +17,32 @@ from models.fine_grained.components.qd_detr import QDDETRModel from models.fine_grained.data_loaders.charades_sta_dataset import CharadesSTADatasetFineGrained from transformers import BertTokenizer, CLIPProcessor -import numpy as np +import os # Setup logger logger = setup_logger('train_logger') -# Initialize config -config = Config() +# Initialize tokenizers +bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + + +def collate_fn(batch): + video_frames, text_sentences, labels = zip(*batch) + + # Pad video frames + video_frames_padded = pad_sequence([torch.tensor(np.array(v)) for v in video_frames], batch_first=True) -# Initialize tokenizer based on config -if config.fine_grained_text_extractor == 'bert': - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') -elif config.fine_grained_text_extractor == 'clip': - tokenizer = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') -else: - raise ValueError("Invalid text_extractor value in config") + # Convert text sentences to strings if they are tensors + if isinstance(text_sentences[0], torch.Tensor): + text_sentences = [text_sentence.tolist() for text_sentence in text_sentences] + text_sentences = [" ".join(map(str, text_sentence)) for text_sentence in text_sentences] + + # Pad text sentences + text_sentences_padded = pad_sequence([torch.tensor(np.array(t)) for t in text_sentences], batch_first=True) + + labels = torch.tensor(labels) + return video_frames_padded, text_sentences_padded, labels def fine_grained_retrieval(train_loader, config): @@ -64,8 +77,6 @@ def fine_grained_retrieval(train_loader, config): for video_frames, text_sentence, labels in train_loader: video_frames, text_sentence, labels = video_frames.to(device), text_sentence.to(device), labels.to(device) - logger.info(f"Processing text_sentence: {text_sentence}") - # Extract features enhanced_text_features = text_extractor.extract_features(text_sentence, device) enhanced_video_features = video_extractor.extract_features(video_frames, device) @@ -88,48 +99,9 @@ def fine_grained_retrieval(train_loader, config): return detector -def pad_sequence(sequences, batch_first=False, padding_value=0.0): - """Pad a list of sequences to the same length.""" - max_size = sequences[0].size() - trailing_dims = max_size[1:] - max_len = max([s.size(0) for s in sequences]) - if batch_first: - out_dims = (len(sequences), max_len) + trailing_dims - else: - out_dims = (max_len, len(sequences)) + trailing_dims - - out_tensor = sequences[0].new_full(out_dims, padding_value) - for i, tensor in enumerate(sequences): - length = tensor.size(0) - # Use index notation to prevent duplicate references to the tensor - if batch_first: - out_tensor[i, :length, ...] = tensor - else: - out_tensor[:length, i, ...] = tensor - - return out_tensor - - -def collate_fn(batch): - video_frames, text_sentences, labels = zip(*batch) - - # Tokenize text sentences - if config.fine_grained_text_extractor == 'bert': - text_sentences = [tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)[ - 'input_ids'].squeeze(0) for text in text_sentences] - elif config.fine_grained_text_extractor == 'clip': - text_sentences = [tokenizer(text=[text], return_tensors='pt')['input_ids'].squeeze(0) for text in - text_sentences] - - # Convert video frames and text sentences to numpy arrays first - video_frames_padded = pad_sequence([torch.tensor(np.array(v)) for v in video_frames], batch_first=True) - text_sentences_padded = pad_sequence(text_sentences, batch_first=True) - labels = torch.tensor(labels) - return video_frames_padded, text_sentences_padded, labels - - def main(): logger.info("Loading configuration.") + config = Config() charades_sta = CharadesSTA( video_dir=CHARADES_VIDEOS_DIR, train_file=CHARADES_ANNOTATIONS_TRAIN,