当前位置: 首页 > news >正文

深度学习(onnx量化)

onnx中的动态量化和静态量化概念与pytorch中的核心思想一致,但实现工具、流程和具体api有所不同。

onnx量化通常依赖onnxrunntime来执行量化模型,并使用onnx工具库进行模型转换。

除了pytorch量化和onnx量化,实际工作中一般像英伟达、地平线、昇腾等不同的芯片都会有各自独特的工具链和加速算子,按照官方教程使用即可。

下面同样给了两个例子,可以验证一下。结合上篇代码可以做个对比。

动态量化:

import torch
import torch.nn as nn
import warnings
import numpy as np
import onnxruntime as ortwarnings.filterwarnings("ignore")
from onnxruntime.quantization import QuantType, quantize_dynamicclass SimpleLSTM(nn.Module):"""简单的LSTM模型,适合动态量化"""def __init__(self, input_size=10, hidden_size=50, num_layers=2, output_size=15):super().__init__()self.hidden_size = hidden_sizeself.num_layers = num_layers# LSTM层self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True)# 全连接层self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):# 初始化隐藏状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)# LSTM前向传播out, _ = self.lstm(x, (h0, c0))# 只取最后一个时间步的输出out = self.fc(out[:, -1, :])return outdef quantize_onnx_model(input_model_path, output_model_path):quantize_dynamic(input_model_path,output_model_path,weight_type=QuantType.QInt8    )print(f"ONNX模型已动态量化并保存到: {output_model_path}")if __name__ == '__main__':model = SimpleLSTM()x = torch.randn(1, 10, 10)  # 假设输入
torch.onnx.export(model,                  # model being runx,                      # model input (or a tuple for multiple inputs)"simple_lstm.onnx",     # where to save the modelexport_params=True,     # store the trained parameter weights inside the model fileopset_version=12,       # the ONNX version to export the model toinput_names = ['input'],   # the model's input namesoutput_names = ['output'])quantize_onnx_model("simple_lstm.onnx", "simple_lstm_quantized.onnx")# 测试ONNX模型和量化后的模型x = np.random.randn(1, 10, 10).astype(np.float32)  # 假设输入ort_session = ort.InferenceSession("simple_lstm.onnx")ort_session_quantized = ort.InferenceSession("simple_lstm_quantized.onnx")inputs = {ort_session.get_inputs()[0].name: x}outputs = ort_session.run(None, inputs)print("ONNX模型输出:\n", outputs[0])inputs = {ort_session_quantized.get_inputs()[0].name: x}outputs_quantized = ort_session_quantized.run(None, inputs)print("动态量化后的ONNX模型输出:\n", outputs_quantized[0])

 静态量化:

import torch
import numpy as np
import warnings
import onnx
import onnxruntime as ort
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static, CalibrationDataReader
warnings.filterwarnings("ignore")class Model(torch.nn.Module):def __init__(self):super().__init__()# QuantStub converts tensors from floating point to quantized# self.quant = torch.ao.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 100, 1)self.conv1 = torch.nn.Conv2d(100, 100, 1)self.conv2 = torch.nn.Conv2d(100, 100, 1)self.conv3 = torch.nn.Conv2d(100, 1, 1)self.relu1 = torch.nn.ReLU()self.relu2 = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating point#  self.dequant = torch.ao.quantization.DeQuantStub()def forward(self, x):# x = self.quant(x)x = self.conv(x)x = self.conv1(x)x = self.relu1(x)x = self.conv2(x)x = self.relu2(x)x = self.conv3(x)#  x = self.dequant(x)return x# 1. 准备校准数据集类
class CustomCalibrationDataReader(CalibrationDataReader):def __init__(self, calibration_data_path, input_name):"""初始化校准数据读取器参数:calibration_data_path: 校准数据.npz文件路径input_name: 模型输入名称"""self.data = np.load(calibration_data_path)self.input_name = input_nameself.datasize = len(self.data.files[0])self.enum_data = iter(self.data[self.data.files[0]])def get_next(self):"""获取下一批校准数据"""try:batch = next(self.enum_data)return {self.input_name: np.expand_dims(batch, axis=0)}except StopIteration:return Nonedef rewind(self):"""重置数据迭代器"""self.enum_data = iter(self.data[self.data.files[0]])# 2. 主量化函数
def quantize_onnx_model_static(original_model_path, quantized_model_path, calibration_data_path):"""执行ONNX模型静态量化参数:original_model_path: 原始FP32模型路径quantized_model_path: 量化后模型保存路径calibration_data_path: 校准数据集路径(.npz格式)"""# 加载原始模型model = onnx.load(original_model_path)# 获取模型输入名称input_name = model.graph.input[0].name# 创建校准数据读取器calibration_data_reader = CustomCalibrationDataReader(calibration_data_path, input_name)quantize_static(model_input=original_model_path,model_output=quantized_model_path,calibration_data_reader=calibration_data_reader,quant_format=QuantFormat.QDQ ,  # QDQ 或 QOperatorper_channel=True,               # 每通道量化reduce_range=True,              # 减少量化范围(某些CPU需要)activation_type=QuantType.QInt8,  # 激活量化类型weight_type=QuantType.QInt8,      # 权重量化类型
    )print(f"量化完成!量化模型已保存至: {quantized_model_path}")# 3. 辅助函数:生成校准数据集
