-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
63 lines (52 loc) · 2.31 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from dataloader import mean, std, dataloader
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
from skimage import io
import numpy as np
import pandas as pd
from fire import Fire
import torch
from autoencoder import Encoder
def show_image(img, mean=mean, std=std):
image = img * std + mean
image = np.clip(image, 0, 1)
plt.tight_layout()
io.imshow(image)
def save_embeddings(weights, path='data/', device='cuda', code_size=64, batch_size=16):
encoder = Encoder(code_size).to(device)
encoder.load_state_dict(torch.load(weights, map_location=device))
encoder.eval()
train_loader, val_loader, _ = dataloader(root=path, batch_size=batch_size,
shuffle=False, transform=False,
drop_last=False)
db = pd.read_csv('coordinates.csv')
codes_embed = []
filenames_embed = []
with torch.no_grad():
for train_batch, filenames in tqdm(train_loader):
train_batch = train_batch.to(device)
codes_embed.extend(encoder(train_batch).cpu().numpy())
for i in range(len(filenames)):
idx = filenames[i].replace('_large', '').replace('.jpg', '')
try:
ra, dec = db[db['#OBJID'] == int(idx)][['RA', 'DEC']].values[0]
filenames_embed.append([filenames[i], ra, dec])
except:
filenames_embed.append([filenames[i], None, None])
for val_batch, filenames in tqdm(val_loader):
val_batch = val_batch.to(device)
codes_embed.extend(encoder(val_batch).cpu().numpy())
for i in range(len(filenames)):
idx = filenames[i].replace('_large', '').replace('.jpg', '')
try:
ra, dec = db[db['#OBJID'] == int(idx)][['RA', 'DEC']].values[0]
filenames_embed.append([filenames[i], ra, dec])
except:
filenames_embed.append([filenames[i], None, None])
if not os.path.exists('embedding'):
os.mkdir('embedding')
np.save('embedding/codes.{}.npy'.format(code_size), codes_embed)
np.save('embedding/filenames.{}.npy'.format(code_size), filenames_embed)
if __name__ == '__main__':
Fire(save_embeddings)