python

(Pytorch) Pytorch를 이용하여 학습한 모델 저장/불러오는 방법

친환경입냄새 2022. 4. 7. 13:55

Pytorch를 이용하여 Model을 저장하는 방법은 아래와 같습니다.

import torch
import torch.nn as nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# CNN_model 예시
class CNN_model(nn.Module):
#tistory 코드 블럭의 문제인지 indent가 맞질 않습니다...
	def __init__(self):
    	...
        ...
    
    def forward(self, x):
    	...
        ...


model = CNN_model()

# torch.save(model, path_dict_file_name)

# model 전체 저장
torch.save(model, 'model.pt')

# state_dict = 학습 가능한 매개변수(weight, bias)가 담겨있는 딕셔너리
# model의 state_dict만 저장
torch.save(model.state_dict(), 'model_state_dict.pt')

저장된 Model을 불러오는 방법은 아래와 같습니다.

import torch
import torch.nn as nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# torch.load(path_file_name)

# model 전체가 저장된 .pt 파일 로드
model = torch.load('model.pt')


# torch.load(path_dict_file_name)
# state_dict = 학습 가능한 매개변수(weight, bias)가 담겨있는 딕셔너리
# model의 state_dict만 저장된 .pt 파일 로드

# CNN_model 예시

class CNN_model(nn.Module):
#tistory 코드 블럭의 문제인지 indent가 맞질 않습니다...
	def __init__(self):
    	...
        ...
    
    def forward(self, x):
    	...
        ...
    
model = CNN_model().to(device)

model_dict = torch.load('model_state_dict.pt')

model.load_state_dict(model_dict)

model 자체를 저장한 파일을 불러올때는 'torch.load(path_file_name)' 으로 불러온 뒤 바로 model에 할당해주면 되고

 

학습 가능한 매개변수(weight, bias)만 저장한 state_dict 파일을 불러오는 경우

 

'torch.load(path_dict_file_name)'으로 해당 dict를 불러 온 뒤 'model.load_state_dict(dict)'로 로드를 해주면 됩니다.

 

※ state_dict로 저장 / 불러오는 경우 코드상에 해당 모델의 구조가 구현되어있어야 합니다.

 

※ '모델을 전체 저장한 파일의 크기' > '모델의 state_dict(매개변수)만 저장한 파일의 크기'

 


저는 모델 학습 진행 중 일정 Epoch 마다 모델을 저장해두는데 아래와 같이

 

check point 파일을 불러오는 함수를 구성해두었습니다.

 

import os
import torch

def load_ckpt(path, filename, extension):
    if os.path.exists(path):
        list_ckpt = os.listdir(path)
    else :
        print('load_ckpt : The directory  could not be found.')
        return False, 0

	# path 안에 check point 파일이 여러개 있을때 가장 마지막에 저장된(epoch가 가장 높은)
    # check point 파일을 load 하기 위해 '파일 이름' 과 '확장자'를 제거하여
    # epoch만 적힌 문자만 남게 합니다.
    for i in range(len(list_ckpt)):
        list_ckpt[i] = list_ckpt[i].replace(filename,'')
        list_ckpt[i] = list_ckpt[i].replace(extension, '')

	# 현재 리스트가 문자열로 구성되어있어 int형으로 형변환 후 sort를 진행합니다.
    list_ckpt = list(map(int, list_ckpt))
    list_ckpt.sort()

    if len(list_ckpt) != 0:
        # Check point file or dict, epoch
        return torch.load(os.path.join(path, filename + str(list_ckpt[-1]) + extension)), list_ckpt[-1]
    else:
        # Check point file does not exist
        return False, 0

위 함수는 아래와 같이 사용할 수 있습니다.

class Generator(nn.Module):
    def __init__(self):
        ...
        ...

    def forward(self, x):
        ...
        ...


class Discriminator(nn.Module):
    def __init__(self):
        ...
        ...

    def forward(self, x):
        ...
        ...
        

# check point가 저장된 폴더 이름
ckpt_path = 'GAN_MNIST_CheckPoint'

G = Generator().to(device)
D = Discriminator().to(device)

# path = os.path.join(ckpt_path, "Generator") -> /GAN_MNIST_CheckPoint/Generator/
# filename = 'G_ckpt_epoch_[epoch_num]' -> epoch 숫자가 입력되기 전 까지의 파일 이름 입력
# extension = '.pt' -> 확장자 입력
G_model_dict, G_epoch = load_ckpt(os.path.join(ckpt_path, "Generator"), 'G_ckpt_epoch_', '.pt')

# path = os.path.join(ckpt_path, "Discriminator") -> /GAN_MNIST_CheckPoint/Discriminator/
# filename = 'D_ckpt_epoch_[epoch_num]' -> epoch 숫자가 입력되기 전 까지의 파일 이름 입력
# extension = '.pt' -> 확장자 입력
D_model_dict, D_epoch = load_ckpt(os.path.join(ckpt_path, "Discriminator"), 'D_ckpt_epoch_', '.pt')

# 저는 state_dict 파일을 불러와서 'load_state_dict'를 이용하였습니다.
# check point 파일 로드에 실패하였을 경우 첫번째 인자 값이 False로 넘어오기 때문에
# 아래와 같이 체크 후 모델에 적용을 시켜주면 됩니다.

if G_model_dict:
    G.load_state_dict(G_model_dict)

if D_model_dict:
    D.load_state_dict(D_model_dict)