반응형
문제
pretrained weight를 불러와 모델에 올릴 때 parameter name이 달라서 오류 발생
model = Model()
model.load_state_dict(torch.load('pretrained_weight.pth'))
오류 메시지
RuntimeError: Error(s) in loading state_dict for Model:
Missing key(s) in state_dict: "backbone.conv1.weight", "backbone.bn1.weight", ... ,
Unexpected key(s) in state_dict: "module.backbone.conv1.weight", "module.backbone.bn1.weight", ...
해결
모델의 parameter name과 saved parameter name을 동일하게 변경
from collections import OrderedDict
new_state_dict = OrderedDict()
state_dict = torch.load('pretrained_weight.pth')
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
저장된 parameter name의 'module.' 에 해당하는 부분을 삭제 하는 코드이다. 일반적으로 사용할 수 있는 코드는 아님.
반응형