-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmake_combined_category_annotations.py
80 lines (61 loc) · 2.08 KB
/
make_combined_category_annotations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import sys
import json
from pathlib import Path
from shiprsimagenet import ShipRSImageNet
new_class_names = {
1: 'Other Ship',
2: 'Warship',
3: 'Other Merchant',
4: 'Container Ship',
5: 'Cargo Ship',
6: 'Barge',
7: 'Fishing Vessel',
8: 'Oil Tanker',
9: 'Motorboat',
10: 'Dock',
}
class_id_mappings = {
1: 1,
**{x: 2 for x in range(2, 37)},
37: 3,
38: 4,
39: 3,
40: 5,
41: 6,
**{x: 3 for x in range(42,46)},
46: 7,
47: 8,
48: 3,
49: 9,
50: 10,
}
new_coco_categories = [ { 'id': id, 'name': name, 'supercategory': name } for id, name in new_class_names.items() ]
def main():
if len(sys.argv) < 2:
print("Usage: python make_combined_category_annotations.py <path to ShipRSImageNet dataset>")
sys.exit(1)
if not os.path.exists(sys.argv[1]):
print("Path does not exist")
sys.exit(1)
if not os.path.isdir(sys.argv[1]):
print("Path to dataset is not a directory")
sys.exit(1)
print("Loading dataset...")
dataset = ShipRSImageNet(sys.argv[1])
train_annotations_path = Path(dataset.coco_root_dir) / dataset.get_coco_annotation_file_name('train')
val_annotations_path = Path(dataset.coco_root_dir) / dataset.get_coco_annotation_file_name('val')
print("Making combined category annotations...")
make_combined_category_annotations(train_annotations_path)
make_combined_category_annotations(val_annotations_path)
print(f"Saved combined annotations in directory {dataset.coco_root_dir}")
def make_combined_category_annotations(annotation_file: Path):
with annotation_file.open('r') as f:
annotations = json.load(f)
annotations['categories'] = new_coco_categories
for annotation in annotations['annotations']:
annotation['category_id'] = class_id_mappings[annotation['category_id']]
with annotation_file.parent.joinpath(f"{annotation_file.stem}_combined_categories.json").open('w') as f:
json.dump(annotations, f)
if __name__ == "__main__":
main()