博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
(原)InsightFace及其mxnet代码
阅读量:6988 次
发布时间:2019-06-27

本文共 3384 字,大约阅读时间需要 11 分钟。

转载请注明出处:

论文

InsightFace : Additive Angular Margin Loss for Deep Face Recognition

官方mxnet代码:

 

说明:没用过mxnet,下面的代码注释只是纯粹从代码的角度来分析并进行注释,如有错误之处,敬请谅解,并欢迎指出。

 

先查看sphereface,查看$\psi (\theta )$的介绍:

论文arcface中,定义$\psi (\theta )$为:

$\psi (\theta )=\cos ({

{\theta }_{yi}}+m)$

同时对w及x均进行了归一化,为了使得训练能收敛,增加了一个参数s=64,最终loss如下:

$L=-\frac{1}{m}\sum\limits_{i=1}^{m}{\log \frac{

{
{e}^{s(\cos ({
{\theta }_{yi}}+m))}}}{
{
{e}^{s(\cos ({
{\theta }_{yi}}+m))}}+\sum\nolimits_{j=n,j\ne yi}^{n}{
{
{e}^{s\cos {
{\theta }_{j}}}}}}}$

其中,

${

{W}_{j}}=\frac{
{
{W}_{j}}}{\left\| {
{W}_{j}} \right\|}$,${
{x}_{i}}=\frac{
{
{x}_{i}}}{\left\| {
{x}_{i}} \right\|}$,$\cos {
{\theta }_{j}}=W_{j}^{T}{
{x}_{i}}$

程序中先对w及x归一化,然后通过全连接层得到cosθ,再扩大s倍,得到scosθ。

对于yi处,由于

$\cos (\theta +m)=\cos \theta \cos m-\sin \theta \sin m$

以及

$\sin \theta =\sqrt{1-{

{\cos }^{2}}\theta }$

得到sinθ。

由于$\cos (\theta +m)$非单调,设置了easy_margin标志,当其为真时,使用0作为阈值,当特征和权重的cos值小于0,直接截断;当其为假时,使用cos(pi-m)=-cos(m)作为阈值。该阈值小于0。

之后判断时,当easy_margin为真时,若s*cos(θ+m)小于0,直接使用s*cos(θ);当easy_margin为假时,若s*cos(θ+m)小于0,使用s*cos(θ)-s*m*sin(m)。

具体的代码如下(完整代码见参考网址):

1     s = args.margin_s  # 参数s 2     m = args.margin_m  # 参数m 3  4     _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) # (C,F) 5     _weight = mx.symbol.L2Normalization(_weight, mode='instance')   # 对w进行归一化 6     nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s # 对x进行归一化,并得到s*x,(B,F) 7     fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') # Y=XW'+b,(B,F)*(C,F)'=(B,C),'为转置,此处得到scos(theta) 8      9     zy = mx.sym.pick(fc7, gt_label, axis=1)  # 得到fc7中gt_label位置的值。(B,1)或者(B),即当前batch中yi处的scos(theta)10     cos_t = zy/s  # 由于fc7及zy均为cos的s倍,此处除以s,得到实际的cos值。(B,1)或者(B)11     12     cos_m = math.cos(m)13     sin_m = math.sin(m)14     mm = math.sin(math.pi-m)*m # sin(pi-m)*m = sin(m)*m15     threshold = math.cos(math.pi-m)  # 阈值,避免theta + m >= pi,实际上threshold < 016     if args.easy_margin:17       cond = mx.symbol.Activation(data=cos_t, act_type='relu') #easy_margin=True,直接使用0作为阈值,得到超过阈值的索引18     else:19       cond_v = cos_t - threshold #easy_margin=False,使用threshold(负数)作为阈值。20       cond = mx.symbol.Activation(data=cond_v, act_type='relu') # 得到超过阈值的索引21     body = cos_t*cos_t  # 通过cos*cos + sin * sin = 1, 来得到sin_theta22     body = 1.0-body23     sin_t = mx.sym.sqrt(body)  # sin_theta24     new_zy = cos_t*cos_m # cos(theta+m)=cos(theta)*cos(m)-sin(theta)*sin(m),此处为cos(theta)*cos(m)25     b = sin_t*sin_m # 此处为sin(theta)*sin(m)26     new_zy = new_zy - b # 此处为cos(theta)*cos(m)-sin(theta)*sin(m)=cos(theta+m)27     new_zy = new_zy*s # 此处为s*cos(theta+m),扩充了s倍28     if args.easy_margin:29       zy_keep = zy   # zy_keep为zy,即s*cos(theta)30     else:31       zy_keep = zy - s*mm  # zy_keep为zy-s*sin(m)*m=s*cos(theta)-s*m*sin(m)32     new_zy = mx.sym.where(cond, new_zy, zy_keep) # cond中>0的保持new_zy=s*cos(theta+m)不变,<0的裁剪为zy_keep= s*cos(theta) or s*cos(theta)-s*m*sin(m)33 34     diff = new_zy - zy # 35     diff = mx.sym.expand_dims(diff, 1)36     gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)37     body = mx.sym.broadcast_mul(gt_one_hot, diff) # 对应yi处为new_zy - zy38     fc7 = fc7+body # 对应yi处,fc7=zy + (new_zy - zy) = new_zy,即cond中>0的为s*cos(theta+m),<0的裁剪为s*cos(theta) or s*cos(theta)-s*m*sin(m)

 

你可能感兴趣的文章
我的友情链接
查看>>
Office365 SKU-1
查看>>
通过JDBC向数据库中存储&读取Blob数据
查看>>
2019年我国云计算行业存在的问题和发展趋势
查看>>
内置模块(二)
查看>>
C编程技巧
查看>>
week5
查看>>
Unity3D常用网络框架与实战解析 学习
查看>>
继承(原型链继承)
查看>>
如何利用 Visual Studio 自定义项目或工程模板(转载)
查看>>
java.lang.Object底层代码分析-jdk1.8
查看>>
获取函数所在模块的方法
查看>>
QtTableView
查看>>
Android应用开发基础--Adapter
查看>>
条件随机场
查看>>
别人要访问我的电脑上部署的tomcat,必须关闭防火墙吗?
查看>>
作业六
查看>>
c++ 二叉树打印节点路径
查看>>
ios--编码规范
查看>>
JsCV Core v0.2发布 & Javascript图像处理系列目录
查看>>