Pytorch를 이용하여 Model을 저장하는 방법은 아래와 같습니다.
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# CNN_model 예시
class CNN_model(nn.Module):
#tistory 코드 블럭의 문제인지 indent가 맞질 않습니다...
def __init__(self):
...
...
def forward(self, x):
...
...
model = CNN_model()
# torch.save(model, path_dict_file_name)
# model 전체 저장
torch.save(model, 'model.pt')
# state_dict = 학습 가능한 매개변수(weight, bias)가 담겨있는 딕셔너리
# model의 state_dict만 저장
torch.save(model.state_dict(), 'model_state_dict.pt')
저장된 Model을 불러오는 방법은 아래와 같습니다.
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# torch.load(path_file_name)
# model 전체가 저장된 .pt 파일 로드
model = torch.load('model.pt')
# torch.load(path_dict_file_name)
# state_dict = 학습 가능한 매개변수(weight, bias)가 담겨있는 딕셔너리
# model의 state_dict만 저장된 .pt 파일 로드
# CNN_model 예시
class CNN_model(nn.Module):
#tistory 코드 블럭의 문제인지 indent가 맞질 않습니다...
def __init__(self):
...
...
def forward(self, x):
...
...
model = CNN_model().to(device)
model_dict = torch.load('model_state_dict.pt')
model.load_state_dict(model_dict)
model 자체를 저장한 파일을 불러올때는 'torch.load(path_file_name)' 으로 불러온 뒤 바로 model에 할당해주면 되고
학습 가능한 매개변수(weight, bias)만 저장한 state_dict 파일을 불러오는 경우
'torch.load(path_dict_file_name)'으로 해당 dict를 불러 온 뒤 'model.load_state_dict(dict)'로 로드를 해주면 됩니다.
※ state_dict로 저장 / 불러오는 경우 코드상에 해당 모델의 구조가 구현되어있어야 합니다.
※ '모델을 전체 저장한 파일의 크기' > '모델의 state_dict(매개변수)만 저장한 파일의 크기'
저는 모델 학습 진행 중 일정 Epoch 마다 모델을 저장해두는데 아래와 같이
check point 파일을 불러오는 함수를 구성해두었습니다.
import os
import torch
def load_ckpt(path, filename, extension):
if os.path.exists(path):
list_ckpt = os.listdir(path)
else :
print('load_ckpt : The directory could not be found.')
return False, 0
# path 안에 check point 파일이 여러개 있을때 가장 마지막에 저장된(epoch가 가장 높은)
# check point 파일을 load 하기 위해 '파일 이름' 과 '확장자'를 제거하여
# epoch만 적힌 문자만 남게 합니다.
for i in range(len(list_ckpt)):
list_ckpt[i] = list_ckpt[i].replace(filename,'')
list_ckpt[i] = list_ckpt[i].replace(extension, '')
# 현재 리스트가 문자열로 구성되어있어 int형으로 형변환 후 sort를 진행합니다.
list_ckpt = list(map(int, list_ckpt))
list_ckpt.sort()
if len(list_ckpt) != 0:
# Check point file or dict, epoch
return torch.load(os.path.join(path, filename + str(list_ckpt[-1]) + extension)), list_ckpt[-1]
else:
# Check point file does not exist
return False, 0
위 함수는 아래와 같이 사용할 수 있습니다.
class Generator(nn.Module):
def __init__(self):
...
...
def forward(self, x):
...
...
class Discriminator(nn.Module):
def __init__(self):
...
...
def forward(self, x):
...
...
# check point가 저장된 폴더 이름
ckpt_path = 'GAN_MNIST_CheckPoint'
G = Generator().to(device)
D = Discriminator().to(device)
# path = os.path.join(ckpt_path, "Generator") -> /GAN_MNIST_CheckPoint/Generator/
# filename = 'G_ckpt_epoch_[epoch_num]' -> epoch 숫자가 입력되기 전 까지의 파일 이름 입력
# extension = '.pt' -> 확장자 입력
G_model_dict, G_epoch = load_ckpt(os.path.join(ckpt_path, "Generator"), 'G_ckpt_epoch_', '.pt')
# path = os.path.join(ckpt_path, "Discriminator") -> /GAN_MNIST_CheckPoint/Discriminator/
# filename = 'D_ckpt_epoch_[epoch_num]' -> epoch 숫자가 입력되기 전 까지의 파일 이름 입력
# extension = '.pt' -> 확장자 입력
D_model_dict, D_epoch = load_ckpt(os.path.join(ckpt_path, "Discriminator"), 'D_ckpt_epoch_', '.pt')
# 저는 state_dict 파일을 불러와서 'load_state_dict'를 이용하였습니다.
# check point 파일 로드에 실패하였을 경우 첫번째 인자 값이 False로 넘어오기 때문에
# 아래와 같이 체크 후 모델에 적용을 시켜주면 됩니다.
if G_model_dict:
G.load_state_dict(G_model_dict)
if D_model_dict:
D.load_state_dict(D_model_dict)
'python' 카테고리의 다른 글
CycleGAN 학습을 위해 python으로 이미지 크롤링하기(wikiart, Vincent van Gogh, selenium) (0) | 2022.04.08 |
---|