본문 바로가기

딥러닝/개인구현 정리

[DeepLearning] 이미지 구분 모델_Pokemon 809 세트_ep.2

320x100

지난 시간의 포켓몬스터 데이터셋의 이미지 처리에 이어서, 이번에도 여러가지 신경망 모델을

사용하여 학습하고 시각화 해보는 시간을 갖도록 하겠습니다.

 

이번에 사용할 모델은 ResNet-50입니다.

 

 

1. ResNet에 대한 잠깐 설명

 

ResNet(잔차 신경망, 레스넷)은 스킵 커넥션을 이용해서 잔차를 학습하도록 만든 알고리즘으로,

Resnet 이전의 일반적인 CNN 신경망보다 예측 정확도가 높습니다.

 

  • Residual(잔차) : 관측치와 회귀식의 예측치와의 차이
  • Network(신경망) : 기존의 모델보다 진보된 신경망

 

2. ResNet의 특징

 

  • 기존의 방식보다 더 빠른 Short cut을 진행시킵니다.
    1) 일정 시점마다 input x 자체를 skip connection을 통해서 연결
    2) gradient flow가 원활하게 이루어짐 → weight가 변동 (모델을 깊게 쌓는 것에 대한 부담이 줄어듬)

 

  • OneHotEncoder, LabelEncoder 사용
  • 721번 부터 이미지 파일 형식이 상이(png, jpg)파일을 따로 구분해줍니다

 

포켓몬스터 데이터를 로드하는 클래스는 작성해줍니다.

이는 라벨 인코더를 사용해 영어 이름과 한국 이름을 구분한 것입니다.

 

# 포켓몬 데이터셋 준비

class PokemonDatasetBuilder():
    def __init__(self, dataset_class, file_path, transform=None, splits=True):
        
        self._splits = splits
        self._dataset_class = dataset_class
        self.df = pd.read_csv(file_path, encoding = 'cp949')
        self.transform = transform
        self._preprocess_frame()
        
    def __call__(self, test_split=0.1, val_split=0.1):
        
        dfs = []
        if self._splits:
            dfs.extend(self._create_splits(test_split, val_split))
        else:
            dfs.append(self.df)
        
        datasets = []
        
        if self._dataset_class == "multilabel":
            OHE = OneHotEncoder(sparse=False, handle_unknown="ignore")
            OHE.fit(self.df[["EnglishName"]])
            
            for df in dfs:
                datasets.append(PokemonDatasetMultilabel(df, OHE, self.transform))

        else:
            LE1 = LabelEncoder()
            LE2 = LabelEncoder()
            
            LE1.fit(self.df["EnglishName"])
            LE2.fit(self.df["KoreanName"])
        
            for df in dfs:
                datasets.append(PokemonDatasetMulticlass(df, LE1, LE2, self.transform))
        
        return datasets
            
            
    def _create_splits(self, test_split, val_split):

        df_test = self.df.sample(frac=test_split, random_state=42)
        df_train = self.df.drop(df_test.index)
        df_val = df_train.sample(frac=val_split, random_state=42)
        df_train = df_train.drop(df_val.index)

            
        return [df_train, df_val, df_test]

    
    def _preprocess_frame(self):

        self.df["EnglishName"].iloc[:721] = self.df["EnglishName"].iloc[:721].apply(lambda x : x + ".png")
        self.df["EnglishName"].iloc[721:] = self.df["EnglishName"].iloc[721:].apply(lambda x : x + ".jpg")
        self.df["EnglishName"].fillna("None", inplace=True)

 

멀티 레이블로 사용

 

