Skip to content

Commit

Permalink
support face recgnition (#3282)
Browse files Browse the repository at this point in the history
* add mobilefacenet

* exp with pplcnet on face

* support face recognition

* fix face recognition models bugs

* add face recognition README
  • Loading branch information
leo-q8 authored Nov 1, 2024
1 parent e3f7dde commit 4b3503b
Show file tree
Hide file tree
Showing 18 changed files with 1,092 additions and 178 deletions.
122 changes: 122 additions & 0 deletions docs/zh_CN/models/Face_Recognition/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# 人脸识别模型

------


## 目录

- [1. 模型和应用场景介绍](##1.-模型和应用场景介绍)
- [2. 支持模型列列表](#2.-支持模型列表)
- [3. 模型快速体验](#3.-模型快速体验)
- [3.1 安装paddlepaddle](#3.1-安装-paddlepaddle)
- [3.2 安装PaddleClas](#3.2-安装-PaddleClas)
- [3.3 模型训练与评估](#3.3-模型训练与评估)
- [3.3.1 下载数据集](#3.3.1-下载数据集)
- [3.3.2 模型训练](#3.3.2-模型训练)
- [3.3.2 模型评估](#3.3.3-模型评估)


## 1. 模型和应用场景介绍

人脸识别模型通常以经过检测提取和关键点矫正处理的标准化人脸图像作为输入。人脸识别模型从这些图像中提取具有高度辨识性的人脸特征,以便供后续模块使用,如人脸匹配和验证等任务。PaddleClas 目前支持了基于 [ArcFace](https://arxiv.org/abs/1801.07698) 损失函数训练的人脸识别模型,包括 [MobileFaceNet](https://arxiv.org/abs/1804.07573) 和 ResNet50。同时也支持在 AgeDB-30、CFP-FP、LFW,CPLFW 和 CALFW 5个常用的人脸识别数据集上进行评估。

## 2. 支持模型列表

|模型|训练数据集|输出特征维度 | 损失函数 |Acc (%)<br>AgeDB-30/CFP-FP/LFW | 模型参数量(M) |模型下载|
|-|-|-|-|:-:|-|-|
| MobileFaceNet |MS1Mv3 |128 |ArcFace |96.28/96.71/99.58 | 0.99 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/mobilefacenet.pdparams)|
| ResNet50 |MS1Mv3 |512 |ArcFace |98.12/98.56/99.77 | 25.56 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/resnet50_face.pdparams)|

**注:**

* 上述评估指标用到的数据集来自 [insightface](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#validation-datasets) 提供的bin文件,评估流程也完全与该仓库对齐
* PaddleClas参照一般人脸识别模型的训练设置,将训练辨率设为112x112。原始的ResNet50模型在该分辨率下使用ArcFace损失函数进行训练时较难收敛,在此这我们参考 [insightface](https://github.com/deepinsight/insightface/blob/a1eb8523fbe50b0c0e39a9fa96d4e2a6936b46be/recognition/arcface_torch/backbones/iresnet.py#L39) 仓库中 IResNet50 的实现,将原始 ResNet50 的整体下采样倍率由原先的 32x 调整为 16x。详见训练配置文件[`Face_Recognition/FaceRecognition_ArcFace_ResNet50.yaml`](../../../../ppcls/configs/Face_Recognition/FaceRecognition_ArcFace_ResNet50.yaml)

## 3. 模型快速体验

### 3.1 安装 paddlepaddle

- 您的机器安装的是 CUDA9 或 CUDA10,请运行以下命令安装

```bash
python3 -m pip install paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple
```

- 您的机器是 CPU,请运行以下命令安装

```bash
python3 -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple
```

更多的版本需求,请参照[飞桨官网安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。

### 3.2 安装 paddleclas

请确保已clone本项目,本地构建安装:

```
若之前安装有paddleclas,先使用以下命令卸载
python3 -m pip uninstall paddleclas
```

使用下面的命令构建:

```
cd path/to/PaddleClas
pip install -v -e .
```

### 3.3 模型训练与评估

#### 3.3.1 下载数据
执行以下命令下载人脸识别数据集MS1Mv3和人脸识别评估数据集,并解压到指定目录

```bash
cd path/to/PaddleClas
wget https://paddleclas.bj.bcebos.com/data/MS1M_v3.tar -P ./dataset/
tar -xf ./dataset/MS1M_v3.tar -C ./dataset/
```
成功执行后进入 `dataset/`目录,可以看到以下数据:

```bash
MS1M_v3
├── images # 训练图像保存目录
│ ├── 00000001.jpg # 训练图像文件
│ ├── 00000002.jpg # 训练图像文件
│ │ ...
├── agedb_30.bin # AgeDB-30 评估集文件
├── calfw.bin # CALFW 评估集文件
├── cfp_fp.bin # CFP-FP 评估集文件
├── cplfw.bin # CPLFW 评估集文件
├── label.txt # 训练集标注文件。每行给出图像的路径和人脸图像类别(人脸身份)id,使用空格分隔,内容举例:images/00000001.jpg 0
└── lfw.bin # LFW 评估集文件
```
* 注:上述MS1Mv3数据集的训练图像和标签是从 [insightface](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) 提供的`rec`格式的文件中恢复出来的,具体恢复过程可以参考 [AdaFace 仓库](https://github.com/mk-minchul/AdaFace/blob/master/README_TRAIN.md)。各评估集的`bin`文件也由 [insightface](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#validation-datasets) 提供。

### 3.3.2 模型训练

`ppcls/configs/Face_Recognition` 目录中提供了训练配置,以 MobileFaceNet 为例,可以通过如下脚本启动训练:

```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
tools/train.py \
-c ppcls/configs/Face_Recognition/FaceRecognition_ArcFace_MobileFaceNet.yaml
```

* 注:当前精度最佳的模型会保存在 `output/MobileFaceNet/best_model.pdparams`


### 3.3.3 模型评估

训练好模型之后,可以通过以下命令实现对模型指标的评估。

```bash
python3 tools/eval.py \
-c ppcls/configs/Face_Recognition/FaceRecognition_ArcFace_MobileFaceNet.yaml \
-o Global.pretrained_model=output/MobileFaceNet/best_model
```

其中 `-o Global.pretrained_model="output/MobileFaceNet/best_model"` 指定了当前最佳权重所在的路径,如果指定其他权重,只需替换对应的路径即可。

1 change: 1 addition & 0 deletions ppcls/arch/backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .model_zoo.resnest import ResNeSt50_fast_1s1x64d, ResNeSt50, ResNeSt101, ResNeSt200, ResNeSt269
from .model_zoo.googlenet import GoogLeNet
from .model_zoo.mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0
from .model_zoo.mobilefacenet import MobileFaceNet
from .model_zoo.shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2_x1_0, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish
from .model_zoo.ghostnet import GhostNet_x0_5, GhostNet_x1_0, GhostNet_x1_3
from .model_zoo.alexnet import AlexNet
Expand Down
61 changes: 42 additions & 19 deletions ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ def make_divisible(v, divisor=8, min_value=None):
new_v += divisor
return new_v

def _create_act(act):
if act == "hardswish":
return nn.Hardswish()
elif act == "relu":
return nn.ReLU()
elif act == "relu6":
return nn.ReLU6()
elif act == "prelu":
return nn.PReLU()
elif act == "leaky_relu":
return nn.LeakyReLU(negative_slope=0.1)
else:
raise RuntimeError(
"The activation function is not supported: {}".format(act))


class ConvBNLayer(TheseusLayer):
def __init__(self,
Expand All @@ -61,9 +76,9 @@ def __init__(self,
kernel_size,
stride,
groups=1,
use_act=True):
act="relu"):
super().__init__()
self.use_act = use_act
self.act = act
self.conv = Conv2D(
in_channels=in_channels,
out_channels=out_channels,
Expand All @@ -78,19 +93,19 @@ def __init__(self,
out_channels,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
if self.use_act:
self.act = nn.ReLU()
if self.act is not None:
self.act = _create_act(act)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.use_act:
if self.act:
x = self.act(x)
return x


class SEModule(TheseusLayer):
def __init__(self, channel, reduction=4):
def __init__(self, channel, reduction=4, act='relu'):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
Expand All @@ -99,7 +114,7 @@ def __init__(self, channel, reduction=4):
kernel_size=1,
stride=1,
padding=0)
self.relu = nn.ReLU()
self.act = _create_act(act)
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
Expand All @@ -112,7 +127,7 @@ def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv1(x)
x = self.relu(x)
x = self.act(x)
x = self.conv2(x)
x = self.hardsigmoid(x)
x = paddle.multiply(x=identity, y=x)
Expand All @@ -128,7 +143,8 @@ def __init__(self,
split_pw=False,
use_rep=False,
use_se=False,
use_shortcut=False):
use_shortcut=False,
act="relu"):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
Expand All @@ -151,7 +167,7 @@ def __init__(self,
kernel_size=kernel_size,
stride=stride,
groups=in_channels,
use_act=False)
act=None)
self.dw_conv_list.append(dw_conv)
self.dw_conv = nn.Conv2D(
in_channels=in_channels,
Expand All @@ -168,29 +184,32 @@ def __init__(self,
stride=stride,
groups=in_channels)

self.act = nn.ReLU()
self.act = _create_act(act)

if use_se:
self.se = SEModule(in_channels)
self.se = SEModule(in_channels, act=act)

if self.split_pw:
pw_ratio = 0.5
self.pw_conv_1 = ConvBNLayer(
in_channels=in_channels,
kernel_size=1,
out_channels=int(out_channels * pw_ratio),
stride=1)
stride=1,
act=act)
self.pw_conv_2 = ConvBNLayer(
in_channels=int(out_channels * pw_ratio),
kernel_size=1,
out_channels=out_channels,
stride=1)
stride=1,
act=act)
else:
self.pw_conv = ConvBNLayer(
in_channels=in_channels,
kernel_size=1,
out_channels=out_channels,
stride=1)
stride=1,
act=act)

def forward(self, x):
if self.use_rep:
Expand Down Expand Up @@ -260,6 +279,7 @@ def __init__(self,
dropout_prob=0,
use_last_conv=True,
class_expand=1280,
act="relu",
**kwargs):
super().__init__(**kwargs)
self.scale = scale
Expand All @@ -271,11 +291,13 @@ def __init__(self,
in_channels=3,
kernel_size=3,
out_channels=make_divisible(32 * scale),
stride=2), RepDepthwiseSeparable(
stride=2,
act=act), RepDepthwiseSeparable(
in_channels=make_divisible(32 * scale),
out_channels=make_divisible(64 * scale),
stride=1,
dw_size=3)
dw_size=3,
act=act)
])

# stages
Expand All @@ -294,7 +316,8 @@ def __init__(self,
split_pw=split_pw,
use_rep=use_rep,
use_se=use_se,
use_shortcut=use_shortcut)
use_shortcut=use_shortcut,
act=act)
for i in range(depths[depth_idx])
]))

Expand All @@ -309,7 +332,7 @@ def __init__(self,
stride=1,
padding=0,
bias_attr=False)
self.act = nn.ReLU()
self.act = _create_act(act)
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")

self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
Expand Down
18 changes: 11 additions & 7 deletions ppcls/arch/backbone/legendary_models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def __init__(self,
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
stride_list=[2, 2, 2, 2, 2],
max_pool=True,
data_format="NCHW",
input_image_channel=3,
return_patterns=None,
Expand Down Expand Up @@ -359,11 +360,13 @@ def __init__(self,
for in_c, out_c, k, s in self.stem_cfg[version]
])

self.max_pool = MaxPool2D(
kernel_size=3,
stride=stride_list[1],
padding=1,
data_format=data_format)
self.max_pool = max_pool
if max_pool:
self.max_pool = MaxPool2D(
kernel_size=3,
stride=stride_list[1],
padding=1,
data_format=data_format)
block_list = []
for block_idx in range(len(self.block_depth)):
# paddleclas' special improvement version
Expand All @@ -377,7 +380,7 @@ def __init__(self,
self.num_filters[block_idx] * self.channels_mult,
num_filters=self.num_filters[block_idx],
stride=self.stride_list[block_idx + 1]
if i == 0 and block_idx != 0 else 1,
if i == 0 and (block_idx != 0 or not max_pool) else 1,
shortcut=shortcut,
if_first=block_idx == i == 0 if version == "vd" else True,
layer=layer,
Expand Down Expand Up @@ -411,7 +414,8 @@ def _forward(self, x):
x = paddle.transpose(x, [0, 2, 3, 1])
x.stop_gradient = True
x = self.stem(x)
x = self.max_pool(x)
if self.max_pool:
x = self.max_pool(x)
x = self.blocks(x)
x = self.avg_pool(x)
x = self.flatten(x)
Expand Down
Loading

0 comments on commit 4b3503b

Please sign in to comment.