程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

torchvision.transforms

发布于2019-12-07 22:53     阅读(1051)     评论(0)     点赞(20)     收藏(0)


上一篇博客讲解了如何组织自己的数据集,这节讲解如何对数据集进行一系列的预处理操作,即Transform。

1. torchvision.transforms.Compose(transforms)

该API将多个 transform 组合起来对数据进行预处理,其中参数 transforms 是由 transform 构成的列表。
现在举一个例子,先看原始图片:
在这里插入图片描述
接着对这个图片进行预处理操作 ToTensor(),即将 PILImage 或者 numpy 的 ndarray 转化成 Tensor。

import torch
import torchvision
import torchvision.transforms as transforms
import cv2
import numpy as np
from PIL import Image

img_path = "./data/train/Daffodil/image_0001.jpg"

# transforms.ToTensor()
transform1 = transforms.Compose([
    transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
]
)

# Way1
# numpy.ndarray --> tensor
img = cv2.imread(img_path)  # 读取图像, returned 500*689*3 (H * W * channel),数值[0, 255]
print("img type", type(img), img.shape)
img1 = transform1(img)      # 归一化到 [0.0,1.0], img1 is a returned tensor, torch.Size([3, 500, 689])
print("img type after ToTensor transform", type(img1), img1.size())

# 转化为numpy.ndarray并显示
img_1 = img1.numpy()*255
img_1 = img_1.astype('uint8')
img_1 = np.transpose(img_1, (1,2,0))
cv2.imshow('img_1', img_1)  # cv2.imshow() input shape: (H * W * channel)
cv2.waitKey()

# Way2
# PIL --> tensor
img = Image.open(img_path).convert('RGB')    # 读取图像
img2 = transform1(img)        # 归一化到 [0.0, 1.0]
print("img type after ToTensor transform", type(img2), img2.size())

# 转化为PILImage并显示
img_2 = transforms.ToPILImage()(img2).convert('RGB')
print("type", img_2)
img_2.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

在这里插入图片描述注意:shape为 (H,W,C) 的numpy.ndarray,转换成 size 为 [C,H,W]的 tensor ,并且 tensor 的取值范围是[0, 1.0]。通道数在第一个维度

2. 常见的一些预处理操作

  1. class torchvision.transforms.CenterCrop(size):将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,在这种情况下,切出来的图片的形状是正方形。
  2. class torchvision.transforms.RandomCrop(size, padding=0):切割中心点的位置随机选取。size可以是tuple也可以是Integer。
  3. class torchvision.transforms.RandomHorizontalFlip():随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
  4. class torchvision.transforms.RandomSizedCrop(size, interpolation=2):先将给定的PIL.Image随机切割(尺寸不定),然后再resize成给定的size大小。
  5. class torchvision.transforms.Pad(padding, fill=0):将给定的PIL.Image的所有边用给定的pad value填充。 padding:要填充多少像素,fill:用什么值填充。
  6. transforms.Normalize((m1, m2, m3), (v1, v2, v3)):使用如下公式进行归一化:channel =(channel - mean)/ std,将在 0 到 1 之间的值变换到了 -1 到 1 区间。其中,mi、vi分别表示第 i 个通道的均值与方差。(比如 transforms.Normalize((.5,.5,.5),(.5,.5,.5)),因为transforms.ToTensor() 已经把数据处理成 [0, 1],那么(x - 0.5) / 0.5 就是 [-1.0, 1.0]。)

至于为什么对数据使用 transforms.Normalize() 变换到 [-1, 1],请读者自行思考!



所属网站分类: 技术文章 > 博客

作者:大壮

链接:https://www.pythonheidong.com/blog/article/170258/08bdbe31d0373d12945d/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

20 0
收藏该文
已收藏

评论内容:(最多支持255个字符)