def generate_calibration_data(output_path, num_samples=100):"""生成校准数据集参数:output_path: 校准数据保存路径(.npz)num_samples: 生成样本数量"""        # 创建随机输入数据 (根据实际模型调整)calibration_data = []for _ in range(num_samples):data = np.random.randn(1,4,4).astype(np.float32)  # 生成随机数据
        calibration_data.append(data)# 保存为.npz文件np.savez(output_path, calibration_data=np.array(calibration_data))print(f"已生成 {num_samples} 个校准样本到: {output_path}")# 4. 使用示例
if __name__ == "__main__":MODEL_FP32 = 'model_fp32.onnx'MODEL_INT8 = 'model_int8.onnx'model_fp32 = Model()x = torch.randn(1, 1, 4, 4)  # 假设输入
    torch.onnx.export(model_fp32,x,MODEL_FP32,input_names=['input'],output_names=['output'])# 步骤1: 生成校准数据 (如果已有数据可跳过)generate_calibration_data("calibration_data.npz", num_samples=100)# 步骤2: 执行静态量化
    quantize_onnx_model_static(original_model_path=MODEL_FP32 ,quantized_model_path=MODEL_INT8,calibration_data_path="calibration_data.npz")# 步骤3: 验证量化模型 (可选)# 加载量化模型x = np.random.randn(1, 1, 4, 4).astype(np.float32)  # 假设输入ort_session = ort.InferenceSession(MODEL_FP32)ort_session_quantized = ort.InferenceSession(MODEL_INT8)inputs = {ort_session.get_inputs()[0].name: x}outputs = ort_session.run(None, inputs)print("ONNX模型输出:\n", outputs[0])inputs = {ort_session_quantized.get_inputs()[0].name: x}outputs_quantized = ort_session_quantized.run(None, inputs)print("静态量化后的ONNX模型输出:\n", outputs_quantized[0])
http://www.vanclimg.com/news/709.html

相关文章:

  • Redisson
  • uni-app项目跑APP报useStore报错
  • P13493 【MX-X14-T3】心电感应 题解
  • DE_aemmprty 草稿纸合集
  • 题解:P13308 故障
  • mmap提高LCD显示效率
  • 用 Python 构建可扩展的验证码识别系统
  • Java学习Day28
  • 在运维工作中,Dockerfile中常见指令有哪些?
  • 英语_阅读_Rivers are important in culture_单词_待读
  • 题解:P12151 【MX-X11-T5】「蓬莱人形 Round 1」俄罗斯方块
  • 在运维工作中,docker封闭了哪些资源?
  • SciTech-EECS-Library: img2pdf 与 pdf2image : Python 的 pdf 与 image 双向转换库
  • 深度学习(pytorch量化)
  • 在运维工作中,Docker怎么清理容器磁盘空间?
  • 生成函数
  • CVE-2021-45232 Apache APISIX Dashboard身份验证绕过漏洞 (复现)
  • 在运维工作中,如果运行的一个容器突然挂了,如何排查?
  • IIS中配置HTTPS证书的详细步骤
  • 李超线段树
  • 非常值得学习渲染入门的一个教程
  • Linux开机自动登录的一种方法
  • 7月28日
  • 2025 ZR暑假集训 CD联考 Day2 E 环球旅行
  • zk后集训
  • 乘法逆元(部分施工)、exgcd
  • 夏令营Ⅲ期
  • 集成学习算法
  • K 近邻算法
  • 二叉树 (动态规划)