神经网络模型导出及开放标准格式ONNX

发布时间:2026/7/5 5:02:32
神经网络模型导出及开放标准格式ONNX 模型格式1.1 现代主流/安全优化格式新生态1 safetensorsHugging Face主导背景目前AI绘画Stable Diffusion、ComfyUI和主流大模型最推荐的格式。详细说明核心特点100%安全它只包含纯粹的张量数据模型权重和一小段描述结构的JSON文本没有任何可执行代码完全免疫了传统格式的木马病毒风险。性能支持零拷贝Zero-copy和内存映射mmap加载速度极快远超旧格式。使用场景SD/SDXL/Flux模型的Checkpoint、LoRA、ControlNet以及大语言模型权重。2 gguf背景替代了早期的ggml格式是目前端侧/本地大模型LLM量化的绝对霸主。详细说明核心特点将模型的所有组件权重、分词器、超参数、元数据封装在单个文件内。它的扩展性极强支持未来添加新的元数据而不破坏向后兼容性。量化优势专门为CPU/GPU混合推理优化支持各种强度的量化如Q4_K_M、Q8_0能让几百亿参数的大模型在消费级显卡甚至手机、MacBook上流畅运行。适用场景Llama3、Gemma、Qwen等大语言模型的本地部署最近ComfyUI也开始流行用GGUF格式来运行轻量化的Flux或SDXL模型。1.2 传统开发/原生框架格式多见于训练与研究这类格式通常伴随官方框架诞生虽然灵活但在跨平台部署或安全性上存在一定局限。1 pth/ptPyTorch详细说明核心特点PyTorch框架的原生存储格式它基于PyTorch的pickle模块进行序列化。优缺点它极其灵活不仅能保存模型权重还能把训练了一半的优化器状态Optimizer、Epoch轮数甚至网络结构代码一起存进去是模型训练和微调的标配。但缺点是不安全加载恶意的pth文件可能会在电脑上自动执行破坏性代码。适用场景AI绘画中的早期特定模型如部分ControlNet、ESRGAN放大算法模型、PyTorch算法开发与训练。2 ckptCheckpoint背景这是Stable Diffusion早期1.5时代最常用的格式本质上和pth一样也是基于pickle序列化。详细说明核心特点由于存在和pth一模一样的安全漏洞可注入恶意脚本目前在生图领域已经全面被safetensors淘汰。如果你在网上下载到老的ckpt权重建议用工具转换成safetensors再使用。3 h5/keras/pb (TensorFlow/Keras)背景Google生态TensorFlow下的主流格式。详细说明核心特点h5HDF5多用于保存Keras模型的权重或完整结构pbProtocol Buffers则是TensorFlow用于生产端部署的图模型格式。适用场景传统计算机视觉如某些老的人脸识别、目标检测算法和工业级TensorFlow工业管线。1.3 夸平台推理与硬件加速格式生产端部署当模型训练完成后为了让它在手机、网页、或者不同显卡英伟达、AMD、Intel上跑得飞快通常会转换成以下格式。1 onnxOpen Neural Network Exchange详细说明核心特点“通用翻译官”由微软、Meta等联合发起的开放标准。它把模型结构抽象成一张通用的计算图让你可以把PyTorch训练的模型无缝转换到TensorFlow或其他推理引擎中运行。适用场景跨平台部署比如你在ComfyUI里用到的某些WD14自动打标插件、LayerDiffusion背景分离模型很多底层都在跑ONNX格式因为它在CPU或非英伟达显卡上的兼容性极好。2 engineNVIDIA TensorRT背景英伟达官方对自家显卡进行极致硬件加速的专用格式。详细说明核心特点速度榨干者你必须在自己的显卡上把onnx或safetensors现场“编译”成 engine 格式。它会根据你当前的显卡架构如RTX 4090进行剪枝、层融合和量化加速。优缺点速度达到物理极限通常比原模型快 50%~100%。但完全没有通用性在4090上编译的engine文件拿到3080上是用不了的甚至显卡驱动升级了都可能需要重新编译。适用场景追求极致生图速度的WebUI/ComfyUI TensorRT加速工作流或工业级实时AI推理。3 tflite/onnxruntime (移动端轻量化)背景tflite是TensorFlow Lite格式专门给安卓、iOS手机或嵌入式设备如树莓派使用做了极端的体积压缩和定点量化。1.4 总结在实际生产应用时模型包含训练和部署两大关键环节在训练环节模型可通过不同框架来实现典型如PyTorch、TensorFlow等由于框架的不同会产生不同格式的训练模型而在部署环节由于硬件平台的不同又需要需要将不同格式的模型进行差异化修改这种多到多的操作实施起来很繁琐。经过工业界和学术界数年的探索模型部署有了一条流行的流水线如上图为了让模型最终能够部署到某一环境上开发者们可以使用任意一种深度学习框架来定义网络结构并通过训练确定网络中的参数。之后模型的结构和参数会被转换成一种只描述网络结构的中间表示一些针对网络结构的优化会在中间表示上进行。最后用面向硬件的高性能编程框架(如 CUDAOpenCL编写能高效执行深度学习网络中算子的推理引擎会把中间表示转换成特定的文件格式并在对应硬件平台上高效运行模型比如中间表示ONNX转换支持华为芯片推理的OM文件。2 模型导出及格式转换2.1 导出pth格式模型在上一篇文章中源码trainNN.py已经导出mnist_cnn.pth模型文件PyTorch使用Python原生pickle序列化任意Python对象模型、张量、字典、优化器、数字、自定义类等这是pth文件的核心。pickle可执行任意代码所以处于安全考虑不要加载来源不明的pt文件Torch 2.0提供weights_onlyTrue安全加载模式仅读取张量、禁止反序列化Python对象。新版PyTorch默认开启_use_new_zipfile_serializationTruepth本质是一个zip压缩包旧版本是纯二进制pickle文件无压缩。对于新版pt文件可以直接用unzip model.pt解压查看内部文件。如解压后内容如下1 目录及文件解释versionpickle序列化协议版本pickle4/pickle5用于版本兼容校验.format_versionTorch自定义ZIP存储格式版本区分新旧存储结构版本不匹配会直接加载失败.storage_alignment张量二进制存储内存对齐字节数GPU/CPU加载时用来对齐内存提升张量读取速度一般是1或64byteorder存储硬件字节序x86机器固定是little小端序用来解析.storage里的浮点/整数二进制data.pkl整个模型的逻辑骨架用pickle序列化了checkpoint /state_dict字典但只存张量的「描述信息」不存权重数字data目录下存储张量权重二进制数据.data该目录可能是特定场景如torch.package/export的残留或误生成2 data.pkl详解直接用UE打开该二进制文件内容如下在VS Code中安装PKL Viewer扩展也一并打开该文件进行分析内容如下1 0: \x80 PROTO 2 2 2: c GLOBAL collections OrderedDict 3 27: q BINPUT 0 4 29: ) EMPTY_TUPLE 5 30: R REDUCE 6 31: q BINPUT 1 7 33: ( MARK 8 34: X BINUNICODE features.0.weight 9 56: q BINPUT 2 10 58: c GLOBAL torch._utils _rebuild_tensor_v2 11 91: q BINPUT 3 12 93: ( MARK 13 94: ( MARK 14 95: X BINUNICODE storage 15 107: q BINPUT 4 16 109: c GLOBAL torch FloatStorage 17 129: q BINPUT 5 18 131: X BINUNICODE 0 19 137: q BINPUT 6 20 139: X BINUNICODE cuda:0 21 150: q BINPUT 7 22 152: M BININT2 288 23 155: t TUPLE (MARK at 94) 24 156: q BINPUT 8 25 158: Q BINPERSID 26 159: K BININT1 0 27 161: ( MARK 28 162: K BININT1 32 29 164: K BININT1 1 30 166: K BININT1 3 31 168: K BININT1 3 32 170: t TUPLE (MARK at 161) 33 171: q BINPUT 9 34 173: ( MARK 35 174: K BININT1 9 36 176: K BININT1 9 37 178: K BININT1 3 38 180: K BININT1 1 39 182: t TUPLE (MARK at 173) 40 183: q BINPUT 10 41 185: \x89 NEWFALSE 42 186: h BINGET 0 43 188: ) EMPTY_TUPLE 44 189: R REDUCE 45 190: q BINPUT 11 46 192: t TUPLE (MARK at 93) 47 193: q BINPUT 12 48 195: R REDUCE 49 196: q BINPUT 13 50 198: X BINUNICODE features.0.bias 51 218: q BINPUT 14 52 220: h BINGET 3最一开始是由Pickle规范PEP307/PEP574)规定的0x80PROTO头声明了本次序列化使用的Pickle协议版本0x80后紧跟1字节参数不是协议号本身而是有对应的映射表随后的二进制流可以通过Python标准库的pickletools模块来查看包含所有协议版本的完整指令对照表源码内带详细注释说明每个指令的作用该文件路径可以通过以下代码获取python -c import pickletools; print(pickletools.__file__)也可以通过如下文件直接打印出相应对照表printpkl.py该文件运行输出内容如下接下来继续结合PKL Viewer解析结果进行分析第2行内容等价于python代码collections.OrderedDict把该类压入到Pickle栈此时栈为[OrderedDict]第3行内容表示将OrderedDict类缓存到memo表即memo[0] class collections.OrderedDict第4行将空元组压入栈此时栈为[OrderedDict, ()]第5行REDUCE的含义是callable(*args)对应的代码是弹出参数空元组及类并调用类实现OrderedDict(*())之后将结果及用OrderedDict类构造一个空实例对象压入到栈此时栈为[ordereddict实例]第6行保存对象到memo表即memo[1] OrderedDict()此时实际上已经得到state_dict OrderedDict()此时在memo表中分别存储了类及其实例对象后续可根据需要进行复用。接下来代码开始向OrderedDict插入元素第7行向栈插入MARK分隔符用于标记一个可变长度对象的开始位置第8行向栈压入一个Unicode字符串对象后续它将作为key第9行将字符串存入到memo表第10行向栈中压入_rebuild_tensor_v2重建函数此时栈内容为[OrderedDict(), MARK33, features.0.weight, _rebuild_tensor_v2]第11行将其存入到memo表第1213行两次将MARK压入到栈第14~22行依次将“storage”FloatStorage“0”“cuda:0”288入栈最终由第23行的TUPLE和最近MARK94产生元组对象并入栈此时栈内容为[OrderedDict(), MARK33, features.0.weight, _rebuild_tensor_v2, MARK93, (storage, FloatStorage, 0, cuda:0, 288)]第25行BINPERSID会调用unpickler.persistent_load(pid)来真正从data/0中的数据实现Storage对象这里正是结构和数据分别存放的精华所在例如这里weight数据288个float数据总大小为1152字节和data/0文件大小正好对应上此时栈内容为[OrderedDict(), MARK33, features.0.weight, _rebuild_tensor_v2, MARK93, FloatStorage(288)]而_rebuild_tensor_v2函数的参数定义如下_rebuild_tensor_v2( storage, storage_offset, size, stride, requires_grad, backward_hooks )可见第2个参数是storage_offset就是说生成Tensor时不一定从Storage的开头开始而是可以从指定storage_offset下标开始这正好对应源码中第26行内容这里向栈中压入下标0接下来第27~32行给出_rebuild_tensor_v2函数的第3个参数size(32,1,3,3)即weight形状是32x1x3x3正好是288个元素再接下来第34~39行给出函数第4个参数stride(9, 9, 3, 1)即stride[0]1*3*39stride[1]3*39stride[2]3stride[3]1随后第41行向栈中压入False值对应参数requires_grad之后第42~44先通过BINGET 0将之前memo中预存的OrderedDict类压栈然后再将空元组压栈之后通过REDUCE实例化OrderedDict对象作为参数backward_hooks的值最终通过第46行的和MARK93闭合产生参数元组并在48行完成_rebuild_tensor_v2函数调用产生tensor实例。接下来过程类似最终栈内容大致如下OrderedDict() MARK33 features.0.weight Tensor(...) features.0.bias Tensor(...) features.1.weight Tensor(...)各个元素初始化完毕后最终会通过SETITEMS完成OrderedDict对象赋值该指令会把最近MARK后的所有key——value键值对插入到OrderedDict对象最终栈变成[ OrderedDict({ features.0.weight: tensor(...), features.0.bias: tensor(...), features.1.weight: tensor(...), ... }) ]2.2 pth转换为onnx1 ProtobufProtocol Buffers简称Protobuf是Google开发的一种语言无关、平台无关、可扩展的机制用于序列化结构化数据。简单来说可以把它想象成一种更高效、更快速的“XML”或“JSON”。它的核心作用是将内存中的复杂数据结构如对象、结构体转换成紧凑的二进制字节流以便进行网络传输或持久化存储反过来也能将二进制数据还原为原始数据结构。使用Protobuf通常遵循以下三个步骤1定义数据结构在一个.proto文件中使用Protobuf的语法定义好数据结构即”消息“如addressbook.proto内容如下syntax proto3; // 定义一个人的信息 message Person { string name 1; int32 id 2; string email 3; } // 定义整个通讯录repeated 表示可包含多个 Person message AddressBook { repeated Person people 1; }2生成源代码使用Protobuf编译器protoc根据你的 .proto 文件自动生成你所用编程语言的代码这些代码包含了读写该数据结构的方法比如可以使用如下命令生成python语言的代码protoc --python_out. addressbook.proto以上命令会生成addressbook_pb2.py文件里面包含了读写Person和Address的所有类方法。但是由于我的电脑上是直接安装的最新版的protoc 编译器核心版本35.0而使用pip install protobuf安装的Protobuf运行时库版本版本是6.33.6python“语言前缀”是6“核心版本号”是33.6这两个是不同的Protobuf核心版本在使用python protobuf库去直接解析高版本编译器库生成的addressbook_pb2.py时会有版本冲突问题这里使用Python生态自带的grpcio-tools来生成代码它内嵌了一个与当前Python Protobuf库严格配套的protoc使用如下命令安装并生成addressbook_pb2.py文件pip install grpcio-tools python -m grpc_tools.protoc -I. --python_out. addressbook.proto文件内容如下addressbook_pb2.py虽然机器生成的代码不太好看但它确实包含了我们定义的字段。3使用生成的代码在应用程序中直接调用生成的代码来序列化写和反序列化读数据。现在创建main.py文件写入下面的代码。它做了两件事写入序列化一个包含两人的通讯录到磁盘文件再读取反序列化并打印出来。main.py运行效果如下相比XML和JSON等文本格式Protobuf在性能上优势显著可以用一个形象的比喻来理解JSON/XML像一封用自然语言写的书信内容清晰易读但有很多冗余字符如引号、括号、标签体积大传输慢。Protobuf则像一封用密文写的电报体积小、传输快但只有拥有“密码本”即.proto定义文件的人才能看懂。Protobuf是一个为高性能、高效率和跨语言兼容性而生的数据序列化方案非常适合对性能和带宽有严苛要求的系统内部通信。它为了追求极致的体积和速度在二进制流中绝对不会保存任何字段名而是采用的是Tag - [Length] - ValueTLV的紧凑格式1Tag标签/字段编号每一个字段在.proto图纸里都有一个独一无二的数字编号二进制里只存这个编号。2Wire Type传输类型用来告诉解析器这个数据是个变长整数Varint、还是个有固定长度的字符串/嵌套块Tag和Wire Type会被打包成一个字节称为 Field Key。Field Key (Tag 3) | Wire Type3Value真实数据如果是字符串前面会用一个数字表明它的字节长度Length后面紧跟纯字符的 ASCII 码。Google Protobuf官方的C底层源码库在头文件wire_format_lite.h里定义了6种核心数据物理传输类型Wire Typeclass WireFormatLite { public: enum WireType { WIRETYPE_VARINT 0, // int32, int64, uint32, uint64, sint32, sint64, bool, enum WIRETYPE_FIXED64 1, // fixed64, sfixed64, double固定8字节 WIRETYPE_LENGTH_DELIMITED 2, // string, bytes, 嵌套的 message, packed repeated fields WIRETYPE_START_GROUP 3, // 已废弃的老式分组 WIRETYPE_END_GROUP 4, // 已废弃的老式分组 WIRETYPE_FIXED32 5, // fixed32, sfixed32, float固定 4 字节 }; };int64在内存里明明是固定占8个字节64位的长整型为什么把它分配给Varint变长整数这正是Protobuf极度高产和聪明的核心算法Varint (Variable-length quantity)。传统的存法如果你存一个很小的数字比如数字6在传统的二进制里使用int64也必须雷打不动地占满8个字节的内存也就是一堆00 00 00 00 00 00 00 06白白浪费空间。Protobuf的存法当Wire Type被识别为0时它会启动Varint编码每个字节的最高位MSB作为标志位后面7位存数字。如果是数字6它在二进制流里只占用1个字节即06就存完了只有当你的数字巨大无比比如好几十亿时它才会慢慢扩充最多扩充到10个字节。下面以上述实例序列化后的二进制数据为例来说明解析过程二进制数据如下解析时结合addressbook.proto解码文件先从最外套娃AddressBook开始然后再进一步解析内部嵌套结构Person。10A 21开启第一个Person容器0A0000 101000001解构为字段编号1对应AddressBook.peopleWire Type为010WIRETYPE_LENGTH_DELIMITED2表示repeated fields21换算成十进制是33意味着第一个人占用了接下来的33个字节。20A 06 E5 BC A0 E4 B8 89解析名字 张三0A0000 1010解构为Person.name编号1Wire Type 206字符串长度为6个字节数据段E5 BC A0 E4 B8 89在UTF8编码中一个汉字占3个字节E5 BC A0 张E4 B8 89 三310 E9 07解析 ID 100110解构为Person.id编号2Wire Type 0E9 07典型的Varint变长整数编码E9对应二进制1110 1001最高位是1表示其后还有字节有效数据位是后7位110 100107对应二进制0000 0111最高位0代表结束有效数据是后7位000 0111低位在前把000 0111和110 1001进行拼接得到0000 0011 1110 1001换算成十进制正好是100141A 14 7A 68 ... 63 6F 6D解析邮箱1A解构为Person.email编号3Wire Type 214十六进制的14换算成十进制是20代表邮箱长20个字节数据段7A 68 61 6E 67 73 61 6E 40 65 78 61 6D 70 6C 65 2E 63 6F 6D 转换成 ASCII 码对应字符串zhangsanexample.com按同样流程可以解析出第2个Person此外还可以通过命令protoc --decodeAddressBook addressbook.proto addressbook.data让protoc帮助完成解析2 onnx结构onnxOpen Neural Network Exchange作为模型界的“通用翻译官”是由微软、Meta等联合发起的开放标准。它把模型结构抽象成一张通用的计算图让你可以把PyTorch训练的模型无缝转换到TensorFlow或其他推理引擎中运行。ONNX文件的本质是一个通过Protocol BuffersProtobuf序列化后的二进制文件。如果你去解构一个.onnx文件它的核心组成部分符合严格的层次结构以下对主要消息体进行介绍1顶层容器ModelProto这是整个.onnx文件的根节点相当于总指挥部包含以下主要内容。ir_version该模型对应的ONNX中间表示版本当前最高已演进至0x000000000000000D即IR Version 13opset_import一个列表声明了模型依赖的算子集版本OperatorSetIdProtograph核心计算图GraphPrototraining_info存放训练或微调阶段的梯度和优化器信息TrainingInfoProtofunctions模型本地定义的函数列表FunctionProtoconfiguration该字段专门用于多设备部署场景它允许一个 ONNX 模型文件预先定义好多种硬件部署方案DeviceConfigurationProto此外该消息体还包括producer_nameproducer_versiondomainmodel_version等一些字符串元数据。2核心大脑GraphProto定义了网络中所有的数据和计算节点形成一个有向无环图Directed Acyclic GraphDAG主要元素如下node计算节点列表NodeProto必须按拓扑顺序排列每个节点指定了它使用的算子类型、输入参数名、输出参数名以及属性Attributes如卷积的stride、paddinginitializer存放静态权重如权重矩阵、偏置的常规张量列表TensorProtosparse_initializer稀疏矩阵格式的静态权重SparseTensorProtoinput/output网络的输入和输出边界定义ValueInfoProto3计算单元NodeProto网络中的“层”或算子例如一个卷积层或激活层主要内容如下op_type算子名称例如ConvReluMatMulinput/output字符串数组ONNX就是通过匹配不同节点的输入输出字符串名称来确立节点间的连接线的attribute该算子的静态配置参数列表AttributeProto如卷积的strides4属性参数AttributeProto由于Protobuf没有原生的联合体UnionONNX在这里用了一个绝妙的设计通过一个枚举AttributeType配合一堆optional字段实现了Union联合体的等价类型message AttributeProto { // 1. 先用一个枚举来标记“当前类型” optional AttributeType type 20; // 2. 下面这一堆字段在实际使用时【有且只能有一个】有值 optional float f 2; // 如果是浮点数就存这里 optional int64 i 3; // 如果是整数就存这里 optional bytes s 4; // 如果是字符串就存这里 optional TensorProto t 5; // 如果是高维张量/权重就存这里 optional GraphProto g 6; // 如果是子计算图就存这里 // 下面是对应的数组复数形式 repeated float floats 7; repeated int64 ints 8; // ... }单数类型字段ffloatiint64sbytes/stringtTensorProtogGraphProto/子图复数类型字段floatsintsstringstensorsgraphs5数据载体TensorProto专门用来存多维数组和权重数据的底层实体。dims形状/维度例如 [64, 3, 3, 3]data_type极其丰富的低精度/高精度数据类型枚举。可以看到最新版已经扩充了FLOAT8INT4UINT4甚至是2位的INT2/UINT2数据存储位置 (data_location)DEFAULT(0)数据直接压缩塞在当前二进制文件里通过float_data、int32_data或原生的raw_data存储EXTERNAL(1)大模型专用的外部权重解耦机制。数据不塞在.onnx里而是记录在external_data键值对中指明外部.data文件的路径、偏移量offset通常推荐4096字节对齐以支持mmap和数据长度6类型描述TypeProto用来规定某个张量或者变量在运行时的数据形态。value使用了一个oneof语法它可以是tensor_type: 包含基础元素类型elem_type和形状描述TensorShapeProtosequence_type/map_type/optional_type用于支持传统机器学习或非张量的数据流动规范中还有其他一些消息体这里不再进行详细介绍更多内容请查看onnx.proto3文件。3 导出onnx可以通过如下代码将生成的mnist_cnn.pth导出为onnx格式的文件pth2onnx.py程序运行输出如下可见程序不仅导出了onnx模型还使用测试图片对onnx模型和pth模型的预测结果进行了对比从结果看onnx模型和原始pth模型的预测结果仅有很小误差可以通过在线网站netron来图形化显示onnx的结构还可以通过如下命令导出完整DAG结构protoc --decodeonnx.ModelProto onnx.proto3 mnist_cnn.onnx mnist_cnn_text.txt当然需要提前下载onnx.proto3文件并将其和mnist_cnn.onnx模型文件放到同一目录下下面结合二进制模型文件和onnx.proto3文件说明下解析产生txt结构文件的关键点。从二进制文件开头08 06 12 ... 38 2E 30的18字节数据可以解析出ModelProto消息的ir_version、producer_name和producer_version成员该部分解析比较简单不再详述。接下来3A 0011 1010高5位数值7对应GraphProto类型编号底层类型2对应嵌套message结构接下来B6 88 67这3个字节解析后是整个graph的大小1688630由于Graph包含多个节点和相关权重数据所以它的字节数较大对应该实例来说整个onnx文件大小是1688656字节头部184占22字节尾部4个字节是opset_import算子集版本其他所有数据都对应Graph内容。大图长度交代完后正式跨入GraphProto内部接下来流里碰到了0A0000 1010在GraphProto中字段编号为1WireType为2对应repeated NodeProto node接下来是第一个算子节点C8 01是第一个NodeProto算子节点的总长度Varint编码C8100100001000 0001拼接二进制对应十进制200即后续200个字节对应第一个node节点内容。其他内容解析不再详解说明这里重点说一下2A 12 0A 09 64 69 6C 61 74 69 6F 6E 73 40 01 40 01 A0 01 07对应AttributeProto message部分的解析首先2A35是字段编号2A72对应WireType在NodeProto中字段5对应repeated AttributeProto attribute12是长度即18字节接下来0A 09对应属性名“dilations”接下来的40 01 40 01对应两个ints值为1之后A0 01 07有点特殊在解析A0发现字段编号超过15所以这里Tag实际上是双字节Varint编码所以要和01一起进行解析A0砍掉开头的1剩下7位010 000001砍掉开头的0剩下7位000 0001拼接后二进制为10100000字段编号是20WireType为0对应enum AttributeType类型type其值为7AttributeProto在onnx.proto3文件中定义如下View Code之后二进制内容解析类似这里不再详述。3 onnx模型预测过程详解之前已经就onnx模型解析进行了详细分析但是由于之前模型输入都是float型为了方便验证分析模型预测过程这里通过ONNX库手动构建ONNX神经网络图并通过该网络模型来详细分析预测过程。首先给出模型生成的源码sampleU8.py运行该python程序即可导出模型文件resize_conv_addU8.onnx用netron打开该模型可视化图像如下图中输出节点conv_input_u8和conv_output是为了调试而故意进行的输出正常模型保持output一个输出节点即可。首先对Resize节点进行分析它的输入是1x3x8x8 UINT8型inputFLOAT型roi以及Shape为[4]的FLOAT型scales输出是1x3x16x16 UINT8型conv_input_u8。这里input为输入的特征图Tensor代表一张Batch size为1、3 通道RGB、大小为8x8的图像它是被缩放的数据源roi全称是Region of Interest即“感兴趣区域”简单来说它的作用是告诉算子只放大/缩小原图中的某一个“局部裁剪区域”而不是整张图它是一个一维浮点数Tensor格式通常是 [start_r1, start_r2, ..., end_r1, end_r2, ...]用来指定一个高维边界框的起始和结束坐标归一化到 0-1 之间初始化为[]说明要针对整张图像进行全局缩放不需要裁剪局部即选择整张原图进行缩放scales是缩放比例因子是一个一维浮点数数组它的形状是[4]对应输入Tensor的4个维度[N, C, H, W]由输出形状可以确定scales静态值实际上应该是[1.0, 1.0, 2.0, 2.0]经过Resize放到后正好得到[1, 3, 16, 16]输出特征图Tensor。在实际运行调试时会分别对行列数据进行扩展复制对应数据如下因为在onnx库中卷积对应的输入需要时FLOAT型所以这里直接使用一个Cast节点对UINT8型输入强制类型转换为FLOAT型输出conv_input_f32这正是Conv节点的输入另外Conv还有两个输入[32, 3, 3, 3] FLOAT型输入权重conv_weight和[32] FLOAT型输入偏置conv_bias并产生[1, 32, 16, 16] FLOAT型输出conv_output。这里conv_weight是卷积核权重也称滤波器32是输出通道意味着有32个不同的滤波器同时去提取特征3是输入通道必须与conv_input_f32的通道数严格对齐每个卷积核内部都有3层最后的3, 3对应卷积核的尺寸Kernel Size即高3x宽3的滑窗conv_bias偏置形状[32]表示每一个输出通道卷积核配一个偏置常数做完矩阵乘法后要加上这个值。此外Conv算子还有其他一些属性attribute如stridespads等。其中strides表示kernel窗口滑动步长默认值为1即每次做完一次卷积滑动1个像素pads表示是否对输入特征图四周进行0填充对于二维图像H和W来说它的对应顺序是[常规方向的Top, 常规方向的Left, 常规方向的Bottom, 常规方向的Right]之所以进行填充是为了保证进行完卷积后输出特征图可以保持原始尺寸以标准的二维卷积为例其输出尺寸公式为如果不进行填充则输出高和宽都为(16-300)/1114而如果上下左右都做1像素填充则输出高和宽为(16-311)/1116可见能保持输入尺寸。实例中正是进行了pads[1, 1, 1, 1]的填充则实际卷积过程如下所示图中输入特征图的3个通道都进行了填充之后每个通道和卷积核的3个通道依次做矩阵乘法3个结果相加再加上相应的偏置输出相应的加权和结果之后再在W和H依次滑动kernel窗口即可获得第一个卷积核对应的16x16输出再利用其余卷积核最终可获得32x16x16输出特征图。最后的Add节点非常简单就是执行数学上的加法A B C但在深度学习和ONNX的底层它触发了一个非常经典的矩阵计算机制——广播机制Broadcasting。如果按照严格的线性代数规则两个矩阵相加它们的形状Shape必须完全一模一样。但是在这里一个形状是[1, 32, 16, 16]另一个是[1]无法直接进行相加这里ONNX会启动广播机制完成相加过