The Question of Distillation from Weak Teachers
During my PhD, I worked on the problem of Distillation from Weak Teachers. While model distillation (or knowledge distillation) is a well known concept in the domain of deep learning, Reverse Knowledge Distillation (also known as Distillation from Weak Teachers) is a relatively obscure topic even within the ML community.
Fast forward an year from my defense (and the NeurIPS 2024 presentation of my work), I have realized some common questions people have while understanding the methodology we introduced in this paper (i.e. Induced Model Matching, or IMM). To this end, I decided to write a blog post on the topic. This post can be thought of as a distilled version of my thesis.
The questions are legitimate:
- First of all, how can smaller models possibly teach larger models? What information do they have that the larger model does not?
- Secondly, how do we even compare the two? Provided models that take inputs of different dimensionalities cannot be directly compared?
- How can a regularization objective that works in a restricted dimensionality possibly improve performance in the original dimensionality?
We start off by answering the first part.
Core Question 1: Merits of restricted models
Surprisingly, it is indeed possible for a model to have extra knowledge, while (1) being trained on (or constructed from) the same data and (2) having the same, or even smaller dimensionality (model size). In the domain of language modeling for example, this means a bigram model can indeed contain information that another bigram model (constructed from the same data), or even a larger N-gram (represented as a conditional probability table or even a recurrent neural network) does not have. More precisely, we are talking about information that is useful for prediction tasks, at any scale (bigram or N-gram).
We start off with an intuitive example. More precisely, we choose an example from language modeling (which was the starting point of our experimentation). We compare, the empirical bigrams with the Kneser-Ney bigram. If you are not familiar with Kneser-Ney bigram, Chen & Goodman, 1999 is a great read on it. In this post, I’ll stay away from the math and explain the concept intuitively. Any bigram model (empirical or Kneser-Ney) has a backoff component (unigram, in the case of bigram models).
Empirical bigrams construct the unigram components using occurrence counts (just like the bigram component). Kneser-Ney takes advantage of the fact that if we have the bigram matrix, there exists a better way to construct the unigram component.
First of all, let’s discuss a potential problem with creating the backoff model using just occurrence counts (as does the empirical bigram). If we have a corpus on let’s say soccer where the token united happens many times in club names. If we create a unigram model on an article on a soccer match, between two teams that have united in their name, (let’s say “Manchester United” and “Newcastle United”) it’s likely that the word united will be assigned a huge probability, probably to the point it’s comparable to the count of a grammatical article in the same corpus. This unigram model will not generalize well. If we use this model to perform a prediction task on another corpus, and we have to back off to the unigram (due to sparsity in the bigram counts), it will predict united with a probability similar to a grammatical article, let’s say them which is not good.
Structural information is key. Kneser-Ney, instead of using occurrence counts, takes a side perspective into the same data. Precisely, to assign a count to a token, it looks at the number of unique contexts that are followed by this token. Following with the example of manchester united and newcastle united, the token united will be assigned a count of just two regardless of how many times they occur in the corpus (provided there was no other usage of united beside these two team names). A grammatical article like them however would still have a huge count, provided it will likely be followed by many unique contexts.
Bottom line: Kneser-Ney bigram utilizes certain structural information about natural language that empirical bigram does not.
Quantitative Proof of the merits of Kneser-Ney bigram over a larger N-gram
Our work is an improvement over Xie et al who first used restricted models in the training of larger ones, via a different approach (data noising). Precisely, they took a successful aim at incorporating Kneser-Ney bigram information into a larger learner model which was an LSTM RNN, with a context length of 35 (i.e. an N-gram with N=35). Our approach solves the mathematical inconsistencies of their approach.
Since both Xie et al and we use a smaller model (Kneser-Ney bigram) to improve a larger N-gram, there is some merit in first comparing the performance of both the Kneser-Ney bigram and the LSTM RNN on the same task (i.e. bigram prediction).
Table 5 in our paper does precisely this. Empirically, we observe that the Kneser-Ney bigram is outperforming the N-gram, and we see a reason for the N-gram to learn from the bigram.
It was not possible for Xie et al to produce this figure because it requires a way to make the large model perform the smaller task. Our proposed methodology allows us to align the larger model with the smaller task (by inducing a small model out of the larger model via marginalization) which makes this comparison possible.
Our methodology becomes more clear in the following section, which is the second most common question for our readers (i.e. how our methodology is the right way to compare models of different scales, and consequently, for knowledge transfer across models of different scales).
Core Question 2: Comparing models of different scales
It is not appropriate to directly compare two models that work at different scales.
In both Xie et al and our work (which builds on Xie et al), the target model, the Kneser-Ney is just a bigram and takes a context of just length 1. For notational convenience, we separate the context of the learner into two parts: the short context, which is the same as that of the bigram, and we use $\overline{x}$ to represent it. The remainder is called the extended context and is represented using $\underline{x}$. The complete context $x = (\overline{x}, \underline{x})$
The noising methodology of Xie et al implicitly adds a regularization term to the training that compares $Q$ (i.e. the 35-gram) directly with $P$ (i.e. a bigram).
\[\tag{1} \sum_{x} \pi(x) \sum_{y} \underbrace{\overline P(y|\overline{x})}_{\text{only looks at restricted context}} \log \frac{ P(y| \overline{x}, \underline{x}) }{ \underbrace{Q(y| \overline{x}, \underline{x})}_{\text{looks at complete context}} }\]This term does not define a valid KL-divergence between the two models, as the target and learner models being compared work at different scales (lengths of context). As a consequence, the non-negativity property also does not hold. In Appendix B.2 of our paper, we show a numerical example where such an objective can go negative!
While noising implicitly adds this objective, another methodology, reverse knowledge distillation Yuan et al, 2020; Qin et al, 2022 explicitly adds this regularization objective. Essentially, the noising methodology of Xie et al and reverse knowledge distillation are implicit and explicit versions of the same idea and add the same regularization objective to the training (implicitly or explicitly).
Having mentioned the word “regularization objective”, I find the need to talk about the importance of a correct regularization technique. Before we do that, we need to understand how the extra knowledge of the Kneser-Ney bigram is infused into the training through “regularization”.
Data Augmentation as Regularization
It’s important to first understand that:
- Any extra knowledge used in training is data augmentation, including models of smaller scales (that noising/reverse-KD do).
- Any data augmentation is also, in a way, regularization.
To this end, we momentarily take a step away from deep learning and take a general perspective on optimization. Following is the general structure of an optimization problem:
\[\begin{aligned} \min_{x} \quad & f_0(x)\\ \textrm{s.t.} \quad & x \in \cal D \end{aligned}\]We have an objective function $f(x)$, that we minimize (or maximize) over some optimization variable $x$ (which may be a vector), and we are subject to (abbreviated as s.t.) constraints that take the form of $x$ belonging to a subset of possibilities, $\cal D$.
Constraints are optional. In fact, if there are no constraints, we are said to be doing unconstrained optimization. The problem is feasible if there is some $x$ that satisfies the constraints (i.e., the feasibility set is not empty). If there is no $x$ that satisfies the constraints, we can say that the problem is infeasible, or equivalently, the feasibility set is a null set.
In deep learning, we optimize over the parameters of the deep neural network. While at first sight, deep learning comes across as unconstrained optimization (as the optimizers used in deep learning do not allow for explicit constraints), this is the case because the constraints are imposed implicitly through regularization, which we can also think of as Lagrangian relaxations of what would otherwise be explicit constraints. The choice of optimization algorithm along with the parameterization can also provide more implicit sources of constraints. All of these together can be interpreted as a secondary source of knowledge, in addition to the original data.
When the objective is approximated empirically, data itself is a form of constraint, attempting to keep us close to the idealized objective that we are after. In a huge parameter space, data imposes constraints that help us locate the solution (i.e. the optimal parameters). The lack of data may make this constraint too loose to pinpoint the optimal solution, which is exactly why we need additional regularization. By taking the perspective that data itself is a constraint, then it becomes evident that data augmentation represents constraints too. This means that data augmentation is a form of regularization and furthermore, conversely, any regularization could be thought of as extra data that tilts the optimal model toward obeying certain constraints, and hopefully preventing the optimization from it from overfitting the original data. Any regularization imposes extra constraints beyond the ones imposed by the original data. These extra constraints prevent the model from overfitting the original data.
Core Question 3: How does a regularization in restricted dimensionality improve performance in original dimensionality?
This question hits at the core of how regularization works. The above discussion should have given you some basis of how any regularization works in general.
This question is kind of the same question as to how L1 or L2 regularization can help improve performance on the task at hand for any network under training under any objective. The additional constraints imposed by the regularization help guide us towards the right parameters in the parameter space, as long as they do not “rule out the right parameters”. Ideally, regularization should rule out the regions where the optimizer has no business, as a favor to the optimizer. Regularization closes some doors in the parameter space to guide it towards where it needs to go.
We discussed above the caveats of the regularization objectives added by prior approaches, and how they do not compute valid KL divergence in the restricted dimensionality. In those scenarios, yes, the methodology can actually hurt (refer again to Appendix B.2 of our paper where we show a numerical example where such an objective can go negative). Provided consistency of the objective, this issue won’t arise.
While Appendix B.2 shows a numerical counter example that theoretically proves this, in the paper we also have an empirical proof (i.e. Table 4), where we implement both noising and IMM (albeit in the logistic regression experiments) and vary the contribution of the regularization term. With increasing contribution, noising falls below the baseline, while IMM does not.
Importance of correct regularization technique
Regularization, just like the dataset, has the potential to hurt the optimization process if done incorrectly. It’s possible that the extra constraints imposed by the regularization could rule out the region in the parameter space containing the optimal idealized solution, or make it suboptimal from an empirical perspective. This is precisely the problem we want to address, when the regularization, or data augmentation comes from a model of restricted dimensionality.
Regularization via data augmentation being done incorrectly, could be either due to the augmentation dataset not being good (e.g, being biased), or incorrect methodology being used to incorporate an otherwise good dataset. In the NeurIPS 2024 paper, we are concerned with the latter.
This is precisely the reason why the mathematical correctness and consistency of the objective, that we studied above in Step 2 are so critical for us :)
Having discussed the inconsistency of noising/reverse-KD objectives, we now move towards the solution we propose.
The right way to compare models of different scales (via Induced Models)
We propose, that in order for a large model (model working on the complete context) to learn from a small model (model working on a restricted context), the large model must be marginalized and a small model be induced out of it.
Consequently, it should be the induced smaller model of the larger model that should learn from the smaller target model. In other words, for an adult to learn from a child, it’s the “inner child” of the adult that should learn from the child.
Throughout the paper, we use the notation $\overline{Q}(y \mid \overline{x})$ to represent the induced smaller model of $Q(y \mid \overline{x}, \underline{x})$. Once we have that, we can add the “corrected” regularization term,
Induced Models: The induced model $\overline{Q}(y \mid \overline{x})$ is obtained from $Q(y \mid \overline{x}, \underline{x})$ via marginalization. This marginalization depends on the kind of dataset (discrete vs continuous).
\[\tag{2} \sum_{x} \pi(x) \sum_{y} \underbrace{\overline P(y|\overline{x})}_{\text{only looks at restricted context}} \log \frac{ P(y| \overline{x}, \underline{x}) }{ \underbrace{\overline{Q}(y| \overline{x})}_{\text{only looks at restricted context}} }\]The marginalization process is the core workhorse of the IMM methodology. For brevity, further details are skipped in this blog post, but I highly encourage you to read the paper (or the thesis, linked in the concluding remarks below) to fully understand it. It is computationally more complex, the cost we pay for mathematical consistency. In my thesis, the whole of Chapter 5 is dedicated to the computational tractability of the approach.
Performance on smaller task (prediction using restricted context)
$\overline{Q}(y \mid \overline{x})$ can now be made to perform the smaller task (i.e. bigram prediction) and can be used to obtain the performance on the smaller task. This is precisely what we do in Table 5 of our paper, and show that a Kneser-Ney bigram performs better than $Q(y \mid \overline{x}, \underline{x})$, and therefore, contains knowledge that $Q(y \mid \overline{x}, \underline{x})$ does not have! This result is a concrete manifestation of what Xie et al based their work on, except that they did not have a way to quantify this!
Since we minimize the regularization objective shown in Equation 2, it is evident that once trained using this objective (using Kneser-Ney as the target model $\overline P$), $\overline{Q}(y \mid \overline{x})$ will try to approach Kneser-Ney performance.
Note: Since $P$ is often used for true distribution which we don’t have, we use an empirical proxy for $\overline P$, which is $\hat P$. For simplicity, we do not use the $\hat P$ notation in this post.
Performance on larger task (prediction using complete context)
What is not very evident is that minimizing the objective shown in Equation 2 will also improve performance on the original task (i.e. 35-gram prediction, which is the task of the LSTM RNN in Xie et al). This is something we show empirically, and with the corrected regularization objective, we are able to achieve better performance on the original objective (i.e. Cross Entropy loss of the 35-gram LSTM RNN).
As is usually the case, the regularization objective is only added during training, and removed during evaluation on the original task. This is the only way we can perform a fair comparison with Xie et al.
By a small margin, our methodology outperformed theirs. This is shown in Table 1 of our paper.
Infusing desirable properties in LM training
Through this methodology, we were able to find the mathematically correct way to infuse the smoothing property of Kneser-Ney in the training of larger language models!
Applicability beyond language models
While showing small improvement over Xie at al on the benchmark used in their paper, we decided to implement IMM in other domains to better show it’s merits. More precisely, we came up with a simple proof of concept using just logistic regression (as surprising as it sounds, yes, we came up with proof of concept later).
Additionally, we went forward and implemented the same idea in the domain of reinforcement learning. More precisely, we used POMDPs as the restricted models to improve the training of MDPs! Details are deferred to the paper.
Concluding remarks
I hope this short blog post was able to give you an overview of the IMM methodology. We found the mathematically correct way of doing it, and solved the inconsistencies of prior approaches.
At the very least, a stepping stone in the domain of distillation from weak teachers. We discovered the tradeoff between consistency and computational efficiency: while mathematically correct and empirically more accurate, our methodology (at least with the current examples) requires more compute to induce the smaller model. Perhaps this can be improved in the future.
Further Reading
As an elaboration of my paper, you can read my thesis.
References
[1] Muneeb, U., & Ohannessian, M. I. (2024). Induced model matching: Restricted models help train full-featured models. Advances in Neural Information Processing Systems, 37, 62617-62647.
[2] Xie, Z., Wang, S. I., Li, J., Lévy, D., Nie, A., Jurafsky, D., & Ng, A. Y. (2017). Data noising as smoothing in neural network language models. arXiv preprint arXiv:1703.02573.
[3] Chen, S. F., & Goodman, J. (1999). An empirical study of smoothing techniques for language modeling. Computer Speech & Language, 13(4), 359-394.
[5] Yuan, L., Tay, F. E., Li, G., Wang, T., & Feng, J. (2020). Revisiting knowledge distillation via label smoothing regularization. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (pp. 3903-3911).
[4] Qin, Y., Lin, Y., Yi, J., Zhang, J., Han, X., Zhang, Z., … & Zhou, J. (2022, July). Knowledge inheritance for pre-trained language models. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (pp. 3921-3937).
Comments