class PokemonDatasetMultilabel(Dataset):
    def __init__(self, df, encoder, transform=None):
        
        self.df = df
        self.transform = transform
        self.encoder = encoder
        self.type1 = encoder.transform(self.df[["EnglishName"]]) 
        self.type2 = encoder.transform(self.df[["KoreanName"]]) 
    
    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        image_file = WORKING_DIR + IMAGES_DIR + self.df.iloc[idx, 0]
        
        image = process_image(image_file, self.transform)

        return image, (self.type1[idx] + self.type2[idx])

    
    def __len__(self):
        return self.df.shape[0]
    
    
    def get_labels_from_vector(self, vector):

        labels = self.encoder.categories_[0][vector==1]
        if vector.sum() == 1:
            return labels[0], "None"
        else:
            return tuple(labels)

    def get_labels_from_id(self, type1, type2=None):

        if type2 is not None: 
            return self.encoder.categories_[0][type1], self.encoder.categories_[0][type2]
        else:
            return self.encoder.categories_[0][type2]

 

이미지의 트레인 데이터셋, 검증 데이터셋, 테스트 데이터셋을 만드는

generator를 선언해줍니다.

 

OUTPUT_TYPE = "multilabel" 

transformations = transforms.Compose([transforms.Resize((120, 120)), transforms.ToTensor()])
# transformations = transforms.Compose([transforms.ToTensor()])
dataset_generator = PokemonDatasetBuilder(OUTPUT_TYPE, WORKING_DIR+INFO_DIR, transformations)
train_dataset, val_dataset, test_dataset = dataset_generator()

train_dataloader = DataLoader(train_dataset, batch_size = 4, shuffle=True)
val_dataloader = DataLoader(val_dataset)
test_dataloader = DataLoader(test_dataset)

 

데이터셋의 shape입니다.

 

train_dataset[0][0].shape
torch.Size([3, 120, 120])

 

이 데이터에서 하나만 뽑아 이미지를 출력해봅니다.

 

# 랜덤한 포켓몬
i = random.randint(1, 650)
pokemon = train_dataset[i]
show_image(pokemon[0], *train_dataset.get_labels_from_vector(pokemon[1]))

# Sample usage for multiclass
# show_image(pokemon[0], *train_dataset.get_labels(*pokemon[1:]))
이 포켓몬의 이름은? 터검니 

이는 해당 폴더에서 이미지를 꺼내고, 경로를 꺼내 한글로 번역한 csv를 사용한 것입니다.

 

모델을 생성해줍니다.

 

class PokemonFCBlock(nn.Module):  
    def __init__(self, in_features, out_features, mode="none", dropout=0.2):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_features, 1000),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(), 
            nn.Dropout(dropout),
            nn.Linear(1000, out_features))        
        self.mode = mode.lower()                
    def forward(self, x):        
        if self.mode == "logistic":
            return torch.sigmoid(self.fc(x))
        elif self.mode == "softmax":
            return F.softmax(self.fc(x))
        elif self.mode == "none":
            return self.fc(x)
        else:
            raise UnknownModeException
class PokemonMultilabelCNN(nn.Module):    
    def __init__(self, base_model, output_size, dropout):        
        super().__init__()    
        self.base_model = base_model 
        in_features = self.base_model.fc.in_features
        new_final = PokemonFCBlock(in_features, output_size, mode="logistic")
        self.base_model.fc = new_final        
    def forward(self, x):
        return self.base_model(x)   
    def freeze(self):
        for name, child in self.base_model.named_children():
            if name != "fc":
                for param in child.parameters():
                    param.requires_grad = False             
    def unfreeze(self):
        for param in self.base_model.parameters():          
            param.requires_grad = True

 

이제 먼저 트레인 되었던 ResNet모델을 다운로드해 줍니다.

바로 모델의 형태를 구경해보겠습니다.

 

