Posts 사전훈련모델과 전이 학습
Post
Cancel

사전훈련모델과 전이 학습

1. 사전훈련모델


1.1 사전훈련모델

  • 좋은 성능을 보이는 모델들은 아주 무겁고, 훈련시간도 오래걸린다
    • ResNet-50은 8개의 P100 GPU를 29시간 학습해서 얻은 모델
    • FaceBook의 인공지능팀이 이를 1시간으로 줄였는데, 대신 GPU를 256개 사용함
  • 보통 많은 연구자들이 자신이 학습한 모델을 공개함
  • 공개된 모델들을 보통은 사전훈련된모델이라고 이야기하고, 이를 그대로 사용하거나 혹은 전이학습에 이용함


1.2 Tensorflow Hub

  • Tensorflow Hub는 재사용 가능한 머신러닝 모듈 라이브러리
  • 설치 : pip install tensorflow_hub


1.3 Mobile-net V2


1.4 Mobile-net v2(사전학습된모델) 불러오기

1
2
3
4
5
6
7
8
9
10
import tensorflow as tf
import numpy as np
import tensorflow_hub as hub

url = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2'
model = tf.keras.Sequential([
    hub.KerasLayer(handle = url, input_shape = (224, 224, 3), trainable = False)
])

model.summary()
1
2
3
4
5
6
7
8
9
10
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 1001)              3540265   
=================================================================
Total params: 3,540,265
Trainable params: 0
Non-trainable params: 3,540,265
_________________________________________________________________
  • 파라미터의 수가 아주 작음


1.5 테스트용 이미지 : ImageNetV2

  • ImageNet의 일부 데이터를 모아놓은 ImageNetv2
  • 아마존 메커니컬 터크에서 배포함
  • 사람의 수작업이 많이 필요한 이미지 라벨링 등을 위해 비교적 저렴한 가격으로 라벨링된 이미지를 제공하는 플랫폼


1.5 ImageNetV2 다운로드

1
2
3
4
5
6
7
import pathlib
import os

im_url = 'https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-topimages.tar.gz'
data_root_orig = tf.keras.utils.get_file('imagenetV2', im_url, cache_dir = './data/imagenetv2-topimages', extract= True)
data_root = pathlib.Path('./data/imagenetv2-topimages')
print(data_root)
1
data/imagenetv2-topimages
  • cache_dir을 설정하여 원하는 경로에 ImageNetV2를 다운할수 있음


1.6 데이터 경로의 이상유무 확인

1
2
3
4
for idx, item in enumerate(data_root.iterdir()):
    print(item)
    if idx ==5:
        break
1
2
3
4
5
6
data/imagenetv2-topimages/797
data/imagenetv2-topimages/909
data/imagenetv2-topimages/135
data/imagenetv2-topimages/307
data/imagenetv2-topimages/763
data/imagenetv2-topimages/551
  • 데이터셋을 확인했고, 큰 이상유무는 없는듯 싶다.


1.7 라벨 불러오기

1
2
3
4
5
6
7
8
label_url = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
label_file = tf.keras.utils.get_file('label', label_url)
label_text = None
with open(label_file, 'r') as f:
    label_text = f.read().split('\n')[:-1]
print(len(label_text))
print(label_text[:10])
print(label_text[-10:])
1
2
3
4
5
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
16384/10484 [==============================================] - 0s 0us/step
1001
['background', 'tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen']
['buckeye', 'coral fungus', 'agaric', 'gyromitra', 'stinkhorn', 'earthstar', 'hen-of-the-woods', 'bolete', 'ear', 'toilet tissue']


1.8 이미지 보기

1
2
3
4
5
6
7
8
import random

all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)

image_count = len(all_image_paths)
print(f'Image Count : :{image_count}')
1
Image Count : :10002
  • 총 이미지는 1만 2개 있음
  • 어떻게 생겨먹은 이미지 일까?


1
2
3
4
5
6
7
8
9
10
11
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 12))
for c in range(9):
    image_path = random.choice(all_image_paths)
    plt.subplot(3, 3, c + 1)
    plt.imshow(plt.imread(image_path))
    idx = int(image_path.split('/')[-2]) + 1
    plt.title(str(idx) +', ' + label_text[idx])
    plt.axis('off')
plt.show()

  • 이렇게 생겼다


1.9 Test

  • pip install opencv-python
  • pip install opencv-contrib-python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import cv2

