What is Transfer Learning?
Transfer learning is a machine learning method where a model developed for a task is reused (pre-trained model) as the starting point for a model on a second task. The pre-trained model, called the base model, is used by deleting a few of the last layers. After that, if we want to build CNNs, we can do so by just connecting the top part of the CNN, called the head, and the frozen base model.
What is Fine Tuning?
Fine-tuning is more suited for bigger datasets, especially when the custom dataset is quite different from the original image dataset. Fine-tuning is similar to Transfer Learning except that the base model is half-frozen.
When to use these techniques
dataset | technique |
---|---|
Large and different | Train the whole model |
Large and similar | Fine tuning |
Small and different | Fine tuning |
Small and similar | Transfer Learning |
Example
Code
Transfer Learning
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
IMG_SHAPE = (128, 128, 3)
# include_top : decide whether fully connected network or not
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
# Freezing the base model
base_model.trainable = False
# because in this case, I would have 4*4*1280(20480) new weights to train. So, to reduce input shape I use GlobalAveragePooling
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
prediction_layer = tf.keras.layers.Dense(units=1, activation='sigmoid')(global_average_layer)
# combine two networks(base_model, prediction_layer)
model = tf.keras.models.Model(inputs=base_model.input, outputs=prediction_layer)
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),
loss='binary_crossentropy',
metrics=['accuracy'])
data_gen_train = ImageDataGenerator(rescale=1/255.)
data_gen_valid = ImageDataGenerator(rescale=1/255.)
train_generator = data_gen_train.flow_from_directory(train_dir, target_size=(128,128), batch_size=128, class_mode='binary')
valid_generator = data_gen_valid.flow_from_directory(validation_dir, target_size=(128,128), batch_size=128, class_mode='binary')
model.fit_generator(train_generator, epochs=50, validation_data=valid_generator)
valid_loss, valid_accuracy = model.evaluate_generator(valid_generator)
print("Accuracy after transfer learning: {}".format(valid_accuracy))
Fine tuning
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
IMG_SHAPE = (128, 128, 3)
# include_top : decide whether fully connected network or not
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
# Half frozen the base model
base_model.trainable = True
fine_tune_at = 100
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
# because in this case, I would have 4*4*1280(20480) new weights to train. So, to reduce input shape I use GlobalAveragePooling
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
prediction_layer = tf.keras.layers.Dense(units=1, activation='sigmoid')(global_average_layer)
model = tf.keras.models.Model(inputs=base_model.input, outputs=prediction_layer)
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),
loss='binary_crossentropy',
metrics=['accuracy'])
data_gen_train = ImageDataGenerator(rescale=1/255.)
data_gen_valid = ImageDataGenerator(rescale=1/255.)
train_generator = data_gen_train.flow_from_directory(train_dir, target_size=(128,128), batch_size=128, class_mode='binary')
valid_generator = data_gen_valid.flow_from_directory(validation_dir, target_size=(128,128), batch_size=128, class_mode='binary')
model.fit_generator(train_generator, epochs=5, validation_data=valid_generator)
valid_loss, valid_accuracy = model.evaluate_generator(valid_generator)
print('Validation accuracy after fine tuning: {}'.format(valid_accuracy))
Result
Transfer Learning
1
Accuracy after transfer learning: 0.9629999995231628
Fine Tuning
1
Validation accuracy after fine tuning: 0.9729999899864197