base_model = models.resnet50(pretrained=True, progress=False)
model = PokemonMultilabelCNN(base_model, 809, 0.5) # dropout = 0.5부터 테스트 시작
model.to(device)
model.freeze()
# model_summary 예시
summary(model, (3, *(120,120)))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 60, 60]           9,408
       BatchNorm2d-2           [-1, 64, 60, 60]             128
              ReLU-3           [-1, 64, 60, 60]               0
         MaxPool2d-4           [-1, 64, 30, 30]               0
            Conv2d-5           [-1, 64, 30, 30]           4,096
       BatchNorm2d-6           [-1, 64, 30, 30]             128
              ReLU-7           [-1, 64, 30, 30]               0
            Conv2d-8           [-1, 64, 30, 30]          36,864
       BatchNorm2d-9           [-1, 64, 30, 30]             128
             ReLU-10           [-1, 64, 30, 30]               0
           Conv2d-11          [-1, 256, 30, 30]          16,384
      BatchNorm2d-12          [-1, 256, 30, 30]             512
           Conv2d-13          [-1, 256, 30, 30]          16,384
      BatchNorm2d-14          [-1, 256, 30, 30]             512
             ReLU-15          [-1, 256, 30, 30]               0
       Bottleneck-16          [-1, 256, 30, 30]               0
           Conv2d-17           [-1, 64, 30, 30]          16,384
      BatchNorm2d-18           [-1, 64, 30, 30]             128
             ReLU-19           [-1, 64, 30, 30]               0
           Conv2d-20           [-1, 64, 30, 30]          36,864
      BatchNorm2d-21           [-1, 64, 30, 30]             128
             ReLU-22           [-1, 64, 30, 30]               0
           Conv2d-23          [-1, 256, 30, 30]          16,384
      BatchNorm2d-24          [-1, 256, 30, 30]             512
             ReLU-25          [-1, 256, 30, 30]               0
       Bottleneck-26          [-1, 256, 30, 30]               0
           Conv2d-27           [-1, 64, 30, 30]          16,384
      BatchNorm2d-28           [-1, 64, 30, 30]             128
             ReLU-29           [-1, 64, 30, 30]               0
           Conv2d-30           [-1, 64, 30, 30]          36,864
      BatchNorm2d-31           [-1, 64, 30, 30]             128
             ReLU-32           [-1, 64, 30, 30]               0
           Conv2d-33          [-1, 256, 30, 30]          16,384
      BatchNorm2d-34          [-1, 256, 30, 30]             512
             ReLU-35          [-1, 256, 30, 30]               0
       Bottleneck-36          [-1, 256, 30, 30]               0
           Conv2d-37          [-1, 128, 30, 30]          32,768
      BatchNorm2d-38          [-1, 128, 30, 30]             256
             ReLU-39          [-1, 128, 30, 30]               0
           Conv2d-40          [-1, 128, 15, 15]         147,456
      BatchNorm2d-41          [-1, 128, 15, 15]             256
             ReLU-42          [-1, 128, 15, 15]               0
           Conv2d-43          [-1, 512, 15, 15]          65,536
      BatchNorm2d-44          [-1, 512, 15, 15]           1,024
           Conv2d-45          [-1, 512, 15, 15]         131,072
      BatchNorm2d-46          [-1, 512, 15, 15]           1,024
             ReLU-47          [-1, 512, 15, 15]               0
       Bottleneck-48          [-1, 512, 15, 15]               0
           Conv2d-49          [-1, 128, 15, 15]          65,536
      BatchNorm2d-50          [-1, 128, 15, 15]             256
             ReLU-51          [-1, 128, 15, 15]               0
           Conv2d-52          [-1, 128, 15, 15]         147,456
      BatchNorm2d-53          [-1, 128, 15, 15]             256
             ReLU-54          [-1, 128, 15, 15]               0
           Conv2d-55          [-1, 512, 15, 15]          65,536
      BatchNorm2d-56          [-1, 512, 15, 15]           1,024
             ReLU-57          [-1, 512, 15, 15]               0
       Bottleneck-58          [-1, 512, 15, 15]               0
           Conv2d-59          [-1, 128, 15, 15]          65,536
      BatchNorm2d-60          [-1, 128, 15, 15]             256
             ReLU-61          [-1, 128, 15, 15]               0
           Conv2d-62          [-1, 128, 15, 15]         147,456
      BatchNorm2d-63          [-1, 128, 15, 15]             256
             ReLU-64          [-1, 128, 15, 15]               0
           Conv2d-65          [-1, 512, 15, 15]          65,536
      BatchNorm2d-66          [-1, 512, 15, 15]           1,024
             ReLU-67          [-1, 512, 15, 15]               0
       Bottleneck-68          [-1, 512, 15, 15]               0
           Conv2d-69          [-1, 128, 15, 15]          65,536
      BatchNorm2d-70          [-1, 128, 15, 15]             256
             ReLU-71          [-1, 128, 15, 15]               0
           Conv2d-72          [-1, 128, 15, 15]         147,456
      BatchNorm2d-73          [-1, 128, 15, 15]             256
             ReLU-74          [-1, 128, 15, 15]               0
           Conv2d-75          [-1, 512, 15, 15]          65,536
      BatchNorm2d-76          [-1, 512, 15, 15]           1,024
             ReLU-77          [-1, 512, 15, 15]               0
       Bottleneck-78          [-1, 512, 15, 15]               0
           Conv2d-79          [-1, 256, 15, 15]         131,072
      BatchNorm2d-80          [-1, 256, 15, 15]             512
             ReLU-81          [-1, 256, 15, 15]               0
           Conv2d-82            [-1, 256, 8, 8]         589,824
      BatchNorm2d-83            [-1, 256, 8, 8]             512
             ReLU-84            [-1, 256, 8, 8]               0
           Conv2d-85           [-1, 1024, 8, 8]         262,144
      BatchNorm2d-86           [-1, 1024, 8, 8]           2,048
           Conv2d-87           [-1, 1024, 8, 8]         524,288
      BatchNorm2d-88           [-1, 1024, 8, 8]           2,048
             ReLU-89           [-1, 1024, 8, 8]               0
       Bottleneck-90           [-1, 1024, 8, 8]               0
           Conv2d-91            [-1, 256, 8, 8]         262,144
      BatchNorm2d-92            [-1, 256, 8, 8]             512
             ReLU-93            [-1, 256, 8, 8]               0
           Conv2d-94            [-1, 256, 8, 8]         589,824
      BatchNorm2d-95            [-1, 256, 8, 8]             512
             ReLU-96            [-1, 256, 8, 8]               0
           Conv2d-97           [-1, 1024, 8, 8]         262,144
      BatchNorm2d-98           [-1, 1024, 8, 8]           2,048
             ReLU-99           [-1, 1024, 8, 8]               0
      Bottleneck-100           [-1, 1024, 8, 8]               0
          Conv2d-101            [-1, 256, 8, 8]         262,144
     BatchNorm2d-102            [-1, 256, 8, 8]             512
            ReLU-103            [-1, 256, 8, 8]               0
          Conv2d-104            [-1, 256, 8, 8]         589,824
     BatchNorm2d-105            [-1, 256, 8, 8]             512
            ReLU-106            [-1, 256, 8, 8]               0
          Conv2d-107           [-1, 1024, 8, 8]         262,144
     BatchNorm2d-108           [-1, 1024, 8, 8]           2,048
            ReLU-109           [-1, 1024, 8, 8]               0
      Bottleneck-110           [-1, 1024, 8, 8]               0
          Conv2d-111            [-1, 256, 8, 8]         262,144
     BatchNorm2d-112            [-1, 256, 8, 8]             512
            ReLU-113            [-1, 256, 8, 8]               0
          Conv2d-114            [-1, 256, 8, 8]         589,824
     BatchNorm2d-115            [-1, 256, 8, 8]             512
            ReLU-116            [-1, 256, 8, 8]               0
          Conv2d-117           [-1, 1024, 8, 8]         262,144
     BatchNorm2d-118           [-1, 1024, 8, 8]           2,048
            ReLU-119           [-1, 1024, 8, 8]               0
      Bottleneck-120           [-1, 1024, 8, 8]               0
          Conv2d-121            [-1, 256, 8, 8]         262,144
     BatchNorm2d-122            [-1, 256, 8, 8]             512
            ReLU-123            [-1, 256, 8, 8]               0
          Conv2d-124            [-1, 256, 8, 8]         589,824
     BatchNorm2d-125            [-1, 256, 8, 8]             512
            ReLU-126            [-1, 256, 8, 8]               0
          Conv2d-127           [-1, 1024, 8, 8]         262,144
     BatchNorm2d-128           [-1, 1024, 8, 8]           2,048
            ReLU-129           [-1, 1024, 8, 8]               0
      Bottleneck-130           [-1, 1024, 8, 8]               0
          Conv2d-131            [-1, 256, 8, 8]         262,144
     BatchNorm2d-132            [-1, 256, 8, 8]             512
            ReLU-133            [-1, 256, 8, 8]               0
          Conv2d-134            [-1, 256, 8, 8]         589,824
     BatchNorm2d-135            [-1, 256, 8, 8]             512
            ReLU-136            [-1, 256, 8, 8]               0
          Conv2d-137           [-1, 1024, 8, 8]         262,144
     BatchNorm2d-138           [-1, 1024, 8, 8]           2,048
            ReLU-139           [-1, 1024, 8, 8]               0
      Bottleneck-140           [-1, 1024, 8, 8]               0
          Conv2d-141            [-1, 512, 8, 8]         524,288
     BatchNorm2d-142            [-1, 512, 8, 8]           1,024
            ReLU-143            [-1, 512, 8, 8]               0
          Conv2d-144            [-1, 512, 4, 4]       2,359,296
     BatchNorm2d-145            [-1, 512, 4, 4]           1,024
            ReLU-146            [-1, 512, 4, 4]               0
          Conv2d-147           [-1, 2048, 4, 4]       1,048,576
     BatchNorm2d-148           [-1, 2048, 4, 4]           4,096
          Conv2d-149           [-1, 2048, 4, 4]       2,097,152
     BatchNorm2d-150           [-1, 2048, 4, 4]           4,096
            ReLU-151           [-1, 2048, 4, 4]               0
      Bottleneck-152           [-1, 2048, 4, 4]               0
          Conv2d-153            [-1, 512, 4, 4]       1,048,576
     BatchNorm2d-154            [-1, 512, 4, 4]           1,024
            ReLU-155            [-1, 512, 4, 4]               0
          Conv2d-156            [-1, 512, 4, 4]       2,359,296
     BatchNorm2d-157            [-1, 512, 4, 4]           1,024
            ReLU-158            [-1, 512, 4, 4]               0
          Conv2d-159           [-1, 2048, 4, 4]       1,048,576
     BatchNorm2d-160           [-1, 2048, 4, 4]           4,096
            ReLU-161           [-1, 2048, 4, 4]               0
      Bottleneck-162           [-1, 2048, 4, 4]               0
          Conv2d-163            [-1, 512, 4, 4]       1,048,576
     BatchNorm2d-164            [-1, 512, 4, 4]           1,024
            ReLU-165            [-1, 512, 4, 4]               0
          Conv2d-166            [-1, 512, 4, 4]       2,359,296
     BatchNorm2d-167            [-1, 512, 4, 4]           1,024
            ReLU-168            [-1, 512, 4, 4]               0
          Conv2d-169           [-1, 2048, 4, 4]       1,048,576
     BatchNorm2d-170           [-1, 2048, 4, 4]           4,096
            ReLU-171           [-1, 2048, 4, 4]               0
      Bottleneck-172           [-1, 2048, 4, 4]               0
