AI

(U-Net) 모델 분석 및 Pytorch로 구현 - 1

친환경입냄새 2022. 4. 12. 18:46

오늘은 Biomedical 분야 Semantic Segmentation을 위한 Convolutional network 인 U-Net에 대해 포스팅 하려고 합니다.

 

더보기

https://link.springer.com/content/pdf/10.1007/978-3-319-24574-4_28.pdf

(U-Net: Convolutional Networks for Biomedical Image Segmentation
Olaf Ronneberger, Philipp Fischer, and Thomas Brox)


U-Net의 구조

- U-Net의 구조는 아래와 같습니다.

U-Net은 'Contracting path', 'Bottle Neck', 'Expanding path'로 구성되어있습니다.


'수축 단계(Contracting path)'

입력된 이미지에 아래 연산을 반복

1. Convolution( 3x3 kernel, stride : 1)  X 2

2. Max Pooling( 2x2 kernel, stride : 2)


'전환 구간(Bottle neck)'

수축 단계를 거친 후 팽창 단계로 전환되는 구간

1. Convolution( 3x3 kernel, stride : 1) X 2


'팽창 단계(Expanding path)'

1. Up-Convolution( 2x2 kernel, stride : 2) 이후 같은 단계에 있는 Feature map을 합침

2. Convolution( 3x3 kernel, stride : 1) X 2


 

Feature map의 채널 수는 '이미지 입력', 'Max Pooling' , 'Up-Convolution' 이후 첫 'Convolution' 연산을 수행할 때

변화가 생깁니다.

 

각 연산을  수행 할 때 padding을 추가하지 않으므로 Feature map의 크기는 점점 작아지게 됩니다.

 


Pytorch로 구현한 U-Net 모델