img = random.choice(all_image_paths)
label = int(img.split('/')[-2]) + 1
img_draw = cv2.imread(img)
img_resized = cv2.resize(img_draw, dsize = (224, 224))
img_resized = img_resized / 255.0
img_resized = np.expand_dims(img_resized, axis = 0)
top_5_predict = model.predict(img_resized)[0].argsort()[::-1][:5]
print(top_5_predict)
print(label)
if label in top_5_predict:
    print('Anser is correct !!')
print(f'Predicted Answer is {label_text[label]}')

plt.imshow(plt.imread(img))
plt.show()
1
2
3
4
[827 427 636 836 713]
827
Anser is correct !!
Predicted Answer is stopwatch

  • 이미지를 아까 사전학습한 모델에 넣고 예측하기, 예측한것중에 상위 5개(확률)중 실제 라벨이 있다면 correct를 출력함


1.10 Softmax 계산 및 확률값 시각화

1
2
3
4
5
6
7
8
9
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis = 0)

logits = model.predict(img_resized)[0]
prediction = softmax(logits)

top_5_predict = prediction.argsort()[::-1][:5]
labels = [label_text[index] for index in top_5_predict]
  • SoftMax 계산하고, 상위 5개 라벨에 대한 확률값을 계산함


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.imshow(plt.imread(img))

idx = int(img.split('/')[-2]) + 1
plt.title(str(idx) + ', ' + label_text[idx])
plt.axis('off')

plt.subplot(1, 2, 2)
color = ['gray'] * 5
if idx in top_5_predict:
    color[top_5_predict.tolist().index(idx)] = 'green'
color = color[::-1]
plt.barh(range(5), prediction[top_5_predict][::-1] * 100, color = color)
plt.yticks(range(5), labels[::-1])
plt.show()

  • 해당 이미지에 상위 5개 라벨의 확률을 시각화화서 보여줌


2. 전이학습


2.1 전이학습 - Transfer Learning

  • 전이학습은 기존의 사전학습모델에서 일부 layer들을 가져와서 재사용하여 비슷한 모델을 생성함


2.2 Pre Trained 모델에 전이학습을 적용할 데이터

https://www.kaggle.com/c/dog-breed-identification/data

  • 강아지 사진과 강아지의 종이 있는 데이터
  • 사진을 보고, 강아지의 종을 맞춘다
  • 위의 링크에서 데이터 다운 가능
  • 편의를 위해 폴더이름을 dog_data로 변경


2.3 Label 데이터 확인

1
2
3
import pandas as pd
label_text = pd.read_csv('./data/dog_data/labels.csv')
print(label_text.head())
1
2
3
4
5
6
                                 id             breed
0  000bec180eb18c7604dcecc8fe0dba07       boston_bull
1  001513dfcb2ffafc82cccf4d8bbaba97             dingo
2  001cdf01b096e06d78e9e5112d419397          pekinese
3  00214f311d5d2247d5dfe4fe24b2303d          bluetick
4  0021f9ceb3235effd7fcde7f7538ed62  golden_retriever
  • label에는 id값과 종이 적혀있음


1
2
label_text.info()
print(label_text['breed'].nunique(),'장')
1
2
3
4
5
6
7
8
9
10
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10222 entries, 0 to 10221
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   id      10222 non-null  object
 1   breed   10222 non-null  object
dtypes: object(2)
memory usage: 159.8+ KB
120 장
  • 만장이 넘는 dog 데이터와 120종의 품종


2.4 데이터 시각화

1
2
3
4
5
6
7
8
plt.figure(figsize=(12, 8))
for c in range(6):
    image_id = label_text.loc[c, 'id']
    plt.subplot(2, 3, c + 1)
    plt.imshow(plt.imread('./data/dog_data/train/' + image_id + '.jpg'))
    plt.title(str(c) + ', ' + label_text.loc[c, 'breed'])
    plt.axis('off')
plt.show()

  • 이렇게 생긴 사진이다.


2.5 MobileNet V2 Load

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from tensorflow.keras.applications import MobileNetV2
import tensorflow as tf
mobilev2 = MobileNetV2()

x = mobilev2.layers[-2].output
predictions = tf.keras.layers.Dense(120, activation='softmax')(x)
model = tf.keras.Model(inputs=mobilev2.input, outputs=predictions)

for layer in model.layers[:-20]:
    layer.trainable = False
for layer in model.layers[-20:]:
    layer.trainable = True

model.compile(optimizer='sgd',
              loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 225, 225, 3)  0           input_2[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 112, 112, 32) 0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 112, 112, 32) 288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 112, 112, 32) 128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 112, 112, 32) 0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 112, 112, 16) 512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 112, 112, 16) 64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 112, 112, 96) 1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 112, 112, 96) 384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 112, 112, 96) 0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 113, 113, 96) 0           block_1_expand_relu[0][0]        
