diff --git a/README.md b/README.md
index 30cd449..3ec4c71 100644
--- a/README.md
+++ b/README.md
@@ -1,15 +1,23 @@
-# 【CVPR'2023】Bidirectional Cross-Modal Knowledge Exploration for Video Recognition with Pre-trained Vision-Language Models
+
【CVPR'2023】🚴 BIKE: Bidirectional Cross-Modal Knowledge Exploration for Video Recognition with Pre-trained Vision-Language Models
[![Conference](http://img.shields.io/badge/CVPR-2023-6790AC.svg)](https://cvpr.thecvf.com/)
[![Paper](http://img.shields.io/badge/Paper-arxiv.2301.00182-b31b1b.svg)](https://arxiv.org/abs/2301.00182)
-
-This is the official implementation of our **BIKE**: [Cross-Modal Knowledge Exploration for Video Recognition with Pre-trained Vision-Language Models](https://arxiv.org/abs/2301.00182).
+
+
+[Wenhao Wu](https://whwu95.github.io/)1,2, [Xiaohan Wang](https://scholar.google.com/citations?user=iGA10XoAAAAJ&hl=en)3, [Haipeng Luo]()4, [Jingdong Wang](https://jingdongwang2017.github.io/)2, [Yi Yang](https://scholar.google.com/citations?user=RMSuNFwAAAAJ&hl=en)3, [Wanli Ouyang](https://wlouyang.github.io/)4,1
+
+
+1[The University of Sydney](https://www.sydney.edu.au/), 2[Baidu](https://vis.baidu.com/#/), 3[ZJU](https://www.zju.edu.cn/english/), 4[UCAS](https://english.ucas.ac.cn/), 5[Shanghai AI Lab](https://www.shlab.org.cn/)
+
+
+
+
+
-I am currently traveling and may not be able to open-source the code until May.
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bidirectional-cross-modal-knowledge/action-recognition-in-videos-on-ucf101)](https://paperswithcode.com/sota/action-recognition-in-videos-on-ucf101?p=bidirectional-cross-modal-knowledge)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bidirectional-cross-modal-knowledge/action-recognition-in-videos-on-activitynet)](https://paperswithcode.com/sota/action-recognition-in-videos-on-activitynet?p=bidirectional-cross-modal-knowledge)
@@ -22,21 +30,269 @@ I am currently traveling and may not be able to open-source the code until May.
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bidirectional-cross-modal-knowledge/zero-shot-action-recognition-on-activitynet)](https://paperswithcode.com/sota/zero-shot-action-recognition-on-activitynet?p=bidirectional-cross-modal-knowledge)
+This is the official implementation of our 🚴 **BIKE** (BIdirectional Knowledge Exploration), which leverages cross-modal bridge to enhance video recognition by exploring bidirectional knowledge.
+
+
+📣 I also have other cross-modal video projects that may interest you ✨.
+
+
+> [**Revisiting Classifier: Transferring Vision-Language Models for Video Recognition**](https://arxiv.org/abs/2207.01297)
+> Accepted by AAAI 2023 | [[Text4Vis Code]](https://github.com/whwu95/Text4Vis)
+> Wenhao Wu, Zhun Sun, Wanli Ouyang
+
+
+> [**Cap4Video: What Can Auxiliary Captions Do for Text-Video Retrieval?**](https://arxiv.org/abs/2301.00184)
+> Accepted by CVPR 2023 as 🌟Highlight🌟 | [[Cap4Video Code]](https://github.com/whwu95/Cap4Video)
+> Wenhao Wu, Haipeng Luo, Bo Fang, Jingdong Wang, Wanli Ouyang
+
+
+
+## News
+- 🚀 **Efficient Training**: Almost all models can be trained using **8 NVIDIA 32G V100 GPUs**. Especially, we can train ViT-L/14 (336) backbone using **8 GPUs** and achieve **88.3%** Top-1 accuracy on Kinetics-400 dataset!
+- `Apr 20, 2023`: Main training codes have been released, including single-node/multi-node multi-GPU distributed training. Thanks for your star 😝.
+- `Feb 28, 2023`: 🎉Our **BIKE** has been accepted by **CVPR-2023**.
+
+## Overview
+🚴**BIKE** explores bidirectional cross-modal knowledge from the pre-trained vision-language model (e.g., CLIP) to introduce auxiliary attributes and category-dependent temporal saliency for improved video recognition.
+
+![BIKE](docs/bike.png)
+
+
+## Content
+- [Prerequisites](#prerequisites)
+- [Data Preparation](#data-preparation)
+- [Model Zoo](#model-zoo)
+- [Training](#training)
+- [Testing](#testing)
+- [BibTeX & Citation](#bibtex)
+- [Acknowledgment](#acknowledgment)
+
+
+## Prerequisites
+
+The code is built with following libraries.
+
+- [PyTorch](https://pytorch.org/) >= 1.8
+- RandAugment
+- pprint
+- tqdm
+- dotmap
+- yaml
+- csv
+- Optional: decord (for on-the-fly video training)
+- Optional: torchnet (for mAP evaluation on ActivityNet)
+
+
+
+
+## Data Preparation
+
+
+
+### Video Loader
+
+**(Recommend)** To train all of our models, we extract videos into frames for fast reading. Please refer to [MVFNet](https://github.com/whwu95/MVFNet/blob/main/data_process/DATASETS.md) repo for the detaied guide of dataset processing.
+The annotation file is a text file with multiple lines, and each line indicates the directory to frames of a video, total frames of the video and the label of a video, which are split with a whitespace.
+Example of annotation
+
+```sh
+abseiling/-7kbO0v4hag_000107_000117 300 0
+abseiling/-bwYZwnwb8E_000013_000023 300 0
+```
+
+
+(Optional) We can also decode the videos in an online fashion using [decord](https://github.com/dmlc/decord). This manner should work but are not tested. All of the models offered have been trained using offline frames.
+Example of annotation
+
+```sh
+ abseiling/-7kbO0v4hag_000107_000117.mp4 0
+ abseiling/-bwYZwnwb8E_000013_000023.mp4 0
+```
+
+
+
+### Annotation
+Annotation information consists of two parts: video label, and category description.
+
+- Video Label: As mentioned above, this part is same as the traditional video recognition. Please refer to [lists/k400/kinetics_rgb_train_se320.txt](lists/k400/kinetics_rgb_train_se320.txt) for the format.
+- Category Description: We also need a textual description for each video category. Please refer to [lists/k400/kinetics_400_labels.csv](lists/k400/kinetics_400_labels.csv) for the format.
+
+
+
+
+## 📱 Model Zoo
+
+Here we provide some off-the-shelf pre-trained checkpoints of our models in the following tables.
+
+- *# Inference Views = # Temporal clips x # Spatial crops*
+### Kinetics-400
+
+| Architecture | Input | Views | Top-1(%) | checkpoint | Train log| config|
+|:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:|:--------------:|
+| ViT-B/32 | 8x2242 | 4x3 |80.5 | [Github](https://github.com/whwu95/Text4Vis/releases/download/v1/k400-vitb-32-f16.pt) | [log](exps/k400/ViT-B/32/f16/log.txt) | [config](configs/k400/k400_train_rgb_vitb-32-f16.yaml) |
+| ViT-B/16 | 8x2242 | 4x3 | 84.0 | [Github]() | [log](exps/k400/ViT-B/16/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitb-16-f8.yaml) |
+| ViT-L/14* | 8x2242 | 4x3| 87.4 | [OneDrive]() | [log](exps/k400/ViT-L/14/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitl-14-f8.yaml) |
+| ViT-L/14 | 16x2242 | 4x3| 88.1| [OneDrive]() | [log](exps/k400/ViT-L/14/f16/log.txt) | [config](configs/k400/k400_train_rgb_vitl-14-f16.yaml) |
+| ViT-L/14 | 8x3362 | 4x3 | 88.3 | [OneDrive]() | [log](exps/k400/ViT-L/14-336px/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitl-14-336-f8.yaml) |
+| ViT-L/14 | 16x3362 | 4x3 | 88.7 | [OneDrive]() | [log](exps/k400/ViT-L/14-336px/f16/log.txt) | [config](configs/k400/k400_train_rgb_vitl-14-336-f16.yaml) |
+
+
+*Note: * indicates that in the paper, a ViT-L model with 8 frames was used for zero-shot evaluation on UCF, HMDB, ActivityNet, and Kinetics-600.*
+
+### Untrimmed Video Recognition: ActivityNet
+| Architecture | Input | Views | Top-1 (%) | mAP (%) | checkpoint | Train log| config|
+|:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:|:--------------:|:--------------:|
+| ViT-L/14 | 16x2242 | 4x1| 94.4 | 96.3| [OneDrive]() | [log](exps/anet/ViT-L/14/f16/log.txt) | [config](configs/anet/anet_k400_finetune.yaml) |
+| ViT-L/14 | 16x3362 | 4x1| 94.7 | 96.3 | [OneDrive]() | [log](exps/anet/ViT-L/14-336px/f16/log.txt) | [config](configs/anet/anet_k400_finetune_336.yaml) |
+
+### Multi-label Action Recognition: Charades
+| Architecture | Input | Views | mAP (%) | checkpoint | Train log| config|
+|:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:|:--------------:|
+| ViT-L/14 | 16x3362 | 4x1| 50.4 | [OneDrive]() | [log](exp_nce/charades/ViT-L/14-336px/16f/log.txt) | [config](configs/charades/charades_k400_finetune_336.yaml) |
-## 📣 Updates
-- [x] **[Feb 28, 2023]** 🎉Our **BIKE** has been accepted by **CVPR-2023**.
+### UCF-101
+| Architecture | Input | Views | Top-1 (%) | checkpoint | Train log| config|
+|:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:|:--------------:|
+| ViT-L/14 | 16x2242 | 1x1 | 98.8 | [OneDrive]() | [log](exps/ucf101/ViT-L/14/f16/log.txt) | [config](configs/ucf101/ucf_k400_finetune.yaml) |
+| ViT-L/14 | 16x3362 | 1x1 | 98.6 | [OneDrive]()| [log](exps/ucf101/ViT-L/14-336px/f16/log.txt) | [config](configs/ucf101/ucf_k400_finetune_336.yaml) |
-## 📌 Bibtex
-If you find this repository useful, please star🌟 this repo and cite📑 our paper:
+### HMDB-51
+| Architecture | Input | Views | Top-1 (%) | checkpoint | Train log| config|
+|:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:|:--------------:|
+| ViT-L/14 | 16x2242 | 1x1 | 82.2 | [OneDrive]() | [log](exps/hmdb51/ViT-L/14/f16/log.txt) | [config](configs/hmdb51/hmdb_k400_finetune.yaml) |
+| ViT-L/14 | 16x3362 | 1x1 | 83.1 | [OneDrive]()| [log](exps/hmdb51/ViT-L/14-336px/f16/log.txt) | [config](configs/hmdb51/hmdb_k400_finetune_336.yaml) |
+
+
+
+
+## 🚀 Training
+This implementation supports Multi-GPU `DistributedDataParallel` training, which is faster and simpler than `DataParallel` training.
+
+1. **Single Machine**: To train our model on Kinetics-400 with 8 GPUs in *Single Machine*, you can run:
+```sh
+# For example, we train the 8 Frames ViT-B/32 video model (i.e., video branch).
+sh scripts/run_train.sh configs/k400/k400_train_rgb_vitb-32-f8.yaml
+
+# We train the attributes branch.
+sh scripts/run_train_attributes.sh configs/k400/k400_train_rgb_vitb-32-f8.yaml
+```
+
+2. Mulitple Machines: We also provide the script to train larger model with Mulitple Machines (e.g., 2 nodes have 16 GPUs).
+
+```sh
+# For example, we train the 8 Frames ViT-L/14-336 with 2 machines as follows:
+# For first machine, you need to set the ip of your first machine as the --master_addr, --nnodes is 2.
+# Compared with the Single-Machine training script, only one node_id needs to be added.
+sh scripts/run_train_multinodes.sh configs/k400/configs/k400/k400_train_rgb_vitl-14-336-f8.yaml 0
+
+# For second machine, --master_addr is still the ip of your first machine
+sh scripts/run_train_multinodes.sh configs/k400/configs/k400/k400_train_rgb_vitl-14-336-f8.yaml 1
+```
+
+
+
+3. Few-shot Recognition: To train our model under Few-shot scenario, you just need to add one line in the general config file.
+
+```sh
+# You can refer to config/k400/k400_few_shot.yaml
+data:
+ ... # general configurations
+ shot: 2 # i.e., 2-shot setting
+```
+
+
+
+## ⚡ Testing
+We support single-view validation (default) and multi-view (4x3 views) validation.
+
+```sh
+# The testing command for obtaining top-1/top-5 accuracy.
+sh scripts/run_test.sh Your-Config.yaml Your-Trained-Model.pt
+
+# The command for zero-shot evaluation is similar.
+sh scripts/run_test_zeroshot.sh Your-Config.yaml Your-Trained-Model.pt
+```
+
+We provide more examples of testing commands below.
+
+
+General / Few-shot Video Recognition
+
+```sh
+# Efficient Setting: Single view evaluation.
+# E.g., ViT-L/14 8 Frames on Kinetics-400. You should get around 86.5% top-1 accuracy.
+sh scripts/run_test.sh configs/k400/k400_train_rgb_vitl-14-f8.yaml exps/k400/ViT-L/14/8f/model_best.pt
+
+# Accurate Setting: Multi-view evalition (4clipsx3crops).
+# You should get around 87.4% top-1 accuracy.
+sh scripts/run_test.sh configs/k400/k400_train_rgb_vitl-14-f8.yaml exps/k400/ViT-L/14/8f/model_best.pt --test_crops 3 --test_clips 4
+
+# Test the Charades dataset using the mAP metric. You should achieve around 50.4 mAP.
+sh scripts/run_test_charades.sh configs/charades/charades_k400_finetune_336.yaml exps/charades/ViT-L/14-336px/16f/model_best.pt --test_crops 1 --test_clips 4
+
+# Test the ActivityNet dataset using top1 and mAP metric. You should achieve around 96.1 mAP.
+sh scripts/run_test_anet.sh configs/anet/anet_k400_finetune_336.yaml exps/anet/ViT-L/14-336px/16f/model_best.pt
```
+
+
+
+Zero-shot Evaluation
+
+
+We use the Kinetics-400 pre-trained model (e.g., [ViT-L/14 with 8 frames](configs/k400/k400_train_rgb_vitl-14-f8.yaml)) to perform cross-dataset zero-shot evaluation, i.e., UCF101, HMDB51, ActivityNet, Kinetics-600.
+
+
+- Half-classes Evaluation: A traditional evaluation protocol involves selecting half of the test dataset's classes, repeating the process ten times, and reporting the mean accuracy with a standard deviation of ten times.
+
+
+- Full-classes Evaluation: Perform evaluation on the entire dataset.
+
+```sh
+# On ActivityNet: reporting the half-classes and full-classes results
+# Half-classes: 86.18 ± 1.05, Full-classes: 80.04
+sh scripts/run_test_zeroshot.sh configs/anet/anet_zero_shot.yaml exps/k400/ViT-L/14/f8/last_model.pt
+
+# On UCF101: reporting the half-classes and full-classes results
+# Half-classes: 86.63 ± 3.4, Full-classes: 80.83
+sh scripts/run_test_zeroshot.sh configs/ucf101/ucf_zero_shot.yaml exps/k400/ViT-L/14/f8/last_model.pt
+
+# On HMDB51: reporting the half-classes and full-classes results
+# Half-classes: 61.37 ± 3.68, Full-classes: 52.75
+sh scripts/run_test_zeroshot.sh configs/hmdb51/hmdb_zero_shot.yaml exps/k400/ViT-L/14/f8/last_model.pt
+
+# On Kinetics-600: manually calculating the mean accuracy with standard deviation of three splits.
+# Split1: 70.14, Split2: 68.31, Split3: 67.15
+# Average: 68.53 ± 1.23
+sh scripts/run_test.sh configs/k600/k600_zero_shot_split1.yaml exp/k400s/ViT-L/14/f8/last_model.pt
+sh scripts/run_test.sh configs/k600/k600_zero_shot_split2.yaml exp/k400s/ViT-L/14/f8/last_model.pt
+sh scripts/run_test.sh configs/k600/k600_zero_shot_split3.yaml exp/k400s/ViT-L/14/f8/last_model.pt
+```
+
+
+
+## 📌 BibTeX & Citation
+
+If you use our code in your research or wish to refer to the baseline results, please use the following BibTeX entry.
+
+
+```bibtex
@inproceedings{bike,
title={Bidirectional Cross-Modal Knowledge Exploration for Video Recognition with Pre-trained Vision-Language Models},
author={Wu, Wenhao and Wang, Xiaohan and Luo, Haipeng and Wang, Jingdong and Yang, Yi and Ouyang, Wanli},
- booktitle=CVPR,
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023}
}
```
+
+
+## 🎗️ Acknowledgement
+
+This repository is built based on [Text4Vis](https://github.com/whwu95/Text4Vis) and [CLIP](https://github.com/openai/CLIP). Sincere thanks to their wonderful works.
+
+
+## 👫 Contact
+For any question, please file an issue or contact [Wenhao Wu](https://whwu95.github.io/)
diff --git a/configs/k400/k400_train_rgb_vitb-16-f8.yaml b/configs/k400/k400_train_rgb_vitb-16-f8.yaml
new file mode 100644
index 0000000..053cef6
--- /dev/null
+++ b/configs/k400/k400_train_rgb_vitb-16-f8.yaml
@@ -0,0 +1,49 @@
+resume:
+pretrain:
+seed: 1024
+data:
+ dataset: k400
+ modality: RGB
+ num_segments: 8
+ seg_length: 1
+ batch_size: 32
+ workers: 4
+ num_classes: 400
+ image_tmpl: 'img_{:05d}.jpg'
+ train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
+ train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
+ val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
+ val_list: lists/k400/kinetics_rgb_val_se320.txt
+ label_list: 'lists/k400/kinetics_400_labels.csv'
+ input_size: 224
+ random_shift: True
+ output_path: exps
+network:
+ arch: ViT-B/16 #ViT-B/32 ViT-B/16
+ init: True
+ tm: False
+ drop_out: 0.0
+ emb_dropout: 0.0
+ type: clip_k400
+ sim_header: Transf # Transf None
+ interaction: VCS # DP VCS
+ joint_st: False
+ drop: 0
+ fix_text: True
+ fix_video: False
+solver:
+ type: cosine
+ epochs: 30
+ start_epoch: 0
+ epoch_offset: 0
+ optim: adamw
+ lr: 5.e-5
+ lr_warmup_step: 5
+ weight_decay: 0.2
+ loss_type: NCE
+ evaluate: False
+ clip_ratio: 0.1
+ grad_accumulation_steps: 1
+logging:
+ print_freq: 10
+ eval_freq: 1
\ No newline at end of file
diff --git a/configs/k400/k400_train_rgb_vitl-14-336-f16.yaml b/configs/k400/k400_train_rgb_vitl-14-336-f16.yaml
new file mode 100644
index 0000000..076fdbe
--- /dev/null
+++ b/configs/k400/k400_train_rgb_vitl-14-336-f16.yaml
@@ -0,0 +1,49 @@
+resume:
+pretrain:
+seed: 1024
+data:
+ dataset: k400
+ modality: RGB
+ num_segments: 16
+ seg_length: 1
+ batch_size: 8 # mem: 26G
+ workers: 4
+ num_classes: 400
+ image_tmpl: 'img_{:05d}.jpg'
+ train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
+ train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
+ val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
+ val_list: lists/k400/kinetics_rgb_val_se320.txt
+ label_list: 'lists/k400/kinetics_400_labels.csv'
+ input_size: 336
+ random_shift: True
+ output_path: exps
+network:
+ arch: ViT-L/14-336px #ViT-B/32 ViT-B/16
+ init: True
+ tm: False
+ drop_out: 0.0
+ emb_dropout: 0.0
+ type: clip_k400
+ sim_header: Transf # Transf None
+ interaction: VCS # DP VCS
+ joint_st: False
+ drop: 0
+ fix_text: True
+ fix_video: False
+solver:
+ type: cosine
+ epochs: 20
+ start_epoch: 0
+ epoch_offset: 0
+ optim: adamw
+ lr: 5.e-5
+ lr_warmup_step: 5
+ weight_decay: 0.2
+ loss_type: NCE
+ evaluate: False
+ clip_ratio: 0.1
+ grad_accumulation_steps: 1
+logging:
+ print_freq: 10
+ eval_freq: 1
\ No newline at end of file
diff --git a/configs/k400/k400_train_rgb_vitl-14-336-f32.yaml b/configs/k400/k400_train_rgb_vitl-14-336-f32.yaml
new file mode 100644
index 0000000..26f6004
--- /dev/null
+++ b/configs/k400/k400_train_rgb_vitl-14-336-f32.yaml
@@ -0,0 +1,49 @@
+resume:
+pretrain:
+seed: 1024
+data:
+ dataset: k400
+ modality: RGB
+ num_segments: 32
+ seg_length: 1
+ batch_size: 5 # mem: 31G
+ workers: 4
+ num_classes: 400
+ image_tmpl: 'img_{:05d}.jpg'
+ train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
+ train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
+ val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
+ val_list: lists/k400/kinetics_rgb_val_se320.txt
+ label_list: 'lists/k400/kinetics_400_labels.csv'
+ input_size: 336
+ random_shift: True
+ output_path: exps
+network:
+ arch: ViT-L/14-336px #ViT-B/32 ViT-B/16
+ init: True
+ tm: False # False tsm tokent1d tokenshift
+ drop_out: 0.0
+ emb_dropout: 0.0
+ type: clip_k400
+ sim_header: Transf # Transf None
+ interaction: VCS # DP VCS
+ joint_st: False
+ drop: 0
+ fix_text: True
+ fix_video: False
+solver:
+ type: cosine
+ epochs: 20
+ start_epoch: 0
+ epoch_offset: 0
+ optim: adamw
+ lr: 5.e-5
+ lr_warmup_step: 5
+ weight_decay: 0.2
+ loss_type: NCE
+ evaluate: False
+ clip_ratio: 0.1
+ grad_accumulation_steps: 1
+logging:
+ print_freq: 10
+ eval_freq: 1
\ No newline at end of file
diff --git a/configs/k400/k400_train_rgb_vitl-14-336-f8.yaml b/configs/k400/k400_train_rgb_vitl-14-336-f8.yaml
new file mode 100644
index 0000000..2d720cc
--- /dev/null
+++ b/configs/k400/k400_train_rgb_vitl-14-336-f8.yaml
@@ -0,0 +1,49 @@
+resume:
+pretrain:
+seed: 1024
+data:
+ dataset: k400
+ modality: RGB
+ num_segments: 8
+ seg_length: 1
+ batch_size: 16
+ workers: 4
+ num_classes: 400
+ image_tmpl: 'img_{:05d}.jpg'
+ train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
+ train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
+ val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
+ val_list: lists/k400/kinetics_rgb_val_se320.txt
+ label_list: 'lists/k400/kinetics_400_labels.csv'
+ input_size: 336
+ random_shift: True
+ output_path: exps
+network:
+ arch: ViT-L/14-336px #ViT-B/32 ViT-B/16
+ init: True
+ tm: False
+ drop_out: 0.0
+ emb_dropout: 0.0
+ type: clip_k400
+ sim_header: Transf # Transf None
+ interaction: VCS # DP VCS
+ joint_st: False
+ drop: 0
+ fix_text: True
+ fix_video: False
+solver:
+ type: cosine
+ epochs: 20
+ start_epoch: 0
+ epoch_offset: 0
+ optim: adamw
+ lr: 5.e-5
+ lr_warmup_step: 5
+ weight_decay: 0.2
+ loss_type: NCE
+ evaluate: False
+ clip_ratio: 0.1
+ grad_accumulation_steps: 1
+logging:
+ print_freq: 10
+ eval_freq: 1
\ No newline at end of file
diff --git a/configs/k400/k400_train_rgb_vitl-14-f16.yaml b/configs/k400/k400_train_rgb_vitl-14-f16.yaml
new file mode 100644
index 0000000..93eb3c7
--- /dev/null
+++ b/configs/k400/k400_train_rgb_vitl-14-f16.yaml
@@ -0,0 +1,49 @@
+resume:
+pretrain:
+seed: 1024
+data:
+ dataset: k400
+ modality: RGB
+ num_segments: 16
+ seg_length: 1
+ batch_size: 32
+ workers: 4
+ num_classes: 400
+ image_tmpl: 'img_{:05d}.jpg'
+ train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
+ train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
+ val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
+ val_list: lists/k400/kinetics_rgb_val_se320.txt
+ label_list: 'lists/k400/kinetics_400_labels.csv'
+ input_size: 224
+ random_shift: True
+ output_path: exps
+network:
+ arch: ViT-L/14 #ViT-B/32 ViT-B/16
+ init: True
+ tm: False
+ drop_out: 0.0
+ emb_dropout: 0.0
+ type: clip_k400
+ sim_header: Transf # Transf None
+ interaction: VCS # DP VCS
+ joint_st: False
+ drop: 0
+ fix_text: True
+ fix_video: False
+solver:
+ type: cosine
+ epochs: 20
+ start_epoch: 0
+ epoch_offset: 0
+ optim: adamw
+ lr: 5.e-5
+ lr_warmup_step: 5
+ weight_decay: 0.2
+ loss_type: NCE
+ evaluate: False
+ clip_ratio: 0.1
+ grad_accumulation_steps: 1
+logging:
+ print_freq: 10
+ eval_freq: 1
\ No newline at end of file
diff --git a/configs/k400/k400_train_rgb_vitl-14-f8.yaml b/configs/k400/k400_train_rgb_vitl-14-f8.yaml
new file mode 100644
index 0000000..b5651fc
--- /dev/null
+++ b/configs/k400/k400_train_rgb_vitl-14-f8.yaml
@@ -0,0 +1,49 @@
+resume:
+pretrain:
+seed: 1024
+data:
+ dataset: k400
+ modality: RGB
+ num_segments: 8
+ seg_length: 1
+ batch_size: 32
+ workers: 4
+ num_classes: 400
+ image_tmpl: 'img_{:05d}.jpg'
+ train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
+ train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
+ val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
+ val_list: lists/k400/kinetics_rgb_val_se320.txt
+ label_list: 'lists/k400/kinetics_400_labels.csv'
+ input_size: 224
+ random_shift: True
+ output_path: exps
+network:
+ arch: ViT-L/14 #ViT-B/32 ViT-B/16
+ init: True
+ tm: False
+ drop_out: 0.0
+ emb_dropout: 0.0
+ type: clip_k400
+ sim_header: Transf # Transf None
+ interaction: VCS # DP VCS
+ joint_st: False
+ drop: 0
+ fix_text: True
+ fix_video: False
+solver:
+ type: cosine
+ epochs: 20
+ start_epoch: 0
+ epoch_offset: 0
+ optim: adamw
+ lr: 5.e-5
+ lr_warmup_step: 5
+ weight_decay: 0.2
+ loss_type: NCE
+ evaluate: False
+ clip_ratio: 0.1
+ grad_accumulation_steps: 1
+logging:
+ print_freq: 10
+ eval_freq: 1
\ No newline at end of file
diff --git a/docs/bike.png b/docs/bike.png
new file mode 100644
index 0000000..09b8ddb
Binary files /dev/null and b/docs/bike.png differ
diff --git a/scripts/run_test.sh b/scripts/run_test.sh
new file mode 100644
index 0000000..faf1d58
--- /dev/null
+++ b/scripts/run_test.sh
@@ -0,0 +1,12 @@
+#!/usr/bin/env bash
+if [ -f $1 ]; then
+ config=$1
+else
+ echo "need a config file"
+ exit
+fi
+
+weight=$2
+
+python -m torch.distributed.launch --master_port 1239 --nproc_per_node=8 \
+ test.py --config ${config} --weights ${weight} ${@:3}
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..af019ac
--- /dev/null
+++ b/test.py
@@ -0,0 +1,226 @@
+import os
+import argparse
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from torch.nn.parallel import DistributedDataParallel
+import torch.distributed as dist
+import torch.backends.cudnn as cudnn
+import torchvision
+import torch.nn.functional as F
+import time
+from utils.utils import init_distributed_mode, AverageMeter, reduce_tensor, accuracy
+import clip
+
+import yaml
+from dotmap import DotMap
+from datasets.video import Video_dataset
+from datasets.transforms import GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor, GroupNormalize, GroupOverSample, GroupFullResSample
+from modules.video_clip import video_header
+from modules.text_prompt import text_prompt
+
+
+def get_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, help='global config file')
+ parser.add_argument('--weights', type=str, default=None)
+ parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument("--local_rank", type=int,
+ help='local rank for DistributedDataParallel')
+ parser.add_argument(
+ "--precision",
+ choices=["amp", "fp16", "fp32"],
+ default="amp",
+ help="Floating point precition."
+ )
+ parser.add_argument('--test_crops', type=int, default=1)
+ parser.add_argument('--test_clips', type=int, default=1)
+ parser.add_argument('--dense', default=False, action="store_true",
+ help='use dense sample for test as in Non-local I3D')
+ args = parser.parse_args()
+ return args
+
+def update_dict(dict):
+ new_dict = {}
+ for k, v in dict.items():
+ new_dict[k.replace('module.', '')] = v
+ return new_dict
+
+
+def main(args):
+ init_distributed_mode(args)
+
+ with open(args.config, 'r') as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+
+ config = DotMap(config)
+
+ device = "cpu"
+ if torch.cuda.is_available():
+ device = "cuda"
+ cudnn.benchmark = True
+
+ # get fp16 model and weight
+ model, clip_state_dict = clip.load(
+ config.network.arch,
+ device='cpu', jit=False,
+ internal_modeling=config.network.tm,
+ T=config.data.num_segments,
+ dropout=config.network.drop_out,
+ emb_dropout=config.network.emb_dropout,
+ pretrain=config.network.init,
+ joint_st= config.network.joint_st) # Must set jit=False for training ViT-B/32
+
+ video_head = video_header(
+ config.network.sim_header,
+ config.network.interaction,
+ clip_state_dict)
+
+ if args.precision == "amp" or args.precision == "fp32":
+ model = model.float()
+
+
+ input_mean = [0.48145466, 0.4578275, 0.40821073]
+ input_std = [0.26862954, 0.26130258, 0.27577711]
+
+ # rescale size
+ if 'something' in config.data.dataset:
+ scale_size = (240, 320)
+ else:
+ scale_size = 256 if config.data.input_size == 224 else config.data.input_size
+
+ # crop size
+ input_size = config.data.input_size
+
+ # control the spatial crop
+ if args.test_crops == 1: # one crop
+ cropping = torchvision.transforms.Compose([
+ GroupScale(scale_size),
+ GroupCenterCrop(input_size),
+ ])
+ elif args.test_crops == 3: # do not flip, so only 3 crops (left right center)
+ cropping = torchvision.transforms.Compose([
+ GroupFullResSample(
+ crop_size=input_size,
+ scale_size=scale_size,
+ flip=False)
+ ])
+ elif args.test_crops == 5: # do not flip, so only 5 crops
+ cropping = torchvision.transforms.Compose([
+ GroupOverSample(
+ crop_size=input_size,
+ scale_size=scale_size,
+ flip=False)
+ ])
+ elif args.test_crops == 10:
+ cropping = torchvision.transforms.Compose([
+ GroupOverSample(
+ crop_size=input_size,
+ scale_size=scale_size,
+ )
+ ])
+ else:
+ raise ValueError("Only 1, 3, 5, 10 crops are supported while we got {}".format(args.test_crops))
+
+
+ val_data = Video_dataset(
+ config.data.val_root, config.data.val_list, config.data.label_list,
+ random_shift=False, num_segments=config.data.num_segments,
+ modality=config.data.modality,
+ image_tmpl=config.data.image_tmpl,
+ test_mode=True,
+ transform=torchvision.transforms.Compose([
+ cropping,
+ Stack(roll=False),
+ ToTorchFormatTensor(div=True),
+ GroupNormalize(input_mean, input_std),
+ ]),
+ dense_sample=args.dense,
+ test_clips=args.test_clips)
+
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
+ val_loader = DataLoader(val_data,
+ batch_size=config.data.batch_size, num_workers=config.data.workers,
+ sampler=val_sampler, pin_memory=True, drop_last=False)
+
+ if os.path.isfile(args.weights):
+ checkpoint = torch.load(args.weights, map_location='cpu')
+ if dist.get_rank() == 0:
+ print('load model: epoch {}'.format(checkpoint['epoch']))
+
+ model.load_state_dict(update_dict(checkpoint['model_state_dict']))
+ video_head.load_state_dict(update_dict(checkpoint['fusion_model_state_dict']))
+ del checkpoint
+
+ if args.distributed:
+ model = DistributedDataParallel(model.cuda(), device_ids=[args.gpu], find_unused_parameters=True)
+ if config.network.sim_header != "None":
+ video_head = DistributedDataParallel(video_head.cuda(), device_ids=[args.gpu])
+
+ classes = text_prompt(val_data)
+ n_class = classes.size(0)
+
+ prec1 = validate(
+ val_loader, classes, device,
+ model, video_head, config, n_class, args.test_crops, args.test_clips)
+ return
+
+
+def validate(val_loader, classes, device, model, video_head, config, n_class, test_crops, test_clips):
+ top1 = AverageMeter()
+ top5 = AverageMeter()
+ model.eval()
+ video_head.eval()
+ proc_start_time = time.time()
+
+ with torch.no_grad():
+ text_inputs = classes.to(device)
+ cls_feature, text_features = model.module.encode_text(text_inputs, return_token=True)
+ for i, (image, class_id) in enumerate(val_loader):
+ batch_size = class_id.numel()
+ num_crop = test_crops
+
+ num_crop *= test_clips # 4 clips for testing when using dense sample
+
+ class_id = class_id.to(device)
+ n_seg = config.data.num_segments
+ image = image.view((-1, n_seg, 3) + image.size()[-2:])
+ b, t, c, h, w = image.size()
+ image_input = image.to(device).view(-1, c, h, w)
+ image_features = model.module.encode_image(image_input).view(b, t, -1)
+ cnt_time = time.time() - proc_start_time
+ similarity = video_head(image_features, text_features, cls_feature) # b dim
+ similarity = F.softmax(similarity, -1)
+ similarity = similarity.reshape(batch_size, num_crop, -1).mean(1) # bs dim
+
+ similarity = similarity.view(batch_size, -1, n_class).softmax(dim=-1)
+ similarity = similarity.mean(dim=1, keepdim=False)
+
+ prec = accuracy(similarity, class_id, topk=(1, 5))
+ prec1 = reduce_tensor(prec[0])
+ prec5 = reduce_tensor(prec[1])
+
+ top1.update(prec1.item(), class_id.size(0))
+ top5.update(prec5.item(), class_id.size(0))
+
+ if i % config.logging.print_freq == 0 and dist.get_rank() == 0:
+ runtime = float(cnt_time) / (i+1) / (batch_size * dist.get_world_size())
+ print(
+ ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t'
+ 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
+ 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
+ i, len(val_loader), runtime=runtime, top1=top1, top5=top5)))
+
+ if dist.get_rank() == 0:
+ print('-----Evaluation is finished------')
+ print('Overall Prec@1 {:.03f}% Prec@5 {:.03f}%'.format(top1.avg, top5.avg))
+ return top1.avg
+
+if __name__ == '__main__':
+ args = get_parser()
+ main(args)
+