-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_index.py
37 lines (29 loc) · 1.19 KB
/
build_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from PIL import Image
import requests
import numpy as np
from transformers import AutoProcessor, CLIPVisionModelWithProjection
import os
import glob
from rich.progress import track
import torch
from src.clip_index import CLIPIndex
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CLIPVisionModelWithProjection.from_pretrained("./clip_epochs1_bz364_lr364")
model.to(device)
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
index = CLIPIndex()
index.create_index("clip_index")
image_paths = glob.glob('data/images/train/*.jpg', recursive=True)
for image_path in track(image_paths, description="Processing..."):
image_id = int(image_path.split('.')[0].split('/')[-1:][0])
with Image.open(image_path) as image:
try:
inputs = processor(images=image, return_tensors="pt")
except:
print(str(image_id), " is corrupted\n")
with torch.no_grad():
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
image_embeds = outputs.image_embeds
index.insert_cosine(image_embeds, np.array([image_id]))
index.save()