...
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 7, 7, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 7, 7, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 7, 7, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 7, 7, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 7, 7, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 7, 7, 1280)   0           Conv_1_bn[0][0]                  
__________________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 1280)         0           out_relu[0][0]                   
__________________________________________________________________________________________________
dense (Dense)                   (None, 120)          153720      global_average_pooling2d_1[0][0] 
==================================================================================================
Total params: 2,411,704
Trainable params: 1,204,280
Non-trainable params: 1,207,424
__________________________________________________________________________________________________
  • 네트워크 구조의 마지막 20개 레이어만 학습하도록 설정
  • Total params의 갯수와 Trainable params의 갯수가 다른건, 20개의 레이어만 사용하기로 했기 떄문
  • 그래도 summary를 보면 뭔가가 엄청 많다..


2.11 Train X, Y Data 생성

1
2
3
4
5
6
7
8
9
10
11
import cv2

train_X = []
for i in range(len(label_text)):
    img = cv2.imread('./data/dog_data/train/' + label_text['id'][i] + '.jpg')
    img = cv2.resize(img, dsize = (224, 224))
    img = img / 255.0
    train_X.append(img)
train_X = np.array(train_X)
print(train_X.shape)
print(train_X.size * train_X.itemsize , ' bytes')
1
2
(10222, 224, 224, 3)
12309577728  bytes


1
2
3
4
5
unique_Y = label_text['breed'].unique().tolist()
train_Y = [unique_Y.index(breed) for breed in label_text['breed']]
train_Y = np.array(train_Y)
print(train_Y[:10])
print(train_Y[-10:])
1
2
[0 1 2 3 4 5 5 6 7 8]
[34 87 91 63 48  6 93 63 77 92]


2.12 Fit

1
history = model.fit(train_X, train_Y, epochs=10, validation_split=0.25, batch_size = 32)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Epoch 1/10
240/240 [==============================] - 102s 424ms/step - loss: 3.2485 - accuracy: 0.2904 - val_loss: 1.9760 - val_accuracy: 0.4534
Epoch 2/10
240/240 [==============================] - 107s 445ms/step - loss: 1.6241 - accuracy: 0.6054 - val_loss: 1.6033 - val_accuracy: 0.5477
Epoch 3/10
240/240 [==============================] - 108s 449ms/step - loss: 1.1931 - accuracy: 0.7131 - val_loss: 1.4534 - val_accuracy: 0.5943
Epoch 4/10
240/240 [==============================] - 108s 449ms/step - loss: 0.9406 - accuracy: 0.7747 - val_loss: 1.3683 - val_accuracy: 0.6072
Epoch 5/10
240/240 [==============================] - 106s 441ms/step - loss: 0.7763 - accuracy: 0.8298 - val_loss: 1.3273 - val_accuracy: 0.6217
Epoch 6/10
240/240 [==============================] - 114s 474ms/step - loss: 0.6375 - accuracy: 0.8729 - val_loss: 1.3028 - val_accuracy: 0.6295
Epoch 7/10
240/240 [==============================] - 107s 447ms/step - loss: 0.5297 - accuracy: 0.9023 - val_loss: 1.2905 - val_accuracy: 0.6307
Epoch 8/10
240/240 [==============================] - 107s 446ms/step - loss: 0.4413 - accuracy: 0.9314 - val_loss: 1.2837 - val_accuracy: 0.6405
Epoch 9/10
240/240 [==============================] - 108s 448ms/step - loss: 0.3795 - accuracy: 0.9510 - val_loss: 1.2573 - val_accuracy: 0.6518
Epoch 10/10
240/240 [==============================] - 108s 449ms/step - loss: 0.3188 - accuracy: 0.9622 - val_loss: 1.2633 - val_accuracy: 0.6424
  • 전이학습중, 전이학습이 아니라면 시간은 훨씬 많이 들어간다.


2.13 학습상황 확인

1
2
3
4
5
6
7
8
plt.figure(figsize = (12, 8))
plt.plot(history.history['loss'], 'b-', label = 'loss')
plt.plot(history.history['val_loss'], 'r--', label = 'val_loss')
plt.plot(history.history['accuracy'], 'g-', label = 'accuracy')
plt.plot(history.history['val_accuracy'], 'k--', label = 'val_accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.show()

  • 일단 과적합 처럼 보이긴하나, 학습을 오래한것은 아니기에..잘은 모르겠다.
This post is licensed under CC BY 4.0 by the author.

고양이, 강아지 사진으로 해보는 CNN

특징 추출기 (Feature Extractor)

Comments powered by Disqus.