当前位置: 首页 > news >正文

深度学习(pytorch量化)

pytorch中的动态量化和静态量化是两种主要的模型量化技术,旨在通过使用低精度数据类型(如 int8)代替高精度数据类型(如 float32)来减小模型大小并加速推理。

动态量化:在模型运行时(推理时)动态计算激活(activations)的量化参数(scale 和 zero_point)。权重(weights)的量化通常在模型加载时或第一次运行前进行。

静态量化:在模型部署之前,使用一个代表性的校准数据集(Calibration Dataset)预先确定网络中所有权重和所有激活的量化参数(scale 和 zero_point)。这些参数在推理过程中是固定的(静态的)。

部署时通常静态量化比较常用一些。下面给了两个例子,可以验证一下。

动态量化:

import torch
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore")
# torch.serialization.add_safe_globals([torch.ScriptObject])class SimpleLSTM(nn.Module):"""简单的LSTM模型,适合动态量化"""def __init__(self, input_size=10, hidden_size=50, num_layers=2, output_size=15):super().__init__()self.hidden_size = hidden_sizeself.num_layers = num_layers# LSTM层self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True)# 全连接层self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):# 初始化隐藏状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)# LSTM前向传播out, _ = self.lstm(x, (h0, c0))# 只取最后一个时间步的输出out = self.fc(out[:, -1, :])return outdef save_fp32_model(model_fp32, x):model_fp32.eval()y = model_fp32(x)print("FP32模型输出:", y)torch.save(model_fp32.state_dict(), 'model_fp32.pth')def load_fp32_model(x):model_fp32 = SimpleLSTM()model_fp32.load_state_dict(torch.load('model_fp32.pth'))model_fp32.eval()y_fp32 = model_fp32(x)print("加载的FP32模型输出:", y_fp32)return model_fp32def save_int8_model(model_fp32, x):model_int8 = torch.quantization.quantize_dynamic(model_fp32,{nn.LSTM,nn.Linear},dtype=torch.qint8)model_int8.eval()y_int8 = model_int8(x)print("INT8模型输出:", y_int8)torch.save(model_int8.state_dict(), 'model_int8.pth')def load_int8_model(x):model_fp32 = SimpleLSTM()model_int8 = torch.quantization.quantize_dynamic(model_fp32,{nn.LSTM,nn.Linear},dtype=torch.qint8)model_int8.load_state_dict(torch.load('model_int8.pth',weights_only=False))model_int8.eval()y_int8 = model_int8(x)print("加载的INT8模型输出:", y_int8)return model_int8if __name__ == '__main__':x = torch.randn(1, 10, 10)model_fp32 = SimpleLSTM()save_fp32_model(model_fp32,x)save_int8_model(model_fp32,x)load_fp32_model(x)load_int8_model(x)

静态量化:

import torch
import numpy as np
import warnings
warnings.filterwarnings("ignore")class Model(torch.nn.Module):def __init__(self):super().__init__()# QuantStub converts tensors from floating point to quantizedself.quant = torch.ao.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 100, 1)self.conv1 = torch.nn.Conv2d(100, 100, 1)self.conv2 = torch.nn.Conv2d(100, 100, 1)self.conv3 = torch.nn.Conv2d(100, 1, 1)self.relu1 = torch.nn.ReLU()self.relu2 = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating pointself.dequant = torch.ao.quantization.DeQuantStub()def forward(self, x):x = self.quant(x)x = self.conv(x)x = self.conv1(x)x = self.relu1(x)x = self.conv2(x)x = self.relu2(x)x = self.conv3(x)x = self.dequant(x)return xdef save_fp32_model(model_fp32,x):model_fp32.eval()y = model_fp32(x)print("FP32模型输出:", y)torch.save(model_fp32.state_dict(), 'model_fp32.pth')torch.onnx.export(model_fp32,x,'model_fp32.onnx',input_names=['input'],output_names=['output'])def load_fp32_model(x):model_fp32 = Model()model_fp32.load_state_dict(torch.load('model_fp32.pth'))model_fp32.eval()y_fp32 = model_fp32(x)print("加载的FP32模型输出:", y_fp32)return model_fp32def save_int8_model(model_fp32,x):model_fp32.eval()  model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv1', 'relu1'], ['conv2', 'relu2']])model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)#calibration
    with torch.no_grad():  for i in range(100):  input_data = torch.randn(1, 1, 4, 4)          model_fp32_prepared(input_data)model_int8 = torch.ao.quantization.convert(model_fp32_prepared)model_int8.eval()  y_int8 = model_int8(x)print("INT8模型输出:", y_int8)torch.save(model_int8.state_dict(), 'model_int8.pth')torch.onnx.export(model_int8,x,'model_int8.onnx',input_names=['input'],output_names=['output'])def load_int8_model(x):model_fp32 = Model()model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv1', 'relu1'], ['conv2', 'relu2']])model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)model_int8 = torch.ao.quantization.convert(model_fp32_prepared)model_int8.load_state_dict(torch.load('model_int8.pth'))model_int8.eval()y_int8 = model_int8(x)print("加载的INT8模型输出:", y_int8)return model_int8if __name__ == '__main__':x = np.array([[0.1,0.2,0.3,0.4],[0.5,0.6,0.7,0.8],[0.9,0.1,0.2,0.3],[0.4,0.5,0.6,0.7]], dtype=np.float32)x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)  model_fp32 = Model()save_fp32_model(model_fp32,x)save_int8_model(model_fp32,x)load_fp32_model(x)load_int8_model(x)
http://www.vanclimg.com/news/690.html

相关文章:

  • 在运维工作中,Docker怎么清理容器磁盘空间?
  • 生成函数
  • CVE-2021-45232 Apache APISIX Dashboard身份验证绕过漏洞 (复现)
  • 在运维工作中,如果运行的一个容器突然挂了,如何排查?
  • IIS中配置HTTPS证书的详细步骤
  • 李超线段树
  • 非常值得学习渲染入门的一个教程
  • Linux开机自动登录的一种方法
  • 7月28日
  • 2025 ZR暑假集训 CD联考 Day2 E 环球旅行
  • zk后集训
  • 乘法逆元(部分施工)、exgcd
  • 夏令营Ⅲ期
  • 集成学习算法
  • K 近邻算法
  • 二叉树 (动态规划)
  • 1 引言(1.1 - 1.5)
  • goethereum-账户 - Charlie
  • Qt播放音频,支持进度条,设置语速,播放暂停
  • 使用监督学习训练图像聚类模型
  • java第二十八天
  • P2910 [USACO08OPEN] Clear And Present Danger S (Floyd算法)
  • 读《构建之法》:我的C/C++学习反思
  • 软工7.28
  • DE_aemmprty 题单合集(分类)
  • 《大道至简——软件工程实践者的思想》读后感
  • C++对象模型
  • 子串的故事(2) - 2025“钉耙编程”中国大学生算法设计暑期联赛(2)T4 题解
  • 【比赛记录】2025CSP-S模拟赛28
  • Apereo CAS 4.1 反序列化命令执行漏洞 (复现)