Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create requirements.txt #42

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion INSTRUCTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from preparation.utils import save_vid_aud_txt

# Initialize video and audio data loaders
video_loader = AVSRDataLoader(modality="video", detector="retinaface", convert_gray=False)
video_loader = AVSRDataLoader(modality="visual", detector="retinaface", convert_gray=False)
audio_loader = AVSRDataLoader(modality="audio")

# Specify the file path to the data
Expand Down
4 changes: 2 additions & 2 deletions avg_ckpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def ensemble(args):
last = [
os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt")
for n in range(
args.trainer.max_epochs - 10,
args.trainer.max_epochs - 1,
args.trainer.max_epochs,
)
]
model_path = os.path.join(
args.exp_dir, args.exp_name, f"model_avg_10.pth"
args.exp_dir, args.exp_name, f"model_avg_1.pth"
)
torch.save(average_checkpoints(last), model_path)
return model_path
2 changes: 1 addition & 1 deletion datamodule/av_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def load_list(self, label_path):
def __getitem__(self, idx):
dataset_name, rel_path, input_length, token_id = self.list[idx]
path = os.path.join(self.root_dir, dataset_name, rel_path)
if self.modality == "video":
if self.modality == "visual":
video = load_video(path)
video = self.video_transform(video)
return {"input": video, "target": token_id}
Expand Down
10 changes: 5 additions & 5 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, cfg, detector="retinaface"):
self.modality = cfg.data.modality
if self.modality in ["audio", "audiovisual"]:
self.audio_transform = AudioTransform(subset="test")
if self.modality in ["video", "audiovisual"]:
if self.modality in ["visual", "audiovisual"]:
if detector == "mediapipe":
from preparation.detectors.mediapipe.detector import LandmarksDetector
from preparation.detectors.mediapipe.video_process import VideoProcess
Expand All @@ -27,12 +27,12 @@ def __init__(self, cfg, detector="retinaface"):
self.video_process = VideoProcess(convert_gray=False)
self.video_transform = VideoTransform(subset="test")

if cfg.data.modality in ["audio", "video"]:
if cfg.data.modality in ["audio", "visual"]:
from lightning import ModelModule
elif cfg.data.modality == "audiovisual":
from lightning_av import ModelModule
self.modelmodule = ModelModule(cfg)
self.modelmodule.model.load_state_dict(torch.load(cfg.pretrained_model_path, map_location=lambda storage, loc: storage))
self.modelmodule.model.load_state_dict(torch.load(cfg.pretrained_model_path, map_location=lambda storage, loc: storage), strict=False)
self.modelmodule.eval()


Expand All @@ -46,15 +46,15 @@ def forward(self, data_filename):
audio = audio.transpose(1, 0)
audio = self.audio_transform(audio)

if self.modality in ["video", "audiovisual"]:
if self.modality in ["visual", "audiovisual"]:
video = self.load_video(data_filename)
landmarks = self.landmarks_detector(video)
video = self.video_process(video, landmarks)
video = torch.tensor(video)
video = video.permute((0, 3, 1, 2))
video = self.video_transform(video)

if self.modality == "video":
if self.modality == "visual":
with torch.no_grad():
transcript = self.modelmodule(video)
elif self.modality == "audio":
Expand Down
6 changes: 3 additions & 3 deletions lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, cfg):
self.cfg = cfg
if self.cfg.data.modality == "audio":
self.backbone_args = self.cfg.model.audio_backbone
elif self.cfg.data.modality == "video":
elif self.cfg.data.modality == "visual":
self.backbone_args = self.cfg.model.visual_backbone

self.text_transform = TextTransform()
Expand All @@ -36,9 +36,9 @@ def __init__(self, cfg):
self.model.encoder.frontend.load_state_dict(tmp_ckpt)
elif self.cfg.transfer_encoder:
tmp_ckpt = {k.replace("encoder.", ""): v for k, v in ckpt.items() if k.startswith("encoder.")}
self.model.encoder.load_state_dict(tmp_ckpt, strict=True)
self.model.encoder.load_state_dict(tmp_ckpt, strict=False)
else:
self.model.load_state_dict(ckpt)
self.model.load_state_dict(ckpt, strict=False)

def configure_optimizers(self):
optimizer = torch.optim.AdamW([{"name": "model", "params": self.model.parameters(), "lr": self.cfg.optimizer.lr}], weight_decay=self.cfg.optimizer.weight_decay, betas=(0.9, 0.98))
Expand Down
3 changes: 2 additions & 1 deletion preparation/asr_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
f = open(label_filename, "w")

# Load ASR model
model = whisper.load_model("medium.en", device="cuda")
model = whisper.load_model("large-v3", device="cuda")

# Transcription
for filename in tqdm(files_to_process):
Expand All @@ -90,6 +90,7 @@
continue

# Write transcript to a text file
print(transcript)
if transcript:
with open(dst_filename, "w") as k:
k.write(f"{transcript}")
Expand Down
4 changes: 2 additions & 2 deletions preparation/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class AVSRDataLoader:
def __init__(self, modality, detector="retinaface", convert_gray=True):
self.modality = modality
if modality == "video":
if modality == "visual":
if detector == "retinaface":
from detectors.retinaface.detector import LandmarksDetector
from detectors.retinaface.video_process import VideoProcess
Expand All @@ -32,7 +32,7 @@ def load_data(self, data_filename, landmarks=None, transform=True):
audio, sample_rate = self.load_audio(data_filename)
audio = self.audio_process(audio, sample_rate)
return audio
if self.modality == "video":
if self.modality == "visual":
video = self.load_video(data_filename)
if not landmarks:
landmarks = self.landmarks_detector(video)
Expand Down
2 changes: 1 addition & 1 deletion preparation/preprocess_lrs2lrs3.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
# Load Data
args.data_dir = os.path.normpath(args.data_dir)
vid_dataloader = AVSRDataLoader(
modality="video", detector=args.detector, convert_gray=False
modality="visual", detector=args.detector, convert_gray=False
)
aud_dataloader = AVSRDataLoader(modality="audio")

Expand Down
2 changes: 1 addition & 1 deletion preparation/preprocess_vox2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@

# Load data
vid_dataloader = AVSRDataLoader(
modality="video", detector=args.detector, convert_gray=False
modality="visual", detector=args.detector, convert_gray=False
)
aud_dataloader = AVSRDataLoader(modality="audio")

Expand Down
14 changes: 14 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"pip<24.1"
openai-whisper
torch==2.0.1
torchvision==0.15.2
torchaudio==2.0.2
hydra-core==1.3.2
"numpy<1.24"
pytorch-lightning==1.5.10
omegaconf==2.2.0
fairseq==0.10.0
tensorboardX
tensorrt
av
ffmpeg-python
Loading