程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2024-11(1)

mxnet快速上手

发布于2020-03-10 19:52     阅读(1534)     评论(0)     点赞(3)     收藏(1)


老司机请无视 方便其他框架转过来的兄弟快速上手

1.安装(版本1.5以上即可 推荐1.6)

首先是mxnet的安装
安装地址

其次是gluoncv的安装(gluoncv就是mxnet框架的model_zoo)

pip install gluoncv

2.mxnet官方指南

  1. mxnet官网
  2. Tutorials以及Api
  3. gluoncv教程
  4. 动手学习深度学习

3.mxnet中常用到的库

1. mxnet.nd (mxnet.ndarray)

在mxnet中的Tensor都是ndarray类型,由mxnet.nd这个库进行创建、加减乘除等等一大堆的操作<br>
例如:	
form mxnet import nd
from mxnet import gpu
# 生成shape为[1,3,224,224]的0矩阵(默认是在cpu上)
zeros = nd.zeros([1, 3, 224, 224])
# 从cpu转gpu
zeros.as_in_context(gpu())
```<br>
更多操作查询api文档<br>

## 2. mxnet.numpy (1.6版本以上才有的接口)
到本文发布为止,官网api中未提供numpy接口详细说明...1.6刚出来的功能<br>

用法和numpy一样(牛逼的地方是可以用gpu计算),就是为了完善mxnet.ndarray的某些功能,可以用mxnet.numpy操作ndarray数据<br>
例如:

```python
from mxnet import nd
from mxnet import numpy as np

a = nd.ones([3,3])
# 用np扩展维度
a = a[np.newaxis, :]
print(a.shape)

建议用mxnet.nd创建数据 mxnet.numpy操作数据(加减乘除等等)


## 3. gluoncv 这个库主要就是提供现成的model
from mxnet import nd
from gluoncv import model_zoo
net = model_zoo.get_model("yolo3_darknet53_coco", pretrained=False)
net.initialize()	# 初始化网络权重等等参数 很重要!
net.hybridize()	# 由动态图转静态图 转静态图速度会快点

data_shape = (1, 3, 416, 416)
input_data = nd.random.uniform(-1, 1, data_shape)
out = net(input_data)
print(out)

4. mxnet.gluon 重点

mxnet.gluon这个接口主要就是mxnet的动态图的接口(静态图接口是mxnet.sym好像 没怎么用它)

常用gluon.nn来创建网络,比如Dense Conv2d 等等

gluon.nn中有这俩东西: nn.Sequential 和 nn.HybridSequential

都是用来构建动态图的,区别就是带Hybrid字眼的能通过hybridize()函数转化为静态图,所以基本使用nn.HybridSequential

同样创建网络时候所继承的Block也有对应的HybridBlock,效果一致

from mxnet.gluon import HybridBlock, nn

# 例子1
net = nn.HybridSequential()
# use net's name_scope to give child Blocks appropriate names.
with net.name_scope():
    net.add(nn.Dense(10, activation='relu'))
    net.add(nn.Dense(20))
net.hybridize()

# 例子2
class Model(HybridBlock):
    def __init__(self, **kwargs):
	# 网络各个layer必须在__init__中初始化 不可在hybrid_forward中初始化
        super(Model, self).__init__(**kwargs)
        # use name_scope to give child Blocks appropriate names.
        with self.name_scope():
            self.dense0 = nn.Dense(20)
            self.dense1 = nn.Dense(20)

    def hybrid_forward(self, F, x):	# F指mxnet.ndarray 会调用ndarray中的方法操作数据
        x = F.relu(self.dense0(x))
        return F.relu(self.dense1(x))

model = Model()
model.initialize(ctx=mx.cpu(0))
model.hybridize()
model(mx.nd.zeros((10, 10), ctx=mx.cpu(0)))

5.数据集加载

[这块和pytorch一样 由dataset类和dataloader类构成}(http://mxnet.incubator.apache.org/api/python/docs/api/gluon/data/index.html)

from mxnet.gluon.data import Dataset
from mxnet.gluon.data import DataLoader

class DatasetBase(Dataset):
"""
只需要重写这三个函数
"""

    def __init__(self, data_root, transform=None, is_train=True):
	"""
	指定数据集地址等等初始化操作
	"""
        super(DatasetBase, self).__init__()
        self.transform = transform
	self.all = [] # 假设其存放所有的数据信息(例如图片路径和label)
	

    def __getitem__(self, idx):
	"""
	返回整个数据集中第idx个数据
	returns the i-th element
	"""
        data = self.all[idx]
        if self.transform is not None:
            return self.transform(data)
        return data

    def __len__(self):
	"""
	which returns the total number elements.
	"""
        return len(self.all)

dataset  = DatasetBase("path_to_img")
dataloader = DataLoader(
        dataset=dataset,
        batch_size=256,
        shuffle=True,
        num_workers=1
    )

for data in dataloder:
	# 得到数据
	pass

想到啥再补充



所属网站分类: 技术文章 > 博客

作者:我是小白兔

链接:https://www.pythonheidong.com/blog/article/251362/8b8c40ed3996b6ca2f26/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

3 0
收藏该文
已收藏

评论内容:(最多支持255个字符)