AdaptiveAvgPool2d-173           [-1, 2048, 1, 1]               0
          Linear-174                 [-1, 1000]       2,049,000
       LeakyReLU-175                 [-1, 1000]               0
         Dropout-176                 [-1, 1000]               0
          Linear-177                 [-1, 1000]       1,001,000
       LeakyReLU-178                 [-1, 1000]               0
         Dropout-179                 [-1, 1000]               0
          Linear-180                  [-1, 809]         809,809
  PokemonFCBlock-181                  [-1, 809]               0
          ResNet-182                  [-1, 809]               0
================================================================
Total params: 27,367,841
Trainable params: 3,859,809
Non-trainable params: 23,508,032
----------------------------------------------------------------
Input size (MB): 0.16
Forward/backward pass size (MB): 85.01
Params size (MB): 104.40
Estimated Total Size (MB): 189.58
----------------------------------------------------------------

 

손실 함수와 최적화 함수를 지정해줍니다.

크로스 엔트로피 로스, 아담 함수입니다.

 

# 패러미터 조정. callback은 미사용
# optimizer 조정

criterion = nn.CrossEntropyLoss() # 초기 BCELoss
optimizer= torch.optim.Adam(model.parameters(), lr=0.0001) # 최초 ADAMW(lr =  5e-4, weight decay는 미설정)

 

