Machine Learning: Understanding Mean Teacher Model

Bhupender saini
4 min readOct 28, 2020

--

Recently, we got many intuitive machine learning models but one model that gained attention was the Mean teacher model by curious AI company¹. The mean teacher mechanism of learning was intuitive and had shown the state of art results in the computer visualization domain. But, observing the performance of this model in the Natural Language Processing domain with different noise strategies was quite a journey of learning machine learning. Here, I will be sharing my understanding of the mean teacher model.

In the mean teacher model, two identical models are trained with two different strategies called the student and teacher model. In which, only the student model is trained. And, a very minimal amount of weights of student model is assigned to the teacher model at every step called exponential moving average weights that's why we call it as Mean teacher. The ability to utilize abundant unlabeled data called semi-supervised learning is one of the major advantages of the mean teacher.

Figure 1: Mean teacher model working diagram.

As shown in Figure 1, two cost function plays an important role while backpropagating i.e. classification cost and consistency cost. Here, classification cost(C(θ)) is calculated as binary cross-entropy between label predicted by student model and original label. Consistency cost(J(θ)) is the mean squared difference between the predicted outputs of the student (weights θ and noise η) and teacher model (weights θ′ and noise η′).

Consistency cost is actually the distribution difference between two predictions (student and teacher prediction) and the original label is not required. During training, the model tries to minimize the distribution difference between the student and teacher model. So, instead of labeled data, we may utilize unlabelled data here. The mathematical declaration of consistency cost is as follows.

Consistency cost

One of the important factors that play a crucial role in adding
robustness to the model is the introduction of noise during
training. It's like fooling models with noise data so the model will not be biased towards a particular target and also can perform well while predicting unseen data.

While back-propagating in the student model, the overall cost
(O(θ)) is calculated with the given formula:

Overall Cost

During training, the exponential moving average(EMA) weights of the student model are assigned to the teacher model at every step and the proportion of weights assigned is controlled by parameter alpha(α). As mentioned in equation 3, while assigning weights, the teacher model holds its previous weights in alpha(α) proportion and (1−α) portion of student weights.

Exponential moving average

We ran the mean teacher model for fake news detection and as per our observations, the teacher model starts performing better than the student model after particular epochs, in our case at epoch 15, where the teacher model starts overtaking the student model. However, the convergence of the teacher model depends on epoch, batch size, train data size, and alpha α. The tuning of parameters is required to get better results.

Table 1

Algorithm:

The complete algorithm of the mean teacher methodology is as follows:

Mean Teacher Algorithm

Code:

Code snippet of mean teacher algorithm is shown below and if interested you can find complete code for fake news detection using weakly supervised learning in the github.

Note: Classification cost is binary cross-entropy in my case for two classes therefore not mentioning here.

Conclusion:

The mean teacher model is quite a simple and intuitive model to get better prediction and has the option of utilizing unlabeled data during training. The teacher model in the end performs well with the test data or unseen data then the student model. However, The convergence of the teacher model depends on epoch, batch size, train data size, and alpha α. So, tuning of the parameter is required to get better results. Furthermore, different noise strategies or attention modeling can be introduced in this model for better performance and robustness.

References:

  1. https://thecuriousaicompany.com/
  2. Antti Tarvainen and Harri Valpola. Mean teachers are
    better role models: Weight-averaged consistency targets
    improve semi-supervised deep learning results.
  3. https://github.com/bksaini078/fake_news_detection

--

--