Combining Deep Learning networks to increase prediction accuracy.
The following post is from Maria Duarte Rosa, who wrote a great post on neural network feature visualization, talking about ways to increase your model prediction accuracy.
- Have you tried training different architectures from scratch?
- Have you tried different weight initializations?
- Have you tried transfer learning using different pretrained models?
- Have you run cross-validation to find the best hyperparameters?
If you answered Yes to any of these questions, this post will show you how to take advantage of your trained models to increase the accuracy of your predictions. Even if you answered No to all 4 questions, the simple techniques below may still help to increase your prediction accuracy.
First, let's talk about ensemble learning.
What is ensemble learning?
Ensemble learning or model ensembling, is a well-established set of machine learning and statistical techniques [LINK:
https://doi.org/10.1002/widm.1249] for improving predictive performance through the combination of different learning algorithms. The combination of the predictions from different models is generally more accurate than any of the individual models making up the ensemble. Ensemble methods come in different flavours and levels of complexity (for a review see
https://arxiv.org/pdf/1106.0257.pdf), but here we focus on combining the predictions of multiple deep learning networks that have been previously trained.
Different networks make different mistakes and the combination of these mistakes can be leveraged through model ensembling. Although not so popular in the deep learning literature as it is for more traditional machine learning research, model ensembling for deep learning has led to impressive results, specially in highly popular competitions, such as ImageNet and other Kaggle challenges. These competitions are commonly won by ensembles of deep learning architectures.
In this post, we focus on three very simple ways of combining predictions from different deep neural networks:
- Averaging: a simple average over all the predictions (output of the softmax layer) from the different networks
- Weighted average: the weights are proportional to an individual model's performance. For example, the predictions for the best model could be weighted by 2, while the rest of the models have no weight;
- Majority voting: for each test observation, the prediction is the most frequent class in all predictions
We will use two examples to illustrate how these techniques can increase the accuracy in the following situations:
Example 1: combining different architectures trained from scratch.
Example 2: combining different pretrained models for transfer learning.
Even though we picked two specific use cases, these techniques apply to most situations where you have trained multiple deep learning networks, including networks trained on different datasets.
Example 1 – combining different architectures trained from scratch
Here we use the CiFAR-10 dataset to train from scratch 6 different ResNet architectures. We follow this example [LINK:
https://www.mathworks.com/help/deeplearning/examples/train-residual-network-for-image-classification.html] but instead of training a single architeture we vary the number of units and network width using the following 6 combinations: numUnits = [3 6 9 12 18 33]; and netWidth = [12 32 16 24 9 6]. We train each network using the same training options as in the example and estimate their individual validation errors (
validation error = 100 - prediction accuracy):
Individual validation errors:
Network 1: 16.36% |
Network 2: 7.83% |
Network 3: 9.52% |
Network 4: 7.68% |
Network 5: 10.36% |
Network 6: 12.04% |
We then calculated the errors for the three different ensembling techniques:
Model ensembling errors:
Average: 6.79% |
Weighted average: 6.79% (Network 4 counted twice). |
Majority vote: 7.16% |
A quick chart of these numbers:
figure; bar(example1Results); title('Example 1: prediction errors (%)');
xticklabels({'Network 1','Network 2','Network 3', 'Network 4', 'Network 5', 'Network 6', ...
'Average', 'Weighted average', 'Majority vote'}); xtickangle(40)
The ensemble prediction errors are smaller than any of the individual models. The difference is small but in 10000 images it means that 89 images are now correctly classified in comparison with the best individual model. We can see some examples of these images:
% Plot some data (misclassified for best model)
load Example1Results.mat
figure;
for i = 1:4
subplot(2,2,i);imshow(dataVal(:,:,:,i))
title(sprintf('Best model: %s / Ensemble: %s',bestModelPreds(i),ensemblePreds(i)))
end
Example 2 – combining different pretrained models for transfer learning
In this example we use again the CiFAR-10 dataset but this time we use different pretrained models for transfer learning. The models were originally trained on ImageNet and can be dowloaded as support packages [LINK:
https://www.mathworks.com/matlabcentral/profile/authors/8743315-mathworks-deep-learning-toolbox-team]. We used googlenet, squeezenet, resnet18, xception and mobilenetv2 and followed the transfer learning example [LINK:
https://www.mathworks.com/help/deeplearning/examples/train-deep-learning-network-to-classify-new-images.html]
Individual validation errors:
googlenet: 7.23% |
squeezneet: 12.89% |
resnet18: 7.75% |
xception: 3.92% |
mobilenetv2: 6.96% |
Model ensembling errors:
Average: 3.56% |
Weighted average: 3.28% (Xception counted twice). |
Majority vote: 4.04% |
% Plot errors
figure;bar(example2Results); title('Example 2: prediction errors (%)');
xticklabels({'GoogLeNet','SqueezeNet','ResNet-18', 'Xception', 'MobileNet-v2', ...
'Average', 'Weighted average', 'Majority vote'}); xtickangle(40)
Again the ensemble prediction errors are smaller than any of the individual models and 64 more images were correctly classified. These included:
% Plot some data (misclassified for best model)
load Example2Results.mat
figure;
for i = 1:4
subplot(2,2,i);imshow(dataVal(:,:,:,i))
title(sprintf('Best model: %s / Ensemble: %s',bestModelPreds(i),ensemblePreds(i)))
end
What else should I know?
Model ensembling can significantly increase prediction time, which makes it impractical in applications where the cost of inference time is higher than the cost of making the wrong predictions.
One other thing to note is that performance does not increase monotonically with the number of networks. Typically, as this number increases, training time significantly increases but the return of combining all models diminishes.
There isn’t a single magic number for how many networks one should combine. This is heavily dependent on the networks, data, and computational resources. Having said this, performance tends to improve the more model variety we have in an ensemble.
Hope you found this useful - Have you tried ensemble learning or thinking of trying it? Leave a comment below.
评论
要发表评论,请点击 此处 登录到您的 MathWorks 帐户或创建一个新帐户。