모델 트레이너를 사용해서 디바이스에 학습시켜 줍니다.

class ModelTrainer():

    
    def __init__(self, model, optimizer, criterion, device):
        
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        # self.callback = callback
        self.epoch = 1
        
    def train_loop(self, dataloader, epoch=100):

        raise NotImplementedError
        
    def validate_loop(self, dataloader):

        raise NotImplementedError
        
    def train(self, train_dataloader, val_dataloader):

        raise NotImplementedError

 

epoch를 시작하면 갈수록 트레인 cost와 검증 cost를 보여주게 됩니다.

 

class MultilabelModelTrainer(ModelTrainer):

    
    def __init__(self, model, optimizer, criterion, device):
        
        super(MultilabelModelTrainer, self).__init__(model, optimizer, criterion, device)
    
        
    def train_loop(self, dataloader, epoch=100):
        
        self.model.train()
        cost = 0 
        t = tqdm(dataloader)
        if epoch:
            t.set_description(f"Training mode, Epoch {self.epoch}")
            
        for feature, target in t:
            feature, target = feature.to(self.device), target.to(self.device)
            output = self.model(feature).double() 
            loss = self.criterion(output, target)
            loss.backward()
            
            self.optimizer.step()
            self.optimizer.zero_grad()
            
            cost += loss.item() * feature.shape[0]
        
        return cost / len(dataloader.dataset)
            
        
    def validate_loop(self, dataloader):
        
        self.model.eval()
        cost = 0
        with torch.no_grad():
            for feature, target in tqdm(dataloader, desc="Validation mode"):
                feature, target = feature.to(self.device), target.to(self.device)
                output = self.model(feature).double()
                loss = self.criterion(output, target)
                
                cost += loss.item() * feature.shape[0]
        
        return cost/len(dataloader.dataset)
        
        
    def train(self, train_dataloader, val_dataloader):
        
        for _ in range(1, 100):
            train_cost = self.train_loop(train_dataloader)
            val_cost = self.validate_loop(val_dataloader)
            self.epoch += 1
                      
            print(f'1) 학습 cost : {train_cost},\n2) 검증 cost : {val_cost}')

 

