지난 시간의 포켓몬스터 데이터셋의 이미지 처리에 이어서, 이번에도 여러가지 신경망 모델을
사용하여 학습하고 시각화 해보는 시간을 갖도록 하겠습니다.
이번에 사용할 모델은 ResNet-50입니다.
1. ResNet에 대한 잠깐 설명
ResNet(잔차 신경망, 레스넷)은 스킵 커넥션을 이용해서 잔차를 학습하도록 만든 알고리즘으로,
Resnet 이전의 일반적인 CNN 신경망보다 예측 정확도가 높습니다.
- Residual(잔차) : 관측치와 회귀식의 예측치와의 차이
- Network(신경망) : 기존의 모델보다 진보된 신경망
2. ResNet의 특징
- 기존의 방식보다 더 빠른 Short cut을 진행시킵니다.
1) 일정 시점마다 input x 자체를 skip connection을 통해서 연결
2) gradient flow가 원활하게 이루어짐 → weight가 변동 (모델을 깊게 쌓는 것에 대한 부담이 줄어듬)
- OneHotEncoder, LabelEncoder 사용
- 721번 부터 이미지 파일 형식이 상이(png, jpg)파일을 따로 구분해줍니다
포켓몬스터 데이터를 로드하는 클래스는 작성해줍니다.
이는 라벨 인코더를 사용해 영어 이름과 한국 이름을 구분한 것입니다.
# 포켓몬 데이터셋 준비
class PokemonDatasetBuilder():
def __init__(self, dataset_class, file_path, transform=None, splits=True):
self._splits = splits
self._dataset_class = dataset_class
self.df = pd.read_csv(file_path, encoding = 'cp949')
self.transform = transform
self._preprocess_frame()
def __call__(self, test_split=0.1, val_split=0.1):
dfs = []
if self._splits:
dfs.extend(self._create_splits(test_split, val_split))
else:
dfs.append(self.df)
datasets = []
if self._dataset_class == "multilabel":
OHE = OneHotEncoder(sparse=False, handle_unknown="ignore")
OHE.fit(self.df[["EnglishName"]])
for df in dfs:
datasets.append(PokemonDatasetMultilabel(df, OHE, self.transform))
else:
LE1 = LabelEncoder()
LE2 = LabelEncoder()
LE1.fit(self.df["EnglishName"])
LE2.fit(self.df["KoreanName"])
for df in dfs:
datasets.append(PokemonDatasetMulticlass(df, LE1, LE2, self.transform))
return datasets
def _create_splits(self, test_split, val_split):
df_test = self.df.sample(frac=test_split, random_state=42)
df_train = self.df.drop(df_test.index)
df_val = df_train.sample(frac=val_split, random_state=42)
df_train = df_train.drop(df_val.index)
return [df_train, df_val, df_test]
def _preprocess_frame(self):
self.df["EnglishName"].iloc[:721] = self.df["EnglishName"].iloc[:721].apply(lambda x : x + ".png")
self.df["EnglishName"].iloc[721:] = self.df["EnglishName"].iloc[721:].apply(lambda x : x + ".jpg")
self.df["EnglishName"].fillna("None", inplace=True)
멀티 레이블로 사용
class PokemonDatasetMultilabel(Dataset):
def __init__(self, df, encoder, transform=None):
self.df = df
self.transform = transform
self.encoder = encoder
self.type1 = encoder.transform(self.df[["EnglishName"]])
self.type2 = encoder.transform(self.df[["KoreanName"]])
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image_file = WORKING_DIR + IMAGES_DIR + self.df.iloc[idx, 0]
image = process_image(image_file, self.transform)
return image, (self.type1[idx] + self.type2[idx])
def __len__(self):
return self.df.shape[0]
def get_labels_from_vector(self, vector):
labels = self.encoder.categories_[0][vector==1]
if vector.sum() == 1:
return labels[0], "None"
else:
return tuple(labels)
def get_labels_from_id(self, type1, type2=None):
if type2 is not None:
return self.encoder.categories_[0][type1], self.encoder.categories_[0][type2]
else:
return self.encoder.categories_[0][type2]
이미지의 트레인 데이터셋, 검증 데이터셋, 테스트 데이터셋을 만드는
generator를 선언해줍니다.
OUTPUT_TYPE = "multilabel"
transformations = transforms.Compose([transforms.Resize((120, 120)), transforms.ToTensor()])
# transformations = transforms.Compose([transforms.ToTensor()])
dataset_generator = PokemonDatasetBuilder(OUTPUT_TYPE, WORKING_DIR+INFO_DIR, transformations)
train_dataset, val_dataset, test_dataset = dataset_generator()
train_dataloader = DataLoader(train_dataset, batch_size = 4, shuffle=True)
val_dataloader = DataLoader(val_dataset)
test_dataloader = DataLoader(test_dataset)
데이터셋의 shape입니다.
train_dataset[0][0].shape
torch.Size([3, 120, 120])
이 데이터에서 하나만 뽑아 이미지를 출력해봅니다.
# 랜덤한 포켓몬
i = random.randint(1, 650)
pokemon = train_dataset[i]
show_image(pokemon[0], *train_dataset.get_labels_from_vector(pokemon[1]))
# Sample usage for multiclass
# show_image(pokemon[0], *train_dataset.get_labels(*pokemon[1:]))
이 포켓몬의 이름은? 터검니
이는 해당 폴더에서 이미지를 꺼내고, 경로를 꺼내 한글로 번역한 csv를 사용한 것입니다.
모델을 생성해줍니다.
class PokemonFCBlock(nn.Module):
def __init__(self, in_features, out_features, mode="none", dropout=0.2):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(in_features, 1000),
nn.LeakyReLU(),
nn.Dropout(dropout),
nn.Linear(1000, 1000),
nn.LeakyReLU(),
nn.Dropout(dropout),
nn.Linear(1000, out_features))
self.mode = mode.lower()
def forward(self, x):
if self.mode == "logistic":
return torch.sigmoid(self.fc(x))
elif self.mode == "softmax":
return F.softmax(self.fc(x))
elif self.mode == "none":
return self.fc(x)
else:
raise UnknownModeException
class PokemonMultilabelCNN(nn.Module):
def __init__(self, base_model, output_size, dropout):
super().__init__()
self.base_model = base_model
in_features = self.base_model.fc.in_features
new_final = PokemonFCBlock(in_features, output_size, mode="logistic")
self.base_model.fc = new_final
def forward(self, x):
return self.base_model(x)
def freeze(self):
for name, child in self.base_model.named_children():
if name != "fc":
for param in child.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.base_model.parameters():
param.requires_grad = True
이제 먼저 트레인 되었던 ResNet모델을 다운로드해 줍니다.
바로 모델의 형태를 구경해보겠습니다.
base_model = models.resnet50(pretrained=True, progress=False)
model = PokemonMultilabelCNN(base_model, 809, 0.5) # dropout = 0.5부터 테스트 시작
model.to(device)
model.freeze()
# model_summary 예시
summary(model, (3, *(120,120)))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 60, 60] 9,408
BatchNorm2d-2 [-1, 64, 60, 60] 128
ReLU-3 [-1, 64, 60, 60] 0
MaxPool2d-4 [-1, 64, 30, 30] 0
Conv2d-5 [-1, 64, 30, 30] 4,096
BatchNorm2d-6 [-1, 64, 30, 30] 128
ReLU-7 [-1, 64, 30, 30] 0
Conv2d-8 [-1, 64, 30, 30] 36,864
BatchNorm2d-9 [-1, 64, 30, 30] 128
ReLU-10 [-1, 64, 30, 30] 0
Conv2d-11 [-1, 256, 30, 30] 16,384
BatchNorm2d-12 [-1, 256, 30, 30] 512
Conv2d-13 [-1, 256, 30, 30] 16,384
BatchNorm2d-14 [-1, 256, 30, 30] 512
ReLU-15 [-1, 256, 30, 30] 0
Bottleneck-16 [-1, 256, 30, 30] 0
Conv2d-17 [-1, 64, 30, 30] 16,384
BatchNorm2d-18 [-1, 64, 30, 30] 128
ReLU-19 [-1, 64, 30, 30] 0
Conv2d-20 [-1, 64, 30, 30] 36,864
BatchNorm2d-21 [-1, 64, 30, 30] 128
ReLU-22 [-1, 64, 30, 30] 0
Conv2d-23 [-1, 256, 30, 30] 16,384
BatchNorm2d-24 [-1, 256, 30, 30] 512
ReLU-25 [-1, 256, 30, 30] 0
Bottleneck-26 [-1, 256, 30, 30] 0
Conv2d-27 [-1, 64, 30, 30] 16,384
BatchNorm2d-28 [-1, 64, 30, 30] 128
ReLU-29 [-1, 64, 30, 30] 0
Conv2d-30 [-1, 64, 30, 30] 36,864
BatchNorm2d-31 [-1, 64, 30, 30] 128
ReLU-32 [-1, 64, 30, 30] 0
Conv2d-33 [-1, 256, 30, 30] 16,384
BatchNorm2d-34 [-1, 256, 30, 30] 512
ReLU-35 [-1, 256, 30, 30] 0
Bottleneck-36 [-1, 256, 30, 30] 0
Conv2d-37 [-1, 128, 30, 30] 32,768
BatchNorm2d-38 [-1, 128, 30, 30] 256
ReLU-39 [-1, 128, 30, 30] 0
Conv2d-40 [-1, 128, 15, 15] 147,456
BatchNorm2d-41 [-1, 128, 15, 15] 256
ReLU-42 [-1, 128, 15, 15] 0
Conv2d-43 [-1, 512, 15, 15] 65,536
BatchNorm2d-44 [-1, 512, 15, 15] 1,024
Conv2d-45 [-1, 512, 15, 15] 131,072
BatchNorm2d-46 [-1, 512, 15, 15] 1,024
ReLU-47 [-1, 512, 15, 15] 0
Bottleneck-48 [-1, 512, 15, 15] 0
Conv2d-49 [-1, 128, 15, 15] 65,536
BatchNorm2d-50 [-1, 128, 15, 15] 256
ReLU-51 [-1, 128, 15, 15] 0
Conv2d-52 [-1, 128, 15, 15] 147,456
BatchNorm2d-53 [-1, 128, 15, 15] 256
ReLU-54 [-1, 128, 15, 15] 0
Conv2d-55 [-1, 512, 15, 15] 65,536
BatchNorm2d-56 [-1, 512, 15, 15] 1,024
ReLU-57 [-1, 512, 15, 15] 0
Bottleneck-58 [-1, 512, 15, 15] 0
Conv2d-59 [-1, 128, 15, 15] 65,536
BatchNorm2d-60 [-1, 128, 15, 15] 256
ReLU-61 [-1, 128, 15, 15] 0
Conv2d-62 [-1, 128, 15, 15] 147,456
BatchNorm2d-63 [-1, 128, 15, 15] 256
ReLU-64 [-1, 128, 15, 15] 0
Conv2d-65 [-1, 512, 15, 15] 65,536
BatchNorm2d-66 [-1, 512, 15, 15] 1,024
ReLU-67 [-1, 512, 15, 15] 0
Bottleneck-68 [-1, 512, 15, 15] 0
Conv2d-69 [-1, 128, 15, 15] 65,536
BatchNorm2d-70 [-1, 128, 15, 15] 256
ReLU-71 [-1, 128, 15, 15] 0
Conv2d-72 [-1, 128, 15, 15] 147,456
BatchNorm2d-73 [-1, 128, 15, 15] 256
ReLU-74 [-1, 128, 15, 15] 0
Conv2d-75 [-1, 512, 15, 15] 65,536
BatchNorm2d-76 [-1, 512, 15, 15] 1,024
ReLU-77 [-1, 512, 15, 15] 0
Bottleneck-78 [-1, 512, 15, 15] 0
Conv2d-79 [-1, 256, 15, 15] 131,072
BatchNorm2d-80 [-1, 256, 15, 15] 512
ReLU-81 [-1, 256, 15, 15] 0
Conv2d-82 [-1, 256, 8, 8] 589,824
BatchNorm2d-83 [-1, 256, 8, 8] 512
ReLU-84 [-1, 256, 8, 8] 0
Conv2d-85 [-1, 1024, 8, 8] 262,144
BatchNorm2d-86 [-1, 1024, 8, 8] 2,048
Conv2d-87 [-1, 1024, 8, 8] 524,288
BatchNorm2d-88 [-1, 1024, 8, 8] 2,048
ReLU-89 [-1, 1024, 8, 8] 0
Bottleneck-90 [-1, 1024, 8, 8] 0
Conv2d-91 [-1, 256, 8, 8] 262,144
BatchNorm2d-92 [-1, 256, 8, 8] 512
ReLU-93 [-1, 256, 8, 8] 0
Conv2d-94 [-1, 256, 8, 8] 589,824
BatchNorm2d-95 [-1, 256, 8, 8] 512
ReLU-96 [-1, 256, 8, 8] 0
Conv2d-97 [-1, 1024, 8, 8] 262,144
BatchNorm2d-98 [-1, 1024, 8, 8] 2,048
ReLU-99 [-1, 1024, 8, 8] 0
Bottleneck-100 [-1, 1024, 8, 8] 0
Conv2d-101 [-1, 256, 8, 8] 262,144
BatchNorm2d-102 [-1, 256, 8, 8] 512
ReLU-103 [-1, 256, 8, 8] 0
Conv2d-104 [-1, 256, 8, 8] 589,824
BatchNorm2d-105 [-1, 256, 8, 8] 512
ReLU-106 [-1, 256, 8, 8] 0
Conv2d-107 [-1, 1024, 8, 8] 262,144
BatchNorm2d-108 [-1, 1024, 8, 8] 2,048
ReLU-109 [-1, 1024, 8, 8] 0
Bottleneck-110 [-1, 1024, 8, 8] 0
Conv2d-111 [-1, 256, 8, 8] 262,144
BatchNorm2d-112 [-1, 256, 8, 8] 512
ReLU-113 [-1, 256, 8, 8] 0
Conv2d-114 [-1, 256, 8, 8] 589,824
BatchNorm2d-115 [-1, 256, 8, 8] 512
ReLU-116 [-1, 256, 8, 8] 0
Conv2d-117 [-1, 1024, 8, 8] 262,144
BatchNorm2d-118 [-1, 1024, 8, 8] 2,048
ReLU-119 [-1, 1024, 8, 8] 0
Bottleneck-120 [-1, 1024, 8, 8] 0
Conv2d-121 [-1, 256, 8, 8] 262,144
BatchNorm2d-122 [-1, 256, 8, 8] 512
ReLU-123 [-1, 256, 8, 8] 0
Conv2d-124 [-1, 256, 8, 8] 589,824
BatchNorm2d-125 [-1, 256, 8, 8] 512
ReLU-126 [-1, 256, 8, 8] 0
Conv2d-127 [-1, 1024, 8, 8] 262,144
BatchNorm2d-128 [-1, 1024, 8, 8] 2,048
ReLU-129 [-1, 1024, 8, 8] 0
Bottleneck-130 [-1, 1024, 8, 8] 0
Conv2d-131 [-1, 256, 8, 8] 262,144
BatchNorm2d-132 [-1, 256, 8, 8] 512
ReLU-133 [-1, 256, 8, 8] 0
Conv2d-134 [-1, 256, 8, 8] 589,824
BatchNorm2d-135 [-1, 256, 8, 8] 512
ReLU-136 [-1, 256, 8, 8] 0
Conv2d-137 [-1, 1024, 8, 8] 262,144
BatchNorm2d-138 [-1, 1024, 8, 8] 2,048
ReLU-139 [-1, 1024, 8, 8] 0
Bottleneck-140 [-1, 1024, 8, 8] 0
Conv2d-141 [-1, 512, 8, 8] 524,288
BatchNorm2d-142 [-1, 512, 8, 8] 1,024
ReLU-143 [-1, 512, 8, 8] 0
Conv2d-144 [-1, 512, 4, 4] 2,359,296
BatchNorm2d-145 [-1, 512, 4, 4] 1,024
ReLU-146 [-1, 512, 4, 4] 0
Conv2d-147 [-1, 2048, 4, 4] 1,048,576
BatchNorm2d-148 [-1, 2048, 4, 4] 4,096
Conv2d-149 [-1, 2048, 4, 4] 2,097,152
BatchNorm2d-150 [-1, 2048, 4, 4] 4,096
ReLU-151 [-1, 2048, 4, 4] 0
Bottleneck-152 [-1, 2048, 4, 4] 0
Conv2d-153 [-1, 512, 4, 4] 1,048,576
BatchNorm2d-154 [-1, 512, 4, 4] 1,024
ReLU-155 [-1, 512, 4, 4] 0
Conv2d-156 [-1, 512, 4, 4] 2,359,296
BatchNorm2d-157 [-1, 512, 4, 4] 1,024
ReLU-158 [-1, 512, 4, 4] 0
Conv2d-159 [-1, 2048, 4, 4] 1,048,576
BatchNorm2d-160 [-1, 2048, 4, 4] 4,096
ReLU-161 [-1, 2048, 4, 4] 0
Bottleneck-162 [-1, 2048, 4, 4] 0
Conv2d-163 [-1, 512, 4, 4] 1,048,576
BatchNorm2d-164 [-1, 512, 4, 4] 1,024
ReLU-165 [-1, 512, 4, 4] 0
Conv2d-166 [-1, 512, 4, 4] 2,359,296
BatchNorm2d-167 [-1, 512, 4, 4] 1,024
ReLU-168 [-1, 512, 4, 4] 0
Conv2d-169 [-1, 2048, 4, 4] 1,048,576
BatchNorm2d-170 [-1, 2048, 4, 4] 4,096
ReLU-171 [-1, 2048, 4, 4] 0
Bottleneck-172 [-1, 2048, 4, 4] 0
AdaptiveAvgPool2d-173 [-1, 2048, 1, 1] 0
Linear-174 [-1, 1000] 2,049,000
LeakyReLU-175 [-1, 1000] 0
Dropout-176 [-1, 1000] 0
Linear-177 [-1, 1000] 1,001,000
LeakyReLU-178 [-1, 1000] 0
Dropout-179 [-1, 1000] 0
Linear-180 [-1, 809] 809,809
PokemonFCBlock-181 [-1, 809] 0
ResNet-182 [-1, 809] 0
================================================================
Total params: 27,367,841
Trainable params: 3,859,809
Non-trainable params: 23,508,032
----------------------------------------------------------------
Input size (MB): 0.16
Forward/backward pass size (MB): 85.01
Params size (MB): 104.40
Estimated Total Size (MB): 189.58
----------------------------------------------------------------
손실 함수와 최적화 함수를 지정해줍니다.
크로스 엔트로피 로스, 아담 함수입니다.
# 패러미터 조정. callback은 미사용
# optimizer 조정
criterion = nn.CrossEntropyLoss() # 초기 BCELoss
optimizer= torch.optim.Adam(model.parameters(), lr=0.0001) # 최초 ADAMW(lr = 5e-4, weight decay는 미설정)
모델 트레이너를 사용해서 디바이스에 학습시켜 줍니다.
class ModelTrainer():
def __init__(self, model, optimizer, criterion, device):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device
# self.callback = callback
self.epoch = 1
def train_loop(self, dataloader, epoch=100):
raise NotImplementedError
def validate_loop(self, dataloader):
raise NotImplementedError
def train(self, train_dataloader, val_dataloader):
raise NotImplementedError
epoch를 시작하면 갈수록 트레인 cost와 검증 cost를 보여주게 됩니다.
class MultilabelModelTrainer(ModelTrainer):
def __init__(self, model, optimizer, criterion, device):
super(MultilabelModelTrainer, self).__init__(model, optimizer, criterion, device)
def train_loop(self, dataloader, epoch=100):
self.model.train()
cost = 0
t = tqdm(dataloader)
if epoch:
t.set_description(f"Training mode, Epoch {self.epoch}")
for feature, target in t:
feature, target = feature.to(self.device), target.to(self.device)
output = self.model(feature).double()
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
cost += loss.item() * feature.shape[0]
return cost / len(dataloader.dataset)
def validate_loop(self, dataloader):
self.model.eval()
cost = 0
with torch.no_grad():
for feature, target in tqdm(dataloader, desc="Validation mode"):
feature, target = feature.to(self.device), target.to(self.device)
output = self.model(feature).double()
loss = self.criterion(output, target)
cost += loss.item() * feature.shape[0]
return cost/len(dataloader.dataset)
def train(self, train_dataloader, val_dataloader):
for _ in range(1, 100):
train_cost = self.train_loop(train_dataloader)
val_cost = self.validate_loop(val_dataloader)
self.epoch += 1
print(f'1) 학습 cost : {train_cost},\n2) 검증 cost : {val_cost}')
epoch를 통해 학습을 진행, cost를 뽑고 마지막으로 모델을 저장해줍니다.
# 시간 체크
# 에포크 숫자가 따로 표시
model_trainer = MultilabelModelTrainer(model, optimizer, criterion, device) # (시간 꼭 체크)
model_trainer.train(train_dataloader, val_dataloader)
torch.save(model.state_dict(), "pokemon_epoch_Cross_entropy_loss_ADAM_LR=0.0001(100회).pt")
(에포크 94~100회까지만 조회)
모델을 로드해보겠습니다.
model.load_state_dict(torch.load("/content/pokemon_epoch_Cross_entropy_loss_ADAM_LR=0.0001(100회).pt", map_location=device))
model
아래 colab만의 양식란 추가를 통해 구현한 시각화 입니다.
포켓몬 번호를 입력해볼까요?
#@title 포켓몬 번호를 입력해볼까요?
번호를_입력해요 = 5 #@param {type:"number"}
pokemon = train_dataset[번호를_입력해요]
with torch.no_grad():
feature = pokemon[0].to(device)
prediction = model(feature.unsqueeze(0))
print(f'>>>>>>>>[ {번호를_입력해요} ]번째 포켓몬')
show_image(pokemon[0], *train_dataset.get_labels_from_vector(pokemon[1]))
>>>>>>>>[ 5 ]번째 포켓몬
이 포켓몬의 이름은? 리자몽
시각화까지 종료
'딥러닝 > 개인구현 정리' 카테고리의 다른 글
[DeepLearning] GAN 모델 활용_CelebA얼굴 이미지 구분_2 (0) | 2023.02.08 |
---|---|
[DeepLearning] GAN 모델 활용_CelebA얼굴 이미지 구분_1 (0) | 2023.02.01 |
[DeepLearning] 이미지 구분 모델_Pokemon 809 세트_ep.1 (0) | 2023.01.26 |
[이미지 처리] 타코와 브리또의 이미지 구분 모델 (0) | 2023.01.24 |
[자연어 처리 학습] 셰익스피어 비극 대본집_중세말투 학습_2 (0) | 2023.01.23 |