class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()

		# Convolution, Batch Normalization, ReLU 연산을 합친 함수
        def CBR2d(input_channel, output_channel, kernel_size=3, stride=1):
            layer = nn.Sequential(
                nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride),
                nn.BatchNorm2d(num_features=output_channel),
                nn.ReLU()
            )
            return layer

		# Contracting path
        # 572x572x1 => 568x568x64
        self.conv1 = nn.Sequential(
            CBR2d(1, 64, 3, 1),
            CBR2d(64, 64, 3, 1)
        )
		# 568x568x64 => 284x284x64 		
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)


		# 284x284x64 => 280x280x128
        self.conv2 = nn.Sequential(
            CBR2d(64, 128, 3, 1),
            CBR2d(128, 128, 3, 1)
        )
		# 280x280x128 => 140x140x128
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)


		# 140x140x128 => 136x136x256
        self.conv3 = nn.Sequential(
            CBR2d(128, 256, 3, 1),
            CBR2d(256, 256, 3, 1)
        )
		# 136x136x256 => 68x68x256
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)


		# 68x68x256 => 64x64x512
        # Contracting path 마지막에 Dropout 적용
        self.conv4 = nn.Sequential(
            CBR2d(256, 512, 3, 1),
            CBR2d(512, 512, 3, 1),
            nn.Dropout(p=0.5)
        )
		# 64x64x512 => 32x32x512
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
		# Contracting path 끝
        
		
        # Bottlneck 구간
        # 32x32x512 => 28x28x1024
        self.bottleNeck = nn.Sequential(
            CBR2d(512, 1024, 3, 1),
            CBR2d(1024, 1024, 3, 1),
        )
		# Bottlneck 구간 끝


		# Expanding path
        # channel 수를 감소 시키며 Up-Convolution
		# 28x28x1024 => 56x56x512        
        self.upconv1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)


		# Up-Convolution 이후 channel = 512
        # Contracting path 중 같은 단계의 Feature map을 가져와 Up-Convolution 결과의 Feature map과 Concat 연산
        # => channel = 1024 가 됩니다.
        # forward 부분을 참고해주세요
        # 56x56x1024 => 52x52x512
        self.ex_conv1 = nn.Sequential(
            CBR2d(1024, 512, 3, 1),
            CBR2d(512, 512, 3, 1)
        )

		# 52x52x512 => 104x104x256
        self.upconv2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)

		# 104x104x512 => 100x100x256
        self.ex_conv2 = nn.Sequential(
            CBR2d(512, 256, 3, 1),
            CBR2d(256, 256, 3, 1)
        )

		# 100x100x256 => 200x200x128
        self.upconv3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)

		# 200x200x256 => 196x196x128
        self.ex_conv3 = nn.Sequential(
            CBR2d(256, 128, 3, 1),
            CBR2d(128, 128, 3, 1)
        )

		# 196x196x128 => 392x392x64
        self.upconv4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)

		# 392x392x128 => 388x388x64
        self.ex_conv4 = nn.Sequential(
            CBR2d(128, 64, 3, 1),
            CBR2d(64, 64, 3, 1),

        )

		# 논문 구조상 output = 2 channel
	    # train 데이터에서 세포 / 배경을 검출하는것이 목표여서 class_num = 1로 지정
		# 388x388x64 => 388x388x1
        self.fc = nn.Conv2d(64, 1, kernel_size=1, stride=1)


    def forward(self, x):
    	# Contracting path
        # 572x572x1 => 568x568x64
        layer1 = self.conv1(x)
        
        # Max Pooling
        # 568x568x64 => 284x284x64
        out = self.pool1(layer1)

		# 284x284x64 => 280x280x128
        layer2 = self.conv2(out)
        
        # Max Pooling
        # 280x280x128 => 140x140x128
        out = self.pool2(layer2)

		# 140x140x128 => 136x136x256
        layer3 = self.conv3(out)
        
        # Max Pooling
        # 136x136x256 => 68x68x256
        out = self.pool3(layer3)

		# 68x68x256 => 64x64x512
        layer4 = self.conv4(out)
        
        # Max Pooling
        # 64x64x512 => 32x32x512
        out = self.pool4(layer4)

		# bottleneck 
        # 32x32x512 => 28x28x1024
        bottleNeck = self.bottleNeck(out)

		# Expanding path
        # 28x28x1024 => 56x56x512
        upconv1 = self.upconv1(bottleNeck)
        
        # Contracting path 중 같은 단계의 Feature map을 가져와 합침
        # Up-Convolution 결과의 Feature map size 만큼 CenterCrop 하여 Concat 연산
        # 56x56x512 => 56x56x1024
        cat1 = torch.cat((transforms.CenterCrop((upconv1.shape[2], upconv1.shape[3]))(layer4), upconv1), dim=1)
        
        # 56x56x1024 => 52x52x512
        ex_layer1 = self.ex_conv1(cat1)

		# 52x52x512 => 104x104x256
        upconv2 = self.upconv2(ex_layer1)
        
        # Contracting path 중 같은 단계의 Feature map을 가져와 합침
        # Up-Convolution 결과의 Feature map size 만큼 CenterCrop 하여 Concat 연산
        # 104x104x256 => 104x104x512
        cat2 = torch.cat((transforms.CenterCrop((upconv2.shape[2], upconv2.shape[3]))(layer3), upconv2), dim=1)
        
        # 104x104x512 => 100x100x256
        ex_layer2 = self.ex_conv2(cat2)

		# 100x100x256 => 200x200x128
        upconv3 = self.upconv3(ex_layer2)
        
        # Contracting path 중 같은 단계의 Feature map을 가져와 합침
        # Up-Convolution 결과의 Feature map size 만큼 CenterCrop 하여 Concat 연산
        # 200x200x128 => 200x200x256
        cat3 = torch.cat((transforms.CenterCrop((upconv3.shape[2], upconv3.shape[3]))(layer2), upconv3), dim=1)
        
        # 200x200x256 => 196x196x128
        ex_layer3 = self.ex_conv3(cat3)

		# 196x196x128=> 392x392x64
        upconv4 = self.upconv4(ex_layer3)
        
        # Contracting path 중 같은 단계의 Feature map을 가져와 합침
        # Up-Convolution 결과의 Feature map size 만큼 CenterCrop 하여 Concat 연산
        # 392x392x64 => 392x392x128
        cat4 = torch.cat((transforms.CenterCrop((upconv4.shape[2], upconv4.shape[3]))(layer1), upconv4), dim=1)
        
        # 392x392x128 => 388x388x64
        out = self.ex_conv4(cat4)

		# 388x388x64 => 388x388x1
        out = self.fc(out)
        return out

해당 모델을 학습하는 Pytorch 코드는 다음 포스팅에 올리도록 하겠습니다.

 

모델 학습시 입력 이미지(512x512)가 모델을 거친 후 (324x324) size의 이미지가 출력됩니다.

- label 데이터의 324x324를 Center Crop 하여 학습을 진행하였습니다.

 

 

그리고 모델을 거친 후에 입력 이미지와 같은 크기의 결과를 얻기 위해

 

논문에 나와있는 'Mirroring Extrapolation'을 적용하고 'Overlap-tile strategy' 기법을 이용하여 코드 구현중입니다.

 

이 부분도 구현을 마치게 되면 포스팅 하도록 하겠습니다.

 


Training dataset은 ISBI 2012 em image segmentation 대회에서 사용한 이미지 데이터 셋 (30장)을 사용하였습니다.

 

training data = 24장

validation data = 6장

Epoch : 100회

Batch size = 4

 

학습 결과는 아래와 같습니다.

 


Training set 결과

Input(512 x 512)

 

Label(324 x 324 CenterCrop)

 

Output(324 x 324)


Validation set 결과

Input(512 x 512)

 

Label(324 x 324 CenterCrop)

 

Output(324 x 324)