Knowledge distillation with R and tensorflow

A quick tutorial on how to perform Knowledge distillation with R, in eager mode.

Etienne Rolland
knitr::opts_chunk$set(echo = TRUE)


Hi everyone ! Welcome to my blog. Here I will just share some tutorials around things that were complicated for me, and for which others R users could be interested. Not surprisingly, lot of this tutorials will involve tensorflow or other deep learning things.

Sometimes things are possible in R, but, since our community is smaller, we don’t have that many resources or tutorials compared to the python community, explaining why it is cubersome to do some particuliar tasks in R, especially when the few tutorials available or interfaces packages start accumulate errors or bugs because they are not used often by an active community.

I am not an expert, so I will try to source at maximum of my codes, or parameters when I can. I used a small size for the images to not blow my GPU, there is an example with fine tuning and a bigger GPU here.

There is probably a lack of optimization, but at least it is a working skeleton. If you have suggestion for improvement, comments are welcome :D

About the data

I wrote this code in the first place in the context of the Cassava Leaf Disease Classification, a Kaggle’s competition where the goal was to train a model to identify the disease on leafs of cassava. Here the distillation is made from an Efficientnet0 to an other one.

What is knowledge distillation

As presented in this discussion thread on kaggle, knowledge distillation is defined as simply trains another individual model to match the output of an ensemble. Source. It is in fact slightly more complicated : the second neural net (student) will made predictions on the images, but then, the losses will be a function of its own loss as well as a loss based on the difference between his prediction and the one of its teacher or the ensemble.

This approach allow to compress an ensemble into one model and by then reduce the inference time, or, if trained to match the output of a model, to increase the overall performance of the model. I discover this approach by looking at the top solutions of the Plant Pathology 2020 competition, an other solution with computer vision and leaf, such as this one.

I let you go to to this source mentioned aboved to understand how it could potentially works. It does not seems sure, but it seems related to the learning of specific features vs forcing the student to learn “multiple view”, multiple type of feature to detect in the images.

There is off course, no starting material to do it in R. Thanksfully there is a code example on the website of keras. In this example, they create a class of model, a distiller, to make the knowledge distillation. There is, however, one problem : model are not inheritable in R. There is example of inheritance with a R6 for callback, like here, but the models are not a R6 class. To overcome this problem, I used the code example as a guide, and reproduced the steps by following the approach in this guide for eager executation in keras with R. I took other code from the tensorflow website for R.

The code is quite hard to understand at first glance. The reason is, everything is executed in a single for loop, since everything is done in eager mode. It did not seemed possible to do it differently. So there is a lot of variable around to collect metrics during training. If you want to understand the code just remove it from the loop and run it outside of the for loop, before reconstructing the loop around. I did not used tfdataset as shown on the guide for eager execution, so instead of make_iterator_one_shot() and iterator_get_next(), here we loop over the train_generator to produce the batches.

[1] TRUE
[1] '2.3'

Here I flex with my own version of keras. Basically, it is a fork with application wrapper for the efficient net.

Disclaimer : I did not write the code for the really handy applications wrappers. It came from this commit for which the PR is hold until the fully release of tf 2.3, as stated in this PR. I am not sure why the PR is closed.

devtools::install_github("Cdk29/keras", dependencies = FALSE)
# A tibble: 6 x 2
  image_id       label
  <chr>          <dbl>
1 1000015157.jpg     0
2 1000201771.jpg     3
3 100042118.jpg      1
4 1000723321.jpg     1
5 1000812911.jpg     3
6 1000837476.jpg     3
[1] "0" "1" "2" "3" "4"

“Would it have been easier to create a function to convert the labelling ?” You may ask.



# A tibble: 6 x 7
  image_id       label   CBB  CBSD   CGM   CMD Healthy
  <chr>          <dbl> <dbl> <dbl> <dbl> <dbl>   <dbl>
1 1000015157.jpg     0     1     0     0     0       0
2 1000201771.jpg     3     0     0     0     1       0
3 100042118.jpg      1     0     1     0     0       0
4 1000723321.jpg     1     0     1     0     0       0
5 1000812911.jpg     3     0     0     0     1       0
6 1000837476.jpg     3     0     0     0     1       0
train_labels<-labels[which(!labels$image_id %in% val_labels$image_id),]
table(train_labels$image_id %in% val_labels$image_id)


# A tibble: 6 x 6
  image_id         CBB  CBSD   CGM   CMD Healthy
  <chr>          <dbl> <dbl> <dbl> <dbl>   <dbl>
1 1000015157.jpg     1     0     0     0       0
2 1000201771.jpg     0     0     0     1       0
3 100042118.jpg      0     1     0     0       0
4 1000723321.jpg     0     1     0     0       0
5 1000812911.jpg     0     0     0     1       0
6 1000837476.jpg     0     0     0     1       0
# A tibble: 6 x 6
  image_id         CBB  CBSD   CGM   CMD Healthy
  <chr>          <dbl> <dbl> <dbl> <dbl>   <dbl>
1 1003442061.jpg     0     0     0     0       1
2 1004672608.jpg     0     0     0     1       0
3 1007891044.jpg     0     0     0     1       0
4 1009845426.jpg     0     0     0     1       0
5 1010648150.jpg     0     0     0     1       0
6 1011139244.jpg     0     0     0     1       0
#data augmentation
datagen <- image_data_generator(
  rotation_range = 40,
  width_shift_range = 0.2,
  height_shift_range = 0.2,
  shear_range = 0.2,
  zoom_range = 0.5,
  horizontal_flip = TRUE,
  fill_mode = "reflect"

img <- image_load(img_path, target_size = c(448, 448))
img_array <- image_to_array(img)
img_array <- array_reshape(img_array, c(1, 448, 448, 3))
# Generated that will flow augmented images
augmentation_generator <- flow_images_from_data(
  generator = datagen, 
  batch_size = 1 
op <- par(mfrow = c(2, 2), pty = "s", mar = c(1, 0, 1, 0))
for (i in 1:4) {
  batch <- generator_next(augmentation_generator)