程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2024-11(2)

pytorch1.1.0-python3.6-CUDA9.0-以10分类为例子,使用自建类载入数据,训练网络(efficientnet)

发布于2019-08-07 12:14     阅读(977)     评论(0)     点赞(2)     收藏(5)


资料来源于网络,但是修改了一些东西,比如载入数据,可以灵活的读取,自定路径+txt内写入的路径:

    # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
    train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
    valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)

首先时自建载入数据类:

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torch
from torch.autograd import Variable
import os
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F




class MyDataset(Dataset):
    def __init__(self, img_path, txt_path, transform = None, target_transform = None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            img_path_result = os.path.join(img_path, words[0])
            imgs.append((img_path_result, int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)


def validate(net, data_loader, set_name, classes_name):
    """
    对一批数据进行预测,返回混淆矩阵以及Accuracy
    :param net:
    :param data_loader:
    :param set_name:  eg: 'valid' 'train' 'tesst
    :param classes_name:
    :return:
    """
    net.eval()
    cls_num = len(classes_name)
    conf_mat = np.zeros([cls_num, cls_num])

    for data in data_loader:
        images, labels = data
        images = Variable(images)
        labels = Variable(labels)

        outputs = net(images)
        outputs.detach_()

        _, predicted = torch.max(outputs.data, 1)

        # 统计混淆矩阵
        for i in range(len(labels)):
            cate_i = labels[i].numpy()
            pre_i = predicted[i].numpy()
            conf_mat[cate_i, pre_i] += 1.0

    for i in range(cls_num):
        print('class:{:<10}, total num:{:<6}, correct num:{:<5}  Recall: {:.2%} Precision: {:.2%}'.format(
            classes_name[i], np.sum(conf_mat[i, :]), conf_mat[i, i], conf_mat[i, i] / (1 + np.sum(conf_mat[i, :])),
                                                                conf_mat[i, i] / (1 + np.sum(conf_mat[:, i]))))

    print('{} set Accuracy:{:.2%}'.format(set_name, np.trace(conf_mat) / np.sum(conf_mat)))

    return conf_mat, '{:.2}'.format(np.trace(conf_mat) / np.sum(conf_mat))


def show_confMat(confusion_mat, classes, set_name, out_dir):

    # 归一化
    confusion_mat_N = confusion_mat.copy()
    for i in range(len(classes)):
        confusion_mat_N[i, :] = confusion_mat[i, :] / confusion_mat[i, :].sum()

    # 获取颜色
    cmap = plt.cm.get_cmap('Greys')  # 更多颜色: http://matplotlib.org/examples/color/colormaps_reference.html
    plt.imshow(confusion_mat_N, cmap=cmap)
    plt.colorbar()

    # 设置文字
    xlocations = np.array(range(len(classes)))
    plt.xticks(xlocations, list(classes), rotation=60)
    plt.yticks(xlocations, list(classes))
    plt.xlabel('Predict label')
    plt.ylabel('True label')
    plt.title('Confusion_Matrix_' + set_name)

    # 打印数字
    for i in range(confusion_mat_N.shape[0]):
        for j in range(confusion_mat_N.shape[1]):
            plt.text(x=j, y=i, s=int(confusion_mat[i, j]), va='center', ha='center', color='red', fontsize=10)
    # 保存
    plt.savefig(os.path.join(out_dir, 'Confusion_Matrix' + set_name + '.png'))
    plt.close()


def normalize_invert(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

然后是训练网络程序:

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch import utils
from MyDataset import MyDataset
from torchsummary import summary
from torchstat import stat
from tensorboardX import SummaryWriter
writer = SummaryWriter('log')

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    #for batch_idx, (data, target) in enumerate(train_loader):
    for batch_idx, data_ynh in enumerate(train_loader):
        # 获取图片和标签
        data, target = data_ynh
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        output1 = torch.nn.functional.log_softmax(output, dim=1)
        loss = F.nll_loss(output1, target)
        #loss = F.l1_loss(output, target)
        loss.backward()
        optimizer.step()

        #new ynh
        #每10个batch画个点用于loss曲线
        if batch_idx % 10 == 0:
            niter = epoch * len(train_loader) + batch_idx
            writer.add_scalar('Train/Loss', loss.data, niter)

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


def test(args, model, device, test_loader, epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        #for data, target in test_loader:
        for data_ynh in test_loader:
            # 获取图片和标签
            data, target = data_ynh
            data, target = data.to(device), target.to(device)
            output = model(data)
            output1 = torch.nn.functional.log_softmax(output, dim=1)
            test_loss += F.nll_loss(output1, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    # new ynh
    writer.add_scalar('Test/Accu', test_loss, epoch)


    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=10, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=10, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # -------------------------------------------- step 1/5 : 加载数据 -------------------------------------------
    train_txt_path = './Data/train.txt'
    valid_txt_path = './Data/valid.txt'
    # 数据预处理设置
    normMean = [0.4948052, 0.48568845, 0.44682974]
    normStd = [0.24580306, 0.24236229, 0.2603115]
    normTransform = transforms.Normalize(normMean, normStd)
    trainTransform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomCrop(224, padding=4),
        transforms.ToTensor(),
        normTransform
    ])

    validTransform = transforms.Compose([
        transforms.ToTensor(),
        normTransform
    ])

    # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
    train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
    valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)

    # 构建DataLoder
    train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=16)




    blocks_args, global_params = utils.get_model_params('efficientnet-b0', override_params=None)
    model = EfficientNet(blocks_args, global_params)  # .to(device)  # .cuda()
    model = EfficientNet.from_pretrained('efficientnet-b0').to(device)#.cuda()


    dummy_input = torch.rand(1, 3, 224, 224)
    #writer.add_graph(model, (dummy_input,))

    print(model)


    #stat(model, (3, 224, 224))
    model.to(device)
    #summary(model, (3, 224, 224))

    print("-------------------------------------------")



    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, valid_loader, epoch)

    if (args.save_model):
        torch.save(model.state_dict(), "mnist_cnn.pt")

    writer.close()


if __name__ == '__main__':
    main()

 



所属网站分类: 技术文章 > 博客

作者:雪儿

链接:https://www.pythonheidong.com/blog/article/10699/da49e9b78a3958b29d4e/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

2 0
收藏该文
已收藏

评论内容:(最多支持255个字符)