深度学习笔记--修改替换Pytorch权重文件的Key值
创始人
2025-05-31 15:56:50
0

目录

1--前言

2--问题描述

2--代码

3--测试


1--前言

        最近复现一篇 Paper,需要使用预训练的模型,但预训练模型和自定义模型的 key 值不匹配,导致无法顺利加载预训练权重文件;

2--问题描述

        需要使用的预训练模型如下:

import torchif __name__ == "__main__":weights_files = './joint_model_stgcn.pt' # 权重文件路径weights = torch.load(weights_files) # 加载权重文件for k, v in weights.items(): # key, valueprint(k)  # 打印 key(参数名)

        原权重文件的 key 值如下:

A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance.0
edge_importance.1
edge_importance.2
edge_importance.3
edge_importance.4
edge_importance.5
edge_importance.6
edge_importance.7
edge_importance.8
edge_importance.9

fcn.weight
fcn.bias

        需求是修改以下 key 值,以适配自定义模型:

edge_importance.0 -> edge_importance0
edge_importance.1 -> edge_importance1
edge_importance.2 -> edge_importance2
edge_importance.3 -> edge_importance3
edge_importance.4 -> edge_importance4
edge_importance.5 -> edge_importance5
edge_importance.6 -> edge_importance6
edge_importance.7 -> edge_importance7
edge_importance.8 -> edge_importance8
edge_importance.9 -> edge_importance9

2--代码

        基于原权重文件,利用 collections.OrderedDict() 创建新的权重文件:

import torch
import collectionsif __name__ == "__main__":# 加载原权重文件weights_files = './joint_model_stgcn.pt'weights = torch.load(weights_files)# 修改new_d = weightsfor i in range(10):new_d = collections.OrderedDict([('edge_importance'+str(i), v) if k == 'edge_importance.'+str(i) else (k, v) for k, v in new_d.items()])# 测试for k, v in new_d.items(): # key, valueprint(k)  # 打印参数名# 保存torch.save(new_d, 'new_joint_model_stgcn.pt')

        修改后的 key 值:

A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance0
edge_importance1
edge_importance2
edge_importance3
edge_importance4
edge_importance5
edge_importance6
edge_importance7
edge_importance8
edge_importance9

fcn.weight
fcn.bias

3--测试

        测试原权重文件和新权重文件的 value 是否相同:

import torchif __name__ == "__main__":origin_weights_files = './joint_model_stgcn.pt'origin_weights = torch.load(origin_weights_files)new_weights_files = './new_joint_model_stgcn.pt'new_weights = torch.load(new_weights_files)print(origin_weights['A'] == new_weights['A'])print(origin_weights['edge_importance.0'] == new_weights['edge_importance0'])

相关内容

热门资讯

豪掷23亿!追觅创始人俞浩入主... 出品|达摩财经12月16日晚,嘉美包装(002969.SZ)发布公告称,公司控股股东中包香港与逐越鸿...
嘎巴一下,就死了 图: gothic_jang闲聊聊。上次聊照护老人的话题,一位读者留言引起了热议:“丁香医生说坚持健...
华海财险股权变局:曾借虚假股东... 专为保险业 打造的垂直新媒体平台华海财产保险股份有限公司(下称“华海财险”)原始股东或将回归。日前,...
富泽人寿获批开业:170亿“身... 专为保险业 打造的垂直新媒体平台君康人寿风险处置的靴子,正式落地。今日,金融监督管理总局山东监管局发...
001330,今日跌停 2025.12.16本文字数:776,阅读时长大约2分钟作者 |第一财经 揭书宜12月16日,博纳影...
杨瀚森G联赛首发出战砍18分1... 北京时间12月16日,G联赛撕裂之城混音队主场迎战斯托克顿国王队,本场比赛杨瀚森首发出战。此役,开拓...
大陆集团三季度净利润4.86亿... 11月11日消息,大陆集团发布2024年第三季度业绩:销售额为98亿欧元,同比减少4.0%;调整后的...
印度据悉将传唤亚马逊和Flip... 11月11日消息,据报道,印度政府一位高级消息人士表示,印度金融犯罪机构将传唤印度最大电子商务零售商...
中兴通讯涨停创阶段新高,三机构... 11月11日消息,中兴通讯(000415.SZ)今日涨停,股价创2023年8月以来新高,成交额88....
工信部印发《重点工业产品碳足迹... 11月11日消息,为加快提升重点工业产品碳排放管理水平,促进行业绿色低碳转型,支撑实现碳达峰碳中和目...