程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2024-11(1)

pytorch中x.norm(p=2,dim=1,keepdim=True)的理解

发布于2019-12-07 16:37     阅读(9017)     评论(0)     点赞(23)     收藏(3)


代码:x.norm(p=2,dim=1,keepdim=True)

功能:求指定维度上的范数。

函数原型:【返回输入张量给定维dim 上每行的p范数】

                 torch.norm(input, p, dim, out=None,keepdim=False) → Tensor

         注:范数求法:【对N个数据求p范数】

                  https://private.codecogs.com/gif.latex?%7C%7Cx%7C%7C_%7Bp%7D%20%3D%20%5Csqrt%5Bp%5D%7Bx_%7B1%7D%5E%7Bp%7D%20+%20x_%7B2%7D%5E%7Bp%7D%20+%20%5Cldots%20+%20x_%7BN%7D%5E%7Bp%7D%7D

函数参数

input (Tensor) – 输入张量

p (float) – 范数计算中的幂指数值

dim (int) – 缩减的维度,dim=0是对0维度上的一个向量求范数,返回结果数量等于其列的个数,也就是说有多少个0维度的向                          量, 将得到多少个范数。dim=1同理。

out (Tensor, optional) – 结果张量

keepdim(bool)– 保持输出的维度 。当keepdim=False时,输出比输入少一个维度(就是指定的dim求范数的维度)。而                                                keepdim=True时,输出与输入维度相同,仅仅是输出在求范数的维度上元素个数变为1。这也是为什么有时                                      我们把参数中的dim称为缩减的维度,因为norm运算之后,此维度或者消失或者元素个数变为1。

 

例子说明

已知一个3×4矩阵,如下:

tensor([[ 1.,  2.,  3.,  4.],

        [ 2.,  4.,  6.,  8.],

        [ 3.,  6.,  9., 12.]])

1)dim参数,分别对其行和列分别求2范数:

inputs1 = torch.norm(inputs, p=2, dim=1, keepdim=True)

print(inputs1)

inputs2 = torch.norm(inputs, p=2, dim=0, keepdim=True)

print(inputs2)

结果分别为:

tensor([[ 5.4772],

        [10.9545],

        [16.4317]])

tensor([[ 3.7417,  7.4833, 11.2250, 14.9666]])

2)keepdim参数

inputs3 = inputs.norm(p=2, dim=1, keepdim=False)

print(inputs3)

inputs3为:

tensor([ 5.4772, 10.9545, 16.4317])

 

输出inputs1和inputs3的shape:

print(inputs1.shape)

print(inputs3.shape)

torch.Size([3, 1])

torch.Size([3])

可以看到inputs3少了一维,其实就是dim=1(求范数)那一维(列)少了,因为从4列变成1列,就是3行中求每一行的2范数,就剩1列了,不保持这一维不会对数据产生影响。或者也可以这么理解,就是数据每个数据有没有用[]扩起来。

即:

keepdim = True,用[]扩起来;

keepdim = False,不用[]括起来;

 

【不写keepdim,则默认不保留dim的那个维度】:

inputs4 = torch.norm(inputs, p=2, dim=1)

print(inputs4)

tensor([ 5.4772, 10.9545, 16.4317])

 

【不写dim,则计算Tensor中所有元素的2范数】:

inputs5 = torch.norm(inputs, p=2)

print(inputs5)

tensor(20.4939)

等价于这句话:

inputs6 = inputs.pow(2).sum().sqrt()

print(inputs6)

tensor(20.4939)

总之,norm操作后dim这一维变为1或者消失。



所属网站分类: 技术文章 > 博客

作者:好好学习

链接:https://www.pythonheidong.com/blog/article/170104/1a35da457b229b6a9c29/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

23 0
收藏该文
已收藏

评论内容:(最多支持255个字符)