目录
1--前言
2--问题描述
2--代码
3--测试
最近复现一篇 Paper,需要使用预训练的模型,但预训练模型和自定义模型的 key 值不匹配,导致无法顺利加载预训练权重文件;
需要使用的预训练模型如下:
import torchif __name__ == "__main__":weights_files = './joint_model_stgcn.pt' # 权重文件路径weights = torch.load(weights_files) # 加载权重文件for k, v in weights.items(): # key, valueprint(k) # 打印 key(参数名)
原权重文件的 key 值如下:
A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance.0
edge_importance.1
edge_importance.2
edge_importance.3
edge_importance.4
edge_importance.5
edge_importance.6
edge_importance.7
edge_importance.8
edge_importance.9
fcn.weight
fcn.bias
需求是修改以下 key 值,以适配自定义模型:
edge_importance.0 -> edge_importance0
edge_importance.1 -> edge_importance1
edge_importance.2 -> edge_importance2
edge_importance.3 -> edge_importance3
edge_importance.4 -> edge_importance4
edge_importance.5 -> edge_importance5
edge_importance.6 -> edge_importance6
edge_importance.7 -> edge_importance7
edge_importance.8 -> edge_importance8
edge_importance.9 -> edge_importance9
基于原权重文件,利用 collections.OrderedDict() 创建新的权重文件:
import torch
import collectionsif __name__ == "__main__":# 加载原权重文件weights_files = './joint_model_stgcn.pt'weights = torch.load(weights_files)# 修改new_d = weightsfor i in range(10):new_d = collections.OrderedDict([('edge_importance'+str(i), v) if k == 'edge_importance.'+str(i) else (k, v) for k, v in new_d.items()])# 测试for k, v in new_d.items(): # key, valueprint(k) # 打印参数名# 保存torch.save(new_d, 'new_joint_model_stgcn.pt')
修改后的 key 值:
A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance0
edge_importance1
edge_importance2
edge_importance3
edge_importance4
edge_importance5
edge_importance6
edge_importance7
edge_importance8
edge_importance9
fcn.weight
fcn.bias
测试原权重文件和新权重文件的 value 是否相同:
import torchif __name__ == "__main__":origin_weights_files = './joint_model_stgcn.pt'origin_weights = torch.load(origin_weights_files)new_weights_files = './new_joint_model_stgcn.pt'new_weights = torch.load(new_weights_files)print(origin_weights['A'] == new_weights['A'])print(origin_weights['edge_importance.0'] == new_weights['edge_importance0'])