Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

训练好的模型 #1

Open
ForeUP opened this issue Apr 6, 2023 · 12 comments
Open

训练好的模型 #1

ForeUP opened this issue Apr 6, 2023 · 12 comments

Comments

@ForeUP
Copy link

ForeUP commented Apr 6, 2023

您好,请问方便分享一些训练好的模型吗?我想先测试一下效果,非常感谢!

@cwhgn
Copy link
Collaborator

cwhgn commented Apr 7, 2023

您好,如果您是需要预训练模型,可以在SOLIDER上下载,用SOLIDER-REID的训练命令直接训练即可得到finetune后的ReID模型。如果您希望直接使用训练好的ReID模型,我们争取近期更新一版,把模型放上去。

@ForeUP
Copy link
Author

ForeUP commented Apr 7, 2023

好的,非常感谢!

@cwhgn
Copy link
Collaborator

cwhgn commented Apr 11, 2023

您好,训练好的ReID模型链接我们已经更新到Readme中了,欢迎试用。

@deep-practice
Copy link

用训练好的模型提取特征,相同的人和不同的人相似度都很高

@cwhgn
Copy link
Collaborator

cwhgn commented Apr 20, 2023

可以用runtest.sh确认下效果。

@deep-practice
Copy link

deep-practice commented Apr 23, 2023

`import torch
import torch.nn as nn
import torchvision.transforms as T
import cv2
from config import cfg
from model import make_model
import torch.nn.functional as F
import numpy as np
from PIL import Image, ImageFile
import os.path as osp

def read_image(img_path):
got_img = False
if not osp.exists(img_path): raise IOError("{} does not exist".format(img_path))
while not got_img:
try:
img = Image.open(img_path).convert('RGB')
got_img = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
return img

@torch.no_grad()
def get_feature(img,model,device,normalize=False):
input = val_transforms(img).unsqueeze(0)
input = input.to(device)
output, _ = model(input)
if normalize:
output = F.normalize(output)
return output

val_transforms = T.Compose([
T.Resize(cfg.INPUT.SIZE_TEST),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
config_file = "configs/msmt17/swin_small.yml"
cfg.merge_from_file(config_file)
cfg.freeze()
model = make_model(cfg, num_class=1000, camera_num=0, view_num = 0, semantic_weight = 0.2)
model.load_param("weights/swin_small_msmt17.pth")
device = "cuda"

if device:
if torch.cuda.device_count() > 1:
print('Using {} GPUs for inference'.format(torch.cuda.device_count()))
model = nn.DataParallel(model)
model.to(device)

model.eval()

img1 = read_image("test_imgs/p5.jpg")
img2 = read_image("test_imgs/p7.jpg")
feature1 = get_feature(img1,model,device,normalize=True)
feature2 = get_feature(img2,model,device,normalize=True)
feature1,feature2 = F.normalize(feature1),F.normalize(feature2)
sim = torch.mm(feature1,feature2.t())
print(sim)

@cwhgn 这是我参考test.py写的代码,能帮忙看看问题在哪儿吗?图片之前相似度高达0.9+

`

@deep-practice
Copy link

模型用的官方提供的swin_small_msmt17.pth

@cwhgn
Copy link
Collaborator

cwhgn commented Apr 24, 2023

目前看我暂时也没找到您代码的问题。几点建议哈:1)可以先check下swin_small_msmt17.pth载入是否成功;2)分步check下图片特征和提供的test.py的输出是否一致;3)如果和test.py一致,则有没有可能两者相似度就是很高。

@deep-practice
Copy link

1)模型载入成功
2)对比了test.py提取的特征和我demo里面提取的特征,数值是一样的
3)使用欧式距离,相同图片和不同图片能区分开,但换成余弦相似度,任何图片的相似度都在0.9+

@cwhgn
Copy link
Collaborator

cwhgn commented Apr 25, 2023

针对3,我理解是因为训练时采用的是欧式距离,见

dist_mat = euclidean_dist(global_feat, global_feat)

如果你希望余弦距离可分的话,可以尝试用余弦距离进行训练。

@deep-practice
Copy link

明白了,谢谢

@MyraBaba
Copy link

MyraBaba commented Jul 5, 2023

@deep-practice

A silly newbee question.

this p5 and p7 image is person cropped image. What should be the dimesion of the image ?

is 384 x 192 ?

Is there any newbee document to make inference and compare distance of two detected person image bow.

Detected person could be from yolo or from solider

Best

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants