Knowledge distillation with R and tensorflow

A quick tutorial on how to perform Knowledge distillation with R, in eager mode.

Etienne Rolland https://github.com/Cdk29
2021-06-04
knitr::opts_chunk$set(echo = TRUE)

Welcome

Hi everyone ! Welcome to my blog. Here I will just share some tutorials around things that were complicated for me, and for which others R users could be interested. Not surprisingly, lot of this tutorials will involve tensorflow or other deep learning things.

Sometimes things are possible in R, but, since our community is smaller, we don’t have that many resources or tutorials compared to the python community, explaining why it is cubersome to do some particuliar tasks in R, especially when the few tutorials available or interfaces packages start accumulate errors or bugs because they are not used often by an active community.

I am not an expert, so I will try to source at maximum of my codes, or parameters when I can. I used a small size for the images to not blow my GPU, there is an example with fine tuning and a bigger GPU here.

There is probably a lack of optimization, but at least it is a working skeleton. If you have suggestion for improvement, comments are welcome :D

About the data

I wrote this code in the first place in the context of the Cassava Leaf Disease Classification, a Kaggle’s competition where the goal was to train a model to identify the disease on leafs of cassava. Here the distillation is made from an Efficientnet0 to an other one.

What is knowledge distillation

As presented in this discussion thread on kaggle, knowledge distillation is defined as simply trains another individual model to match the output of an ensemble. Source. It is in fact slightly more complicated : the second neural net (student) will made predictions on the images, but then, the losses will be a function of its own loss as well as a loss based on the difference between his prediction and the one of its teacher or the ensemble.

This approach allow to compress an ensemble into one model and by then reduce the inference time, or, if trained to match the output of a model, to increase the overall performance of the model. I discover this approach by looking at the top solutions of the Plant Pathology 2020 competition, an other solution with computer vision and leaf, such as this one.

I let you go to to this source mentioned aboved to understand how it could potentially works. It does not seems sure, but it seems related to the learning of specific features vs forcing the student to learn “multiple view”, multiple type of feature to detect in the images.

There is off course, no starting material to do it in R. Thanksfully there is a code example on the website of keras. In this example, they create a class of model, a distiller, to make the knowledge distillation. There is, however, one problem : model are not inheritable in R. There is example of inheritance with a R6 for callback, like here, but the models are not a R6 class. To overcome this problem, I used the code example as a guide, and reproduced the steps by following the approach in this guide for eager executation in keras with R. I took other code from the tensorflow website for R.

The code is quite hard to understand at first glance. The reason is, everything is executed in a single for loop, since everything is done in eager mode. It did not seemed possible to do it differently. So there is a lot of variable around to collect metrics during training. If you want to understand the code just remove it from the loop and run it outside of the for loop, before reconstructing the loop around. I did not used tfdataset as shown on the guide for eager execution, so instead of make_iterator_one_shot() and iterator_get_next(), here we loop over the train_generator to produce the batches.

library(tidyverse)
library(tensorflow)
tf$executing_eagerly()
[1] TRUE
tensorflow::tf_version()
[1] '2.3'

Here I flex with my own version of keras. Basically, it is a fork with application wrapper for the efficient net.

Disclaimer : I did not write the code for the really handy applications wrappers. It came from this commit for which the PR is hold until the fully release of tf 2.3, as stated in this PR. I am not sure why the PR is closed.

devtools::install_github("Cdk29/keras", dependencies = FALSE)
labels<-read_csv('train.csv')
head(labels)
# A tibble: 6 x 2
  image_id       label
  <chr>          <dbl>
1 1000015157.jpg     0
2 1000201771.jpg     3
3 100042118.jpg      1
4 1000723321.jpg     1
5 1000812911.jpg     3
6 1000837476.jpg     3
levels(as.factor(labels$label))
[1] "0" "1" "2" "3" "4"
idx0<-which(labels$label==0)
idx1<-which(labels$label==1)
idx2<-which(labels$label==2)
idx3<-which(labels$label==3)
idx4<-which(labels$label==4)
labels$CBB<-0
labels$CBSD<-0
labels$CGM<-0
labels$CMD<-0
labels$Healthy<-0
labels$CBB[idx0]<-1
labels$CBSD[idx1]<-1
labels$CGM[idx2]<-1
labels$CMD[idx3]<-1

“Would it have been easier to create a function to convert the labelling ?” You may ask.

labels$Healthy[idx4]<-1

Probably.

#labels$label<-NULL
head(labels)
# A tibble: 6 x 7
  image_id       label   CBB  CBSD   CGM   CMD Healthy
  <chr>          <dbl> <dbl> <dbl> <dbl> <dbl>   <dbl>
1 1000015157.jpg     0     1     0     0     0       0
2 1000201771.jpg     3     0     0     0     1       0
3 100042118.jpg      1     0     1     0     0       0
4 1000723321.jpg     1     0     1     0     0       0
5 1000812911.jpg     3     0     0     0     1       0
6 1000837476.jpg     3     0     0     0     1       0
val_labels<-read_csv('validation_set.csv')
train_labels<-labels[which(!labels$image_id %in% val_labels$image_id),]
table(train_labels$image_id %in% val_labels$image_id)

FALSE 
19256 
train_labels$label<-NULL
val_labels$label<-NULL

head(train_labels)
# A tibble: 6 x 6
  image_id         CBB  CBSD   CGM   CMD Healthy
  <chr>          <dbl> <dbl> <dbl> <dbl>   <dbl>
1 1000015157.jpg     1     0     0     0       0
2 1000201771.jpg     0     0     0     1       0
3 100042118.jpg      0     1     0     0       0
4 1000723321.jpg     0     1     0     0       0
5 1000812911.jpg     0     0     0     1       0
6 1000837476.jpg     0     0     0     1       0
head(val_labels)
# A tibble: 6 x 6
  image_id         CBB  CBSD   CGM   CMD Healthy
  <chr>          <dbl> <dbl> <dbl> <dbl>   <dbl>
1 1003442061.jpg     0     0     0     0       1
2 1004672608.jpg     0     0     0     1       0
3 1007891044.jpg     0     0     0     1       0
4 1009845426.jpg     0     0     0     1       0
5 1010648150.jpg     0     0     0     1       0
6 1011139244.jpg     0     0     0     1       0
image_path<-'cassava-leaf-disease-classification/train_images/'
#data augmentation
datagen <- image_data_generator(
  rotation_range = 40,
  width_shift_range = 0.2,
  height_shift_range = 0.2,
  shear_range = 0.2,
  zoom_range = 0.5,
  horizontal_flip = TRUE,
  fill_mode = "reflect"
)
img_path<-"cassava-leaf-disease-classification/train_images/1000015157.jpg"

img <- image_load(img_path, target_size = c(448, 448))
img_array <- image_to_array(img)
img_array <- array_reshape(img_array, c(1, 448, 448, 3))
img_array<-img_array/255
# Generated that will flow augmented images
augmentation_generator <- flow_images_from_data(
  img_array, 
  generator = datagen, 
  batch_size = 1 
)
op <- par(mfrow = c(2, 2), pty = "s", mar = c(1, 0, 1, 0))
for (i in 1:4) {
  batch <- generator_next(augmentation_generator)
  plot(as.raster(batch[1,,,]))
}
par(op)

Data generator

Okay so here is an interresting thing, I will try to compress the code to call a train generator to make it easier to call it.

Why ? Well, apparently a generator does not yield infinite batches, and the for loop of the distiller will stop working without obvious reason at epoch 7, when reaching the end of the validation generator.

When we iterate over it, validation_generator yeld 8 images and 8 label, until the batch 267, than contains only 5 images (and create the bug when we try to add the loss of the batch to the loss of the epoch. Batch 268 does not exist. So solution seems to recreate on the fly the validation set and restart the iterations.

arg.list <- list(dataframe = val_labels, directory = image_path,
                                              class_mode = "other",
                                              x_col = "image_id",
                                              y_col = c("CBB","CBSD", "CGM", "CMD", "Healthy"),
                                              target_size = c(228, 228),
                                              batch_size=8)
validation_generator <- do.call(flow_images_from_dataframe, arg.list)
dim(validation_generator[266][[1]])
[1]   8 228 228   3
dim(validation_generator[267][[1]])
[1]   5 228 228   3
dim(val_labels)
[1] 2141    6
2141/8
[1] 267.625
train_generator <- flow_images_from_dataframe(dataframe = train_labels, 
                                              directory = image_path,
                                              generator = datagen,
                                              class_mode = "other",
                                              x_col = "image_id",
                                              y_col = c("CBB","CBSD", "CGM", "CMD", "Healthy"),
                                              target_size = c(228, 228),
                                              batch_size=8)

validation_generator <- flow_images_from_dataframe(dataframe = val_labels, 
                                              directory = image_path,
                                              class_mode = "other",
                                              x_col = "image_id",
                                              y_col = c("CBB","CBSD", "CGM", "CMD", "Healthy"),
                                              target_size = c(228, 228),
                                              batch_size=8)
train_generator
<tensorflow.python.keras.preprocessing.image.DataFrameIterator>
conv_base<-keras::application_efficientnet_b0(weights = "imagenet", include_top = FALSE, input_shape = c(228, 228, 3))

freeze_weights(conv_base)

model <- keras_model_sequential() %>%
    conv_base %>% 
    layer_global_max_pooling_2d() %>% 
    layer_batch_normalization() %>% 
    layer_dropout(rate=0.5) %>%
    layer_dense(units=5, activation="softmax")
#unfreeze_weights(model, from = 'block5a_expand_conv')
unfreeze_weights(conv_base, from = 'block5a_expand_conv')
model %>% load_model_weights_hdf5("fine_tuned_eff_net_weights.15.hdf5")
summary(model)
Model: "sequential"
______________________________________________________________________
Layer (type)                   Output Shape                Param #    
======================================================================
efficientnetb0 (Functional)    (None, 8, 8, 1280)          4049571    
______________________________________________________________________
global_max_pooling2d (GlobalMa (None, 1280)                0          
______________________________________________________________________
batch_normalization (BatchNorm (None, 1280)                5120       
______________________________________________________________________
dropout (Dropout)              (None, 1280)                0          
______________________________________________________________________
dense (Dense)                  (None, 5)                   6405       
======================================================================
Total params: 4,061,096
Trainable params: 3,707,853
Non-trainable params: 353,243
______________________________________________________________________
conv_base_student<-keras::application_efficientnet_b0(weights = "imagenet", include_top = FALSE, input_shape = c(228, 228, 3))

freeze_weights(conv_base_student)

student <- keras_model_sequential() %>%
    conv_base_student %>% 
    layer_global_max_pooling_2d() %>% 
    layer_batch_normalization() %>% 
    layer_dropout(rate=0.5) %>%
    layer_dense(units=5, activation="softmax")

student
Model
Model: "sequential_1"
______________________________________________________________________
Layer (type)                   Output Shape                Param #    
======================================================================
efficientnetb0 (Functional)    (None, 8, 8, 1280)          4049571    
______________________________________________________________________
global_max_pooling2d_1 (Global (None, 1280)                0          
______________________________________________________________________
batch_normalization_1 (BatchNo (None, 1280)                5120       
______________________________________________________________________
dropout_1 (Dropout)            (None, 1280)                0          
______________________________________________________________________
dense_1 (Dense)                (None, 5)                   6405       
======================================================================
Total params: 4,061,096
Trainable params: 8,965
Non-trainable params: 4,052,131
______________________________________________________________________

Source code and knowledge distillation

Source code for knowledge distillation with Keras : https://keras.io/examples/vision/knowledge_distillation/
Help for eager executation details in R and various usefull code : https://keras.rstudio.com/articles/eager_guide.html
Other source code in R : https://tensorflow.rstudio.com/tutorials/advanced/

I am using an alpha parameter of 0.9 as suggested by this article.

i=1
alpha=0.9 #On_the_Efficacy_of_Knowledge_Distillation_ICCV_2019
temperature=3
optimizer <- optimizer_adam()
train_loss <- tf$keras$metrics$Mean(name='student_loss')
train_accuracy <-  tf$keras$metrics$CategoricalAccuracy(name='train_accuracy')
nb_epoch<-12
nb_batch<-300
val_step<-40
train_loss_plot<-c()
accuracy_plot<-c()
distilation_loss_plot <- c()
val_loss_plot <- c()
val_accuracy_plot <- c()
count_epoch<-0
for (epoch in 1:nb_epoch) {
    cat("Epoch: ", epoch, " -----------\n")
    # Init metrics
    train_loss_epoch <- 0
    accuracies_on_epoch <- c()
    distilation_loss_epoch <- 0
    val_loss_epoch <- 0
    val_accuaries_on_epoch <- c()
    
    #Formula to not see the same batch over and over on each epoch
    #Count epoch instead of epoch
    count_epoch<-count_epoch+1
    idx_batch <- (1+nb_batch*(count_epoch-1)):(nb_batch*count_epoch)
    idx_val_set <- (1+val_step*(count_epoch-1)):(val_step*count_epoch)
    
    #Dirty solution to restart on a new validation batch generator before reaching the end of the other one 
    if (as.integer((dim(val_labels)[1]/8)-1) %in% idx_val_set) {
        count_epoch<-1
        idx_val_set <- (1+val_step*(count_epoch-1)):(val_step*count_epoch)
        validation_generator <- do.call(flow_images_from_dataframe, arg.list)
    }
    #need the same if for train generator
    if (as.integer((dim(train_labels)[1]/8)-1) %in% idx_batch) {
        count_epoch<-1
        idx_batch <- (1+nb_batch*(count_epoch-1)):(nb_batch*count_epoch)
        train_generator <- do.call(flow_images_from_dataframe, arg.list)
    }
    
    for (batch in idx_batch) {
        x = train_generator[batch][[1]]
        y = train_generator[batch][[2]]
        # Forward pass of teacher
        teacher_predictions = model(x)

        with(tf$GradientTape() %as% tape, {
            student_predictions = student(x)
            student_loss = tf$losses$categorical_crossentropy(y, student_predictions)
        
            distillation_loss = tf$losses$categorical_crossentropy(tf$nn$softmax(teacher_predictions/temperature, axis=0L), 
                                                           tf$nn$softmax(student_predictions/temperature, axis=0L))
        
            loss = alpha * student_loss + (1 - alpha) * distillation_loss
            })
        
        # Compute gradients
        # Variating learning rate :
        # optimizer <- optimizer_adam(lr = 0.0001)
        gradients <- tape$gradient(loss, student$trainable_variables)
        optimizer$apply_gradients(purrr::transpose(list(gradients, student$trainable_variables)))
        
        #Collect the metrics of the student
        train_loss_epoch <- train_loss_epoch + student_loss
        distilation_loss_epoch <- distilation_loss_epoch + distillation_loss
        
        accuracy_on_batch <- train_accuracy(y_true=y, y_pred=student_predictions)
        accuracies_on_epoch <- c(accuracies_on_epoch, as.numeric(accuracy_on_batch))
        
    }

    #Collect info on current epoch and for graphs and cat()
    train_loss_epoch <- mean(as.vector(as.numeric(train_loss_epoch))/nb_batch)
    train_loss_plot <- c(train_loss_plot, train_loss_epoch)
    
    distilation_loss_epoch <- mean(as.vector(as.numeric(distilation_loss_epoch))/nb_batch)
    distilation_loss_plot <- c(distilation_loss_plot, distilation_loss_epoch)
    
    accuracies_on_epoch <- mean(accuracies_on_epoch)
    accuracy_plot <- c(accuracy_plot, accuracies_on_epoch)
    
    
    for (step in idx_val_set) {
        # Unpack the data
        x = validation_generator[step][[1]]
        y = validation_generator[step][[2]]

        # Compute predictions
        student_predictions = student(x)

        # Calculate the loss
        student_loss = tf$losses$categorical_crossentropy(y, student_predictions)

        #Collect the metrics of the student
        #This line will create a bug of shape when val_loss end.
        val_loss_epoch <- val_loss_epoch + student_loss
        
        accuracy_on_val_step <- train_accuracy(y_true=y, y_pred=student_predictions)
        val_accuaries_on_epoch <- c(val_accuaries_on_epoch, as.numeric(accuracy_on_val_step))
    }
    
    #Collect info on current epoch and for graphs and cat()
    val_loss_epoch <- mean(as.vector(as.numeric(val_loss_epoch))/val_step)
    val_loss_plot <- c(val_loss_plot, val_loss_epoch)
    
    val_accuaries_on_epoch <- mean(val_accuaries_on_epoch)
    val_accuracy_plot <- c(val_accuracy_plot, val_accuaries_on_epoch)
    
    #Plotting
    cat("Total loss (epoch): ", epoch, ": ", train_loss_epoch, "\n")
    cat("Distillater loss : ", epoch, ": ", distilation_loss_epoch, "\n")
    cat("Accuracy (epoch): ", epoch, ": ", accuracies_on_epoch, "\n")
    cat("Val loss : ", epoch, ": ", val_loss_epoch, "\n")
    cat("Val Accuracy (epoch): ", epoch, ": ", val_accuaries_on_epoch, "\n")
}
Epoch:  1  -----------
Total loss (epoch):  1 :  1.970847 
Distillater loss :  1 :  1.006515 
Accuracy (epoch):  1 :  0.5028956 
Val loss :  1 :  1.683065 
Val Accuracy (epoch):  1 :  0.5337647 
Epoch:  2  -----------
Total loss (epoch):  2 :  1.671758 
Distillater loss :  2 :  1.006187 
Accuracy (epoch):  2 :  0.5482197 
Val loss :  2 :  1.746699 
Val Accuracy (epoch):  2 :  0.5590533 
Epoch:  3  -----------
Total loss (epoch):  3 :  1.618646 
Distillater loss :  3 :  1.006112 
Accuracy (epoch):  3 :  0.5649438 
Val loss :  3 :  1.531488 
Val Accuracy (epoch):  3 :  0.5679042 
Epoch:  4  -----------
Total loss (epoch):  4 :  1.562328 
Distillater loss :  4 :  1.005987 
Accuracy (epoch):  4 :  0.575272 
Val loss :  4 :  1.508584 
Val Accuracy (epoch):  4 :  0.5776999 
Epoch:  5  -----------
Total loss (epoch):  5 :  1.406053 
Distillater loss :  5 :  1.005917 
Accuracy (epoch):  5 :  0.5815135 
Val loss :  5 :  1.372146 
Val Accuracy (epoch):  5 :  0.5892469 
Epoch:  6  -----------
Total loss (epoch):  6 :  1.520737 
Distillater loss :  6 :  1.005878 
Accuracy (epoch):  6 :  0.5893831 
Val loss :  6 :  1.34237 
Val Accuracy (epoch):  6 :  0.5902871 
Epoch:  7  -----------
Total loss (epoch):  7 :  1.508101 
Distillater loss :  7 :  1.005872 
Accuracy (epoch):  7 :  0.5920453 
Val loss :  7 :  2.097656 
Val Accuracy (epoch):  7 :  0.5925921 
Epoch:  8  -----------
Total loss (epoch):  8 :  1.267969 
Distillater loss :  8 :  1.005815 
Accuracy (epoch):  8 :  0.5949023 
Val loss :  8 :  1.78513 
Val Accuracy (epoch):  8 :  0.5982342 
Epoch:  9  -----------
Total loss (epoch):  9 :  1.510699 
Distillater loss :  9 :  1.005925 
Accuracy (epoch):  9 :  0.5991197 
Val loss :  9 :  1.387395 
Val Accuracy (epoch):  9 :  0.5990839 
Epoch:  10  -----------
Total loss (epoch):  10 :  1.495111 
Distillater loss :  10 :  1.005821 
Accuracy (epoch):  10 :  0.6014644 
Val loss :  10 :  2.015202 
Val Accuracy (epoch):  10 :  0.6017212 
Epoch:  11  -----------
Total loss (epoch):  11 :  1.454717 
Distillater loss :  11 :  1.00589 
Accuracy (epoch):  11 :  0.6021615 
Val loss :  11 :  1.659231 
Val Accuracy (epoch):  11 :  0.6034705 
Epoch:  12  -----------
Total loss (epoch):  12 :  1.37668 
Distillater loss :  12 :  1.005826 
Accuracy (epoch):  12 :  0.6035553 
Val loss :  12 :  1.438908 
Val Accuracy (epoch):  12 :  0.604799 

What about global_step = tf.train.get_or_create_global_step() describe here ? It seems to only refers to the number of batches seen by the graph. Source.

Plotting

total_loss_plot<-c()
#instead of collecting them during the training : 
total_loss_plot <- alpha * train_loss_plot + (1 - alpha) * distilation_loss_plot
data <- data.frame("Student_loss" = train_loss_plot, 
                    "Distillation_loss" = distilation_loss_plot,
                   "Total_loss" = total_loss_plot,
                    "Epoch" = 1:length(train_loss_plot),
                    "Val_loss" = val_loss_plot,
                    "Train_accuracy"= accuracy_plot,
                    "Val_accuracy"= val_accuracy_plot)
head(data)
  Student_loss Distillation_loss Total_loss Epoch Val_loss
1     1.970847          1.006515   1.874414     1 1.683065
2     1.671758          1.006187   1.605201     2 1.746699
3     1.618646          1.006112   1.557393     3 1.531488
4     1.562328          1.005987   1.506694     4 1.508584
5     1.406053          1.005917   1.366040     5 1.372146
6     1.520737          1.005878   1.469251     6 1.342370
  Train_accuracy Val_accuracy
1      0.5028956    0.5337647
2      0.5482197    0.5590533
3      0.5649438    0.5679042
4      0.5752720    0.5776999
5      0.5815135    0.5892469
6      0.5893831    0.5902871

Where total_loss is alpha * train_loss_plot * (1 - alpha) * distilation_loss_plot

ggplot(data, aes(Epoch)) +
  scale_colour_manual(values=c(Student_loss="#F8766D",Val_loss="#00BFC4", Distillation_loss="#DE8C00", Total_loss="#1aff8c")) +
  geom_line(aes(y = Student_loss, colour = "Student_loss")) + 
  geom_line(aes(y = Val_loss, colour = "Val_loss")) + 
  geom_line(aes(y = Total_loss, colour = "Total_loss")) + 
  geom_line(aes(y = Distillation_loss, colour = "Distillation_loss"))
#Validation set
ggplot(data, aes(Epoch)) + 
  geom_line(aes(y = Train_accuracy, colour = "Train_accuracy")) + 
  geom_line(aes(y = Val_accuracy, colour = "Val_accuracy"))

Fine tuning and conclusion

Is that all ? Well, no. Here we perform knowledge distillation to teach to the head of the student network.

The next step would be to reproduce the knowledge distillation after unfreezing some part of the student, after writing something like :

unfreeze_weights(conv_base_student, from = 'block5a_expand_conv')

But I will not bet my small GPU card on this or start a fire in my basement for the sake of the tutorial.

As I mentioned earlier, I readapted my code from kaggle, where the gpu is much bigger. Take a look if you want to see, but basically the end output looks like this :

loss
accuracy

Well, that’s it for this post, which is probably already lengthy enough for a blog post !