Loss Functions
Target Values
There are basically two kinds of target values:
- Numerical (vector of real values)
- Categorical (type of object)
For the first kind, a common loss is the mean squared error loss. For the second kind, the most common approach is to view classification as predicting a discrete probability distribution over the set of possible objects. Then one chooses losses that measure the dissimilarity between the predicted distribution over the objects and the target distribution (one-hot encoding). In practice, numerical stability issues mean the prediction and loss might not be based on literal discrete probability distributions.
One-Hot Encoding
For classification tasks, we can associate an integer from \(1\) to \(n\) to each type of object out of the \(n\) possible ones. Then, the target for an object of class \(i\) is the \(n\)-dimensional vector with a \(1\) as the \(i^{\mathrm{th}}\) element and zeros for all other elements. This representation is called a one-hot encoding.
Mean Square Error
This loss is the most widely used when the target values are vectors of real values. For example, when predicting the age and value of a house from a large list of information about the house.
The mean squared error loss is
\[ L(\theta) = \frac{1}{N} \mathbf{e}^\top \mathbf{e}, \]
where \(\theta\) are the trainable parameters, and \(e\) is some vector of errors between predictions and targets on the data set.
This loss is popular because it is: - Convex - Differentiable -
In torch
, the loss is defined in torch.nn.MSELoss
Cross-Entropy Loss
Under the most common case of one-hot encoding, cross entropy loss for a single datum boils down to the negative log of the predicted probability of the true class. Thus, if the true class is \(j\), and out model says that the probability that the input belongs to class \(j\) is \(p_j\), then the loss for this input is \(-\log p_j\). We would like \(p_j\) to be \(1\), and if this happens the loss is zero. If instead \(p_j\) is zero, a very incorrect prediction, the loss is \(\infty\). The loss smoothly decreases from \(\infty\) to \(0\) as \(p_j\) decreases. Moreover, it is convex in \(p_j\).
Theory
Given a distribution \(p(x)\) over some set \(X \ni x\), we can define the surprise as \(-\log p(x)\). The surprise is high when we obtain sample \(x\) and \(p(x)\) is low, while the surprise is \(0\) if we sample \(x\) and \(p(x)=1\). That is, we are surprised when unexpected things happen and not surprised at all when the sure thing happens.
The entropy \(H(p)\) of some source of information is then the expected value of the surprise. Expected-value is the probability-weighted average.
\[ H(p) = \sum_{x_i \in X} p(x_i) ( - \log p(x_i) ).\]
Entropy is low when the set of events that will probably occur is small relative to the set of possible outcomes. A uniform distribution has the highest entropy (anything can occur with equal probability), and a distribution with all weight on one outcome has zero entropy (exactly one thing will happen).
In entropy, the surprise and expected value come from the same distribution.
In cross entropy, the surprise is based on the predicted probabilities. The expected value is with respect to the true class probabilities. So, if we predict \(q^{pred}(x)\), and the true distribution is \(p(x)\), then cross entropy is
\[H(p,q^{pred}) = \sum_{x_i \in X} p(x_i) ( - \log q^{pred}(x_i) ). \tag{1}\]
If we are bad at predicting the class probabilities, we should expect a lot of surprise, since with higher probability we will see true outcomes which we predicted to be unlikely. In fact, \(H(q,p^{pred})\) is minimized when \(q^{pred} = p\).
Finally, we can view cross-entropy as being equivalent to the Kullback-Leibler (KL) divergence, which provides a distance-like measure between distributions. The KL divergence of \(p\) from \(q^{pred}\) is \(H(p) - H(p,q^{pred})\). Since \(H(p)\) does not depend on the model, as far as taking gradients are concerned for training parameters, cross-entropy is equivalent to KL divergence, the latter being a true dissimilarity measure between distributions. When \(q^{pred}=p\), the KL divergence is zero.
PyTorch Implementation
There are two ways to implement the theoretical cross entropy loss in PyTorch
.
- Use the cross-entropy loss
torch.nn.CrossEntropyLoss
and ensure that the model outputs the class logits, not the probability distribution. In practice this just means using a final layer with no activation. The softmax operation applied to the class logits will produce the probability distribution \(q^{pred}\). ThisPyTorch
function will implement Equation 1 given the logits. When the target is a unique class the implementation of Equation 1 is cheaper than when it is a distribution over classes. - Use the NLL (negative log likelihood) loss
torch.nn.Functional.nll_loss
and ensure that the model predicts the log probabilities. This latter requirement is often achieved by adding aLog_SoftMax
layer at the end of the model. The loss function will return the negative of the \(i^{\mathrm{th}}\) element from the model’s output, where \(i\) is the target class label. Thus, it implements the cross entropy loss for the case of one-hot encoding.
The basic reason for these two differing methods is that computing the logarithm of a softmax of values in one go is numerically more stable than the seemingly intuitive way of first computing softmax to get probabilities and then later computing the logarithm of those probabilities. Either you perform the log-softmax in the model or you do it in the loss. The NLL loss assumes the target distribution is one-hot, so that it is more efficient but less general.