A quick tutorial on how to perform Knowledge distillation with R, in eager mode.
knitr::opts_chunk$set(echo = TRUE)
Hi everyone ! Welcome to my blog. Here I will just share some tutorials around things that were complicated for me, and for which others R users could be interested. Not surprisingly, lot of this tutorials will involve tensorflow or other deep learning things.
Sometimes things are possible in R, but, since our community is smaller, we don’t have that many resources or tutorials compared to the python community, explaining why it is cubersome to do some particuliar tasks in R, especially when the few tutorials available or interfaces packages start accumulate errors or bugs because they are not used often by an active community.
I am not an expert, so I will try to source at maximum of my codes, or parameters when I can. I used a small size for the images to not blow my GPU, there is an example with fine tuning and a bigger GPU here.
There is probably a lack of optimization, but at least it is a working skeleton. If you have suggestion for improvement, comments are welcome :D
I wrote this code in the first place in the context of the Cassava Leaf Disease Classification, a Kaggle’s competition where the goal was to train a model to identify the disease on leafs of cassava. Here the distillation is made from an Efficientnet0 to an other one.
As presented in this discussion thread on kaggle, knowledge distillation is defined as simply trains another individual model to match the output of an ensemble. Source. It is in fact slightly more complicated : the second neural net (student) will made predictions on the images, but then, the losses will be a function of its own loss as well as a loss based on the difference between his prediction and the one of its teacher or the ensemble.
This approach allow to compress an ensemble into one model and by then reduce the inference time, or, if trained to match the output of a model, to increase the overall performance of the model. I discover this approach by looking at the top solutions of the Plant Pathology 2020 competition, an other solution with computer vision and leaf, such as this one.
I let you go to to this source mentioned aboved to understand how it could potentially works. It does not seems sure, but it seems related to the learning of specific features vs forcing the student to learn “multiple view”, multiple type of feature to detect in the images.
There is off course, no starting material to do it in R. Thanksfully there is a code example on the website of keras. In this example, they create a class of model, a distiller, to make the knowledge distillation. There is, however, one problem : model are not inheritable in R. There is example of inheritance with a R6 for callback, like here, but the models are not a R6 class. To overcome this problem, I used the code example as a guide, and reproduced the steps by following the approach in this guide for eager executation in keras with R. I took other code from the tensorflow website for R.
The code is quite hard to understand at first glance. The reason is, everything is executed in a single for loop, since everything is done in eager mode. It did not seemed possible to do it differently. So there is a lot of variable around to collect metrics during training. If you want to understand the code just remove it from the loop and run it outside of the for loop, before reconstructing the loop around. I did not used tfdataset as shown on the guide for eager execution, so instead of make_iterator_one_shot() and iterator_get_next(), here we loop over the train_generator to produce the batches.
library(tidyverse)
library(tensorflow)
tf$executing_eagerly()
[1] TRUE
tensorflow::tf_version()
[1] '2.3'
Here I flex with my own version of keras. Basically, it is a fork with application wrapper for the efficient net.
Disclaimer : I did not write the code for the really handy applications wrappers. It came from this commit for which the PR is hold until the fully release of tf 2.3, as stated in this PR. I am not sure why the PR is closed.
devtools::install_github("Cdk29/keras", dependencies = FALSE)
labels<-read_csv('train.csv')
head(labels)
# A tibble: 6 x 2
image_id label
<chr> <dbl>
1 1000015157.jpg 0
2 1000201771.jpg 3
3 100042118.jpg 1
4 1000723321.jpg 1
5 1000812911.jpg 3
6 1000837476.jpg 3
labels$CBB<-0
labels$CBSD<-0
labels$CGM<-0
labels$CMD<-0
labels$Healthy<-0
labels$CBB[idx0]<-1
labels$CBSD[idx1]<-1
labels$CGM[idx2]<-1
labels$CMD[idx3]<-1
“Would it have been easier to create a function to convert the labelling ?” You may ask.
labels$Healthy[idx4]<-1
Probably.
#labels$label<-NULL
head(labels)
# A tibble: 6 x 7
image_id label CBB CBSD CGM CMD Healthy
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 1000015157.jpg 0 1 0 0 0 0
2 1000201771.jpg 3 0 0 0 1 0
3 100042118.jpg 1 0 1 0 0 0
4 1000723321.jpg 1 0 1 0 0 0
5 1000812911.jpg 3 0 0 0 1 0
6 1000837476.jpg 3 0 0 0 1 0
val_labels<-read_csv('validation_set.csv')
train_labels<-labels[which(!labels$image_id %in% val_labels$image_id),]
table(train_labels$image_id %in% val_labels$image_id)
FALSE
19256
train_labels$label<-NULL
val_labels$label<-NULL
head(train_labels)
# A tibble: 6 x 6
image_id CBB CBSD CGM CMD Healthy
<chr> <dbl> <dbl> <dbl> <dbl> <dbl>
1 1000015157.jpg 1 0 0 0 0
2 1000201771.jpg 0 0 0 1 0
3 100042118.jpg 0 1 0 0 0
4 1000723321.jpg 0 1 0 0 0
5 1000812911.jpg 0 0 0 1 0
6 1000837476.jpg 0 0 0 1 0
head(val_labels)
# A tibble: 6 x 6
image_id CBB CBSD CGM CMD Healthy
<chr> <dbl> <dbl> <dbl> <dbl> <dbl>
1 1003442061.jpg 0 0 0 0 1
2 1004672608.jpg 0 0 0 1 0
3 1007891044.jpg 0 0 0 1 0
4 1009845426.jpg 0 0 0 1 0
5 1010648150.jpg 0 0 0 1 0
6 1011139244.jpg 0 0 0 1 0
image_path<-'cassava-leaf-disease-classification/train_images/'
#data augmentation
datagen <- image_data_generator(
rotation_range = 40,
width_shift_range = 0.2,
height_shift_range = 0.2,
shear_range = 0.2,
zoom_range = 0.5,
horizontal_flip = TRUE,
fill_mode = "reflect"
)
img_path<-"cassava-leaf-disease-classification/train_images/1000015157.jpg"
img <- image_load(img_path, target_size = c(448, 448))
img_array <- image_to_array(img)
img_array <- array_reshape(img_array, c(1, 448, 448, 3))
img_array<-img_array/255
# Generated that will flow augmented images
augmentation_generator <- flow_images_from_data(
img_array,
generator = datagen,
batch_size = 1
)
op <- par(mfrow = c(2, 2), pty = "s", mar = c(1, 0, 1, 0))
for (i in 1:4) {
batch <- generator_next(augmentation_generator)
plot(as.raster(batch[1,,,]))
}
par(op)
Okay so here is an interresting thing, I will try to compress the code to call a train generator to make it easier to call it.
Why ? Well, apparently a generator does not yield infinite batches, and the for loop of the distiller will stop working without obvious reason at epoch 7, when reaching the end of the validation generator.
When we iterate over it, validation_generator yeld 8 images and 8 label, until the batch 267, than contains only 5 images (and create the bug when we try to add the loss of the batch to the loss of the epoch. Batch 268 does not exist. So solution seems to recreate on the fly the validation set and restart the iterations.
validation_generator <- do.call(flow_images_from_dataframe, arg.list)
dim(validation_generator[266][[1]])
[1] 8 228 228 3
dim(validation_generator[267][[1]])
[1] 5 228 228 3
dim(val_labels)
[1] 2141 6
2141/8
[1] 267.625
train_generator <- flow_images_from_dataframe(dataframe = train_labels,
directory = image_path,
generator = datagen,
class_mode = "other",
x_col = "image_id",
y_col = c("CBB","CBSD", "CGM", "CMD", "Healthy"),
target_size = c(228, 228),
batch_size=8)
validation_generator <- flow_images_from_dataframe(dataframe = val_labels,
directory = image_path,
class_mode = "other",
x_col = "image_id",
y_col = c("CBB","CBSD", "CGM", "CMD", "Healthy"),
target_size = c(228, 228),
batch_size=8)
train_generator
<tensorflow.python.keras.preprocessing.image.DataFrameIterator>
conv_base<-keras::application_efficientnet_b0(weights = "imagenet", include_top = FALSE, input_shape = c(228, 228, 3))
freeze_weights(conv_base)
model <- keras_model_sequential() %>%
conv_base %>%
layer_global_max_pooling_2d() %>%
layer_batch_normalization() %>%
layer_dropout(rate=0.5) %>%
layer_dense(units=5, activation="softmax")
#unfreeze_weights(model, from = 'block5a_expand_conv')
unfreeze_weights(conv_base, from = 'block5a_expand_conv')
model %>% load_model_weights_hdf5("fine_tuned_eff_net_weights.15.hdf5")
summary(model)
Model: "sequential"
______________________________________________________________________
Layer (type) Output Shape Param #
======================================================================
efficientnetb0 (Functional) (None, 8, 8, 1280) 4049571
______________________________________________________________________
global_max_pooling2d (GlobalMa (None, 1280) 0
______________________________________________________________________
batch_normalization (BatchNorm (None, 1280) 5120
______________________________________________________________________
dropout (Dropout) (None, 1280) 0
______________________________________________________________________
dense (Dense) (None, 5) 6405
======================================================================
Total params: 4,061,096
Trainable params: 3,707,853
Non-trainable params: 353,243
______________________________________________________________________
conv_base_student<-keras::application_efficientnet_b0(weights = "imagenet", include_top = FALSE, input_shape = c(228, 228, 3))
freeze_weights(conv_base_student)
student <- keras_model_sequential() %>%
conv_base_student %>%
layer_global_max_pooling_2d() %>%
layer_batch_normalization() %>%
layer_dropout(rate=0.5) %>%
layer_dense(units=5, activation="softmax")
student
Model
Model: "sequential_1"
______________________________________________________________________
Layer (type) Output Shape Param #
======================================================================
efficientnetb0 (Functional) (None, 8, 8, 1280) 4049571
______________________________________________________________________
global_max_pooling2d_1 (Global (None, 1280) 0
______________________________________________________________________
batch_normalization_1 (BatchNo (None, 1280) 5120
______________________________________________________________________
dropout_1 (Dropout) (None, 1280) 0
______________________________________________________________________
dense_1 (Dense) (None, 5) 6405
======================================================================
Total params: 4,061,096
Trainable params: 8,965
Non-trainable params: 4,052,131
______________________________________________________________________
Source code for knowledge distillation with Keras : https://keras.io/examples/vision/knowledge_distillation/
Help for eager executation details in R and various usefull code : https://keras.rstudio.com/articles/eager_guide.html
Other source code in R : https://tensorflow.rstudio.com/tutorials/advanced/
I am using an alpha parameter of 0.9 as suggested by this article.
i=1
alpha=0.9 #On_the_Efficacy_of_Knowledge_Distillation_ICCV_2019
temperature=3
optimizer <- optimizer_adam()
train_loss <- tf$keras$metrics$Mean(name='student_loss')
train_accuracy <- tf$keras$metrics$CategoricalAccuracy(name='train_accuracy')
nb_epoch<-12
nb_batch<-300
val_step<-40
count_epoch<-0
for (epoch in 1:nb_epoch) {
cat("Epoch: ", epoch, " -----------\n")
# Init metrics
train_loss_epoch <- 0
accuracies_on_epoch <- c()
distilation_loss_epoch <- 0
val_loss_epoch <- 0
val_accuaries_on_epoch <- c()
#Formula to not see the same batch over and over on each epoch
#Count epoch instead of epoch
count_epoch<-count_epoch+1
idx_batch <- (1+nb_batch*(count_epoch-1)):(nb_batch*count_epoch)
idx_val_set <- (1+val_step*(count_epoch-1)):(val_step*count_epoch)
#Dirty solution to restart on a new validation batch generator before reaching the end of the other one
if (as.integer((dim(val_labels)[1]/8)-1) %in% idx_val_set) {
count_epoch<-1
idx_val_set <- (1+val_step*(count_epoch-1)):(val_step*count_epoch)
validation_generator <- do.call(flow_images_from_dataframe, arg.list)
}
#need the same if for train generator
if (as.integer((dim(train_labels)[1]/8)-1) %in% idx_batch) {
count_epoch<-1
idx_batch <- (1+nb_batch*(count_epoch-1)):(nb_batch*count_epoch)
train_generator <- do.call(flow_images_from_dataframe, arg.list)
}
for (batch in idx_batch) {
x = train_generator[batch][[1]]
y = train_generator[batch][[2]]
# Forward pass of teacher
teacher_predictions = model(x)
with(tf$GradientTape() %as% tape, {
student_predictions = student(x)
student_loss = tf$losses$categorical_crossentropy(y, student_predictions)
distillation_loss = tf$losses$categorical_crossentropy(tf$nn$softmax(teacher_predictions/temperature, axis=0L),
tf$nn$softmax(student_predictions/temperature, axis=0L))
loss = alpha * student_loss + (1 - alpha) * distillation_loss
})
# Compute gradients
# Variating learning rate :
# optimizer <- optimizer_adam(lr = 0.0001)
gradients <- tape$gradient(loss, student$trainable_variables)
optimizer$apply_gradients(purrr::transpose(list(gradients, student$trainable_variables)))
#Collect the metrics of the student
train_loss_epoch <- train_loss_epoch + student_loss
distilation_loss_epoch <- distilation_loss_epoch + distillation_loss
accuracy_on_batch <- train_accuracy(y_true=y, y_pred=student_predictions)
accuracies_on_epoch <- c(accuracies_on_epoch, as.numeric(accuracy_on_batch))
}
#Collect info on current epoch and for graphs and cat()
train_loss_epoch <- mean(as.vector(as.numeric(train_loss_epoch))/nb_batch)
train_loss_plot <- c(train_loss_plot, train_loss_epoch)
distilation_loss_epoch <- mean(as.vector(as.numeric(distilation_loss_epoch))/nb_batch)
distilation_loss_plot <- c(distilation_loss_plot, distilation_loss_epoch)
accuracies_on_epoch <- mean(accuracies_on_epoch)
accuracy_plot <- c(accuracy_plot, accuracies_on_epoch)
for (step in idx_val_set) {
# Unpack the data
x = validation_generator[step][[1]]
y = validation_generator[step][[2]]
# Compute predictions
student_predictions = student(x)
# Calculate the loss
student_loss = tf$losses$categorical_crossentropy(y, student_predictions)
#Collect the metrics of the student
#This line will create a bug of shape when val_loss end.
val_loss_epoch <- val_loss_epoch + student_loss
accuracy_on_val_step <- train_accuracy(y_true=y, y_pred=student_predictions)
val_accuaries_on_epoch <- c(val_accuaries_on_epoch, as.numeric(accuracy_on_val_step))
}
#Collect info on current epoch and for graphs and cat()
val_loss_epoch <- mean(as.vector(as.numeric(val_loss_epoch))/val_step)
val_loss_plot <- c(val_loss_plot, val_loss_epoch)
val_accuaries_on_epoch <- mean(val_accuaries_on_epoch)
val_accuracy_plot <- c(val_accuracy_plot, val_accuaries_on_epoch)
#Plotting
cat("Total loss (epoch): ", epoch, ": ", train_loss_epoch, "\n")
cat("Distillater loss : ", epoch, ": ", distilation_loss_epoch, "\n")
cat("Accuracy (epoch): ", epoch, ": ", accuracies_on_epoch, "\n")
cat("Val loss : ", epoch, ": ", val_loss_epoch, "\n")
cat("Val Accuracy (epoch): ", epoch, ": ", val_accuaries_on_epoch, "\n")
}
Epoch: 1 -----------
Total loss (epoch): 1 : 1.970847
Distillater loss : 1 : 1.006515
Accuracy (epoch): 1 : 0.5028956
Val loss : 1 : 1.683065
Val Accuracy (epoch): 1 : 0.5337647
Epoch: 2 -----------
Total loss (epoch): 2 : 1.671758
Distillater loss : 2 : 1.006187
Accuracy (epoch): 2 : 0.5482197
Val loss : 2 : 1.746699
Val Accuracy (epoch): 2 : 0.5590533
Epoch: 3 -----------
Total loss (epoch): 3 : 1.618646
Distillater loss : 3 : 1.006112
Accuracy (epoch): 3 : 0.5649438
Val loss : 3 : 1.531488
Val Accuracy (epoch): 3 : 0.5679042
Epoch: 4 -----------
Total loss (epoch): 4 : 1.562328
Distillater loss : 4 : 1.005987
Accuracy (epoch): 4 : 0.575272
Val loss : 4 : 1.508584
Val Accuracy (epoch): 4 : 0.5776999
Epoch: 5 -----------
Total loss (epoch): 5 : 1.406053
Distillater loss : 5 : 1.005917
Accuracy (epoch): 5 : 0.5815135
Val loss : 5 : 1.372146
Val Accuracy (epoch): 5 : 0.5892469
Epoch: 6 -----------
Total loss (epoch): 6 : 1.520737
Distillater loss : 6 : 1.005878
Accuracy (epoch): 6 : 0.5893831
Val loss : 6 : 1.34237
Val Accuracy (epoch): 6 : 0.5902871
Epoch: 7 -----------
Total loss (epoch): 7 : 1.508101
Distillater loss : 7 : 1.005872
Accuracy (epoch): 7 : 0.5920453
Val loss : 7 : 2.097656
Val Accuracy (epoch): 7 : 0.5925921
Epoch: 8 -----------
Total loss (epoch): 8 : 1.267969
Distillater loss : 8 : 1.005815
Accuracy (epoch): 8 : 0.5949023
Val loss : 8 : 1.78513
Val Accuracy (epoch): 8 : 0.5982342
Epoch: 9 -----------
Total loss (epoch): 9 : 1.510699
Distillater loss : 9 : 1.005925
Accuracy (epoch): 9 : 0.5991197
Val loss : 9 : 1.387395
Val Accuracy (epoch): 9 : 0.5990839
Epoch: 10 -----------
Total loss (epoch): 10 : 1.495111
Distillater loss : 10 : 1.005821
Accuracy (epoch): 10 : 0.6014644
Val loss : 10 : 2.015202
Val Accuracy (epoch): 10 : 0.6017212
Epoch: 11 -----------
Total loss (epoch): 11 : 1.454717
Distillater loss : 11 : 1.00589
Accuracy (epoch): 11 : 0.6021615
Val loss : 11 : 1.659231
Val Accuracy (epoch): 11 : 0.6034705
Epoch: 12 -----------
Total loss (epoch): 12 : 1.37668
Distillater loss : 12 : 1.005826
Accuracy (epoch): 12 : 0.6035553
Val loss : 12 : 1.438908
Val Accuracy (epoch): 12 : 0.604799
What about global_step = tf.train.get_or_create_global_step() describe here ? It seems to only refers to the number of batches seen by the graph. Source.
total_loss_plot<-c()
#instead of collecting them during the training :
total_loss_plot <- alpha * train_loss_plot + (1 - alpha) * distilation_loss_plot
data <- data.frame("Student_loss" = train_loss_plot,
"Distillation_loss" = distilation_loss_plot,
"Total_loss" = total_loss_plot,
"Epoch" = 1:length(train_loss_plot),
"Val_loss" = val_loss_plot,
"Train_accuracy"= accuracy_plot,
"Val_accuracy"= val_accuracy_plot)
head(data)
Student_loss Distillation_loss Total_loss Epoch Val_loss
1 1.970847 1.006515 1.874414 1 1.683065
2 1.671758 1.006187 1.605201 2 1.746699
3 1.618646 1.006112 1.557393 3 1.531488
4 1.562328 1.005987 1.506694 4 1.508584
5 1.406053 1.005917 1.366040 5 1.372146
6 1.520737 1.005878 1.469251 6 1.342370
Train_accuracy Val_accuracy
1 0.5028956 0.5337647
2 0.5482197 0.5590533
3 0.5649438 0.5679042
4 0.5752720 0.5776999
5 0.5815135 0.5892469
6 0.5893831 0.5902871
Where total_loss is alpha * train_loss_plot * (1 - alpha) * distilation_loss_plot
ggplot(data, aes(Epoch)) +
scale_colour_manual(values=c(Student_loss="#F8766D",Val_loss="#00BFC4", Distillation_loss="#DE8C00", Total_loss="#1aff8c")) +
geom_line(aes(y = Student_loss, colour = "Student_loss")) +
geom_line(aes(y = Val_loss, colour = "Val_loss")) +
geom_line(aes(y = Total_loss, colour = "Total_loss")) +
geom_line(aes(y = Distillation_loss, colour = "Distillation_loss"))
#Validation set
ggplot(data, aes(Epoch)) +
geom_line(aes(y = Train_accuracy, colour = "Train_accuracy")) +
geom_line(aes(y = Val_accuracy, colour = "Val_accuracy"))
Is that all ? Well, no. Here we perform knowledge distillation to teach to the head of the student network.
The next step would be to reproduce the knowledge distillation after unfreezing some part of the student, after writing something like :
unfreeze_weights(conv_base_student, from = 'block5a_expand_conv')
But I will not bet my small GPU card on this or start a fire in my basement for the sake of the tutorial.
As I mentioned earlier, I readapted my code from kaggle, where the gpu is much bigger. Take a look if you want to see, but basically the end output looks like this :
Well, that’s it for this post, which is probably already lengthy enough for a blog post !