epoch를 통해 학습을 진행, cost를 뽑고 마지막으로 모델을 저장해줍니다.

 

# 시간 체크
# 에포크 숫자가 따로 표시

model_trainer = MultilabelModelTrainer(model, optimizer, criterion, device) # (시간 꼭 체크)
model_trainer.train(train_dataloader, val_dataloader)
torch.save(model.state_dict(), "pokemon_epoch_Cross_entropy_loss_ADAM_LR=0.0001(100회).pt")

 

(에포크 94~100회까지만 조회)

 

모델을 로드해보겠습니다.

 

model.load_state_dict(torch.load("/content/pokemon_epoch_Cross_entropy_loss_ADAM_LR=0.0001(100회).pt", map_location=device))

model

 

아래 colab만의 양식란 추가를 통해 구현한 시각화 입니다.

 

포켓몬 번호를 입력해볼까요?

#@title 포켓몬 번호를 입력해볼까요?
번호를_입력해요 = 5 #@param {type:"number"}
pokemon = train_dataset[번호를_입력해요]
with torch.no_grad():
    feature = pokemon[0].to(device)
    prediction = model(feature.unsqueeze(0))
print(f'>>>>>>>>[ {번호를_입력해요} ]번째 포켓몬')
show_image(pokemon[0], *train_dataset.get_labels_from_vector(pokemon[1]))
>>>>>>>>[ 5 ]번째 포켓몬
이 포켓몬의 이름은? 리자몽 

 

시각화까지 종료

728x90