目录
1--动态输入和静态输入
2--Pytorch API
3--完整代码演示
4--模型可视化
5--测试动态导出的Onnx模型
1--动态输入和静态输入
当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变,当用户使用模型时只能输入指定维度的数据进行推理。
显然,动态输入的通用性比静态输入更强。
2--Pytorch API
在 Pytorch 中,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入,dynamic_axes 的默认值为 None,即默认为静态输入。
以下展示动态导出的用法,通过定义 dynamic_axes 参数来设置动态导出输入。dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值;
# 导出为动态输入
input_name = 'input'
output_name = 'output'
torch.onnx.export(model,
input_data,
"Dynamics_InputNet.onnx",
opset_version=11,
input_names=[input_name],
output_names=[output_name],
dynamic_axes={
input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'},
output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})
3--完整代码演示
在以下代码中,定义了一个网络,并使用动态导出和静态导出两种方式,将网络导出为 Onnx 模型格式。
import torch
import torch.nn as nn
class Model_Net(nn.Module):
def __init__(self):
super(Model_Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
def forward(self, data):
data = self.layer1(data)
return data
if __name__ == "__main__":
# 设置输入参数
Batch_size = 8
Channel = 3
Height = 256
Width = 256
input_data = torch.rand((Batch_size, Channel, Height, Width))
# 实例化模型
model = Model_Net()
# 导出为静态输入
input_name = 'input'
output_name = 'output'
torch.onnx.export(model,
input_data,
"Static_InputNet.onnx",
verbose=True,
input_names=[input_name],
output_names=[output_name])
# 导出为动态输入
torch.onnx.export(model,
input_data,
"Dynamics_InputNet.onnx",
opset_version=11,
input_names=[input_name],
output_names=[output_name],
dynamic_axes={
input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'},
output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})
4--模型可视化
通过 netron 库可视化导出的静态模型和动态模型,代码如下:
import netron
netron.start("./Dynamics_InputNet.onnx")
静态模型可视化:
动态模型可视化:
5--测试动态导出的Onnx模型
import numpy as np
import onnx
import onnxruntime
if __name__ == "__main__":
input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32)
input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32)
# 导入 Onnx 模型
Onnx_file = "./Dynamics_InputNet.onnx"
Model = onnx.load(Onnx_file)
onnx.checker.check_model(Model) # 验证Onnx模型是否准确
# 使用 onnxruntime 推理
model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[0].name
output1 = model.run([output_name], {input_name:input_data1})
output2 = model.run([output_name], {input_name:input_data2})
print('output1.shape: ', np.squeeze(np.array(output1), 0).shape)
print('output2.shape: ', np.squeeze(np.array(output2), 0).shape)
由输出结果可知,对应动态输入 Onnx 模型,其输出维度也是动态的,并且为对应关系,则表明导出的 Onnx 模型无误。