深度学习笔记--修改替换Pytorch权重文件的Key值
创始人
2025-05-31 15:56:50
0

目录

1--前言

2--问题描述

2--代码

3--测试


1--前言

        最近复现一篇 Paper,需要使用预训练的模型,但预训练模型和自定义模型的 key 值不匹配,导致无法顺利加载预训练权重文件;

2--问题描述

        需要使用的预训练模型如下:

import torchif __name__ == "__main__":weights_files = './joint_model_stgcn.pt' # 权重文件路径weights = torch.load(weights_files) # 加载权重文件for k, v in weights.items(): # key, valueprint(k)  # 打印 key(参数名)

        原权重文件的 key 值如下:

A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance.0
edge_importance.1
edge_importance.2
edge_importance.3
edge_importance.4
edge_importance.5
edge_importance.6
edge_importance.7
edge_importance.8
edge_importance.9

fcn.weight
fcn.bias

        需求是修改以下 key 值,以适配自定义模型:

edge_importance.0 -> edge_importance0
edge_importance.1 -> edge_importance1
edge_importance.2 -> edge_importance2
edge_importance.3 -> edge_importance3
edge_importance.4 -> edge_importance4
edge_importance.5 -> edge_importance5
edge_importance.6 -> edge_importance6
edge_importance.7 -> edge_importance7
edge_importance.8 -> edge_importance8
edge_importance.9 -> edge_importance9

2--代码

        基于原权重文件,利用 collections.OrderedDict() 创建新的权重文件:

import torch
import collectionsif __name__ == "__main__":# 加载原权重文件weights_files = './joint_model_stgcn.pt'weights = torch.load(weights_files)# 修改new_d = weightsfor i in range(10):new_d = collections.OrderedDict([('edge_importance'+str(i), v) if k == 'edge_importance.'+str(i) else (k, v) for k, v in new_d.items()])# 测试for k, v in new_d.items(): # key, valueprint(k)  # 打印参数名# 保存torch.save(new_d, 'new_joint_model_stgcn.pt')

        修改后的 key 值:

A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance0
edge_importance1
edge_importance2
edge_importance3
edge_importance4
edge_importance5
edge_importance6
edge_importance7
edge_importance8
edge_importance9

fcn.weight
fcn.bias

3--测试

        测试原权重文件和新权重文件的 value 是否相同:

import torchif __name__ == "__main__":origin_weights_files = './joint_model_stgcn.pt'origin_weights = torch.load(origin_weights_files)new_weights_files = './new_joint_model_stgcn.pt'new_weights = torch.load(new_weights_files)print(origin_weights['A'] == new_weights['A'])print(origin_weights['edge_importance.0'] == new_weights['edge_importance0'])

相关内容

热门资讯

A股玻尿酸巨头出手!2700字... 医美龙头巨子生物“成分争议”风波持续发酵。日前,美妆博主大嘴博士(香港大学化学博士郝宇)发文,质疑巨...
计算机组成原理实验1---运算...     本实验为哈尔滨工业大学计算机组成原理实验,实验内容均为个人完成,...
3 ROS1通讯编程提高(1) 3 ROS1通讯编程提高3.1 使用VS Code编译ROS13.1.1 VS Code的安装和配置...
前端-session、jwt 目录:   (1)session (2&#x...
前端学习第三阶段-第4章 jQ... 4-1 jQuery介绍及常用API导读 01-jQuery入门导读 02-JavaScri...
EL表达式JSTL标签库 EL表达式     EL:Expression Language 表达式语言     ...
数字温湿度传感器DHT11模块... 模块实例https://blog.csdn.net/qq_38393591/article/deta...
【内网安全】 隧道搭建穿透上线... 文章目录内网穿透-Ngrok-入门-上线1、服务端配置:2、客户端连接服务端ÿ...
【Spring Cloud A... 文章目录前言Metadata元数据ClassMetadataSpring中常见的一些元注解Nacos...
React篇-关于React的... 一.简介1.介绍用于构建用户界面的 JavaScript 库2.创建项目(1)手动创建Documen...
win7 Pro 英文版添加中... win7pro x64英文版添加中文语言包1、下载语言包,并解压成lp.cab,复制到...
Android开发-Andro... 01  Android UI 1.1  UI 用户界面(User Interface,...
基于springboot教师人... 基于springboot教师人事档案管理系统【源码+论文】 开发语言:Jav...
编写软件界面的方式 本文重点解决如下问题:编写软件的界面有哪几种方式?通常情形下࿰...
keil调试专题篇 调试的前提是需要连接调试器比如STLINK。 然后点击菜单或者快捷图标均可进入调试模式。 如果前面...
GO语言小锤硬磕十三、数组与切... 数组用来保存一组相同类型的数据,go语言数组也分一维数组和多维数组。 直接上代码看一下...
三级数据库备考--数据库应用系... 1.数据库应用系统设计包括概念设计、逻辑设计、物理设计3个步骤,每个步骤的设计活动按照...
prometheus数据持久化... https://segmentfault.com/a/1190000015710814 promet...
孩子用什么样的灯对眼睛没有伤害... 现代社会高速发展,越来越多的人开始重视身体健康,尤其是很多家长ÿ...
微软Bing GPT支持AI绘... 我想要一张图片:大象、珊瑚、火山、云朵我想要一张图片:亚特兰蒂斯...