发布于2020-03-10 19:52 阅读(1534) 评论(0) 点赞(3) 收藏(1)
老司机请无视 方便其他框架转过来的兄弟快速上手
首先是mxnet的安装
安装地址
其次是gluoncv的安装(gluoncv就是mxnet框架的model_zoo)
pip install gluoncv
在mxnet中的Tensor都是ndarray类型,由mxnet.nd这个库进行创建、加减乘除等等一大堆的操作<br>
例如:
form mxnet import nd
from mxnet import gpu
# 生成shape为[1,3,224,224]的0矩阵(默认是在cpu上)
zeros = nd.zeros([1, 3, 224, 224])
# 从cpu转gpu
zeros.as_in_context(gpu())
```<br>
更多操作查询api文档<br>
## 2. mxnet.numpy (1.6版本以上才有的接口)
到本文发布为止,官网api中未提供numpy接口详细说明...1.6刚出来的功能<br>
用法和numpy一样(牛逼的地方是可以用gpu计算),就是为了完善mxnet.ndarray的某些功能,可以用mxnet.numpy操作ndarray数据<br>
例如:
```python
from mxnet import nd
from mxnet import numpy as np
a = nd.ones([3,3])
# 用np扩展维度
a = a[np.newaxis, :]
print(a.shape)
建议用mxnet.nd创建数据 mxnet.numpy操作数据(加减乘除等等)
from mxnet import nd
from gluoncv import model_zoo
net = model_zoo.get_model("yolo3_darknet53_coco", pretrained=False)
net.initialize() # 初始化网络权重等等参数 很重要!
net.hybridize() # 由动态图转静态图 转静态图速度会快点
data_shape = (1, 3, 416, 416)
input_data = nd.random.uniform(-1, 1, data_shape)
out = net(input_data)
print(out)
mxnet.gluon这个接口主要就是mxnet的动态图的接口(静态图接口是mxnet.sym好像 没怎么用它)
常用gluon.nn来创建网络,比如Dense Conv2d 等等
gluon.nn中有这俩东西: nn.Sequential 和 nn.HybridSequential
都是用来构建动态图的,区别就是带Hybrid字眼的能通过hybridize()函数转化为静态图,所以基本使用nn.HybridSequential
同样创建网络时候所继承的Block也有对应的HybridBlock,效果一致
from mxnet.gluon import HybridBlock, nn
# 例子1
net = nn.HybridSequential()
# use net's name_scope to give child Blocks appropriate names.
with net.name_scope():
net.add(nn.Dense(10, activation='relu'))
net.add(nn.Dense(20))
net.hybridize()
# 例子2
class Model(HybridBlock):
def __init__(self, **kwargs):
# 网络各个layer必须在__init__中初始化 不可在hybrid_forward中初始化
super(Model, self).__init__(**kwargs)
# use name_scope to give child Blocks appropriate names.
with self.name_scope():
self.dense0 = nn.Dense(20)
self.dense1 = nn.Dense(20)
def hybrid_forward(self, F, x): # F指mxnet.ndarray 会调用ndarray中的方法操作数据
x = F.relu(self.dense0(x))
return F.relu(self.dense1(x))
model = Model()
model.initialize(ctx=mx.cpu(0))
model.hybridize()
model(mx.nd.zeros((10, 10), ctx=mx.cpu(0)))
[这块和pytorch一样 由dataset类和dataloader类构成}(http://mxnet.incubator.apache.org/api/python/docs/api/gluon/data/index.html)
from mxnet.gluon.data import Dataset
from mxnet.gluon.data import DataLoader
class DatasetBase(Dataset):
"""
只需要重写这三个函数
"""
def __init__(self, data_root, transform=None, is_train=True):
"""
指定数据集地址等等初始化操作
"""
super(DatasetBase, self).__init__()
self.transform = transform
self.all = [] # 假设其存放所有的数据信息(例如图片路径和label)
def __getitem__(self, idx):
"""
返回整个数据集中第idx个数据
returns the i-th element
"""
data = self.all[idx]
if self.transform is not None:
return self.transform(data)
return data
def __len__(self):
"""
which returns the total number elements.
"""
return len(self.all)
dataset = DatasetBase("path_to_img")
dataloader = DataLoader(
dataset=dataset,
batch_size=256,
shuffle=True,
num_workers=1
)
for data in dataloder:
# 得到数据
pass
想到啥再补充
作者:我是小白兔
链接:https://www.pythonheidong.com/blog/article/251362/8b8c40ed3996b6ca2f26/
来源:python黑洞网
任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任
昵称:
评论内容:(最多支持255个字符)
---无人问津也好,技不如人也罢,你都要试着安静下来,去做自己该做的事,而不是让内心的烦躁、焦虑,坏掉你本来就不多的热情和定力
Copyright © 2018-2021 python黑洞网 All Rights Reserved 版权所有,并保留所有权利。 京ICP备18063182号-1
投诉与举报,广告合作请联系vgs_info@163.com或QQ3083709327
免责声明:网站文章均由用户上传,仅供读者学习交流使用,禁止用做商业用途。若文章涉及色情,反动,侵权等违法信息,请向我们举报,一经核实我们会立即删除!