Student Lounge

Sharing technical and real-life examples of how students can use MATLAB and Simulink in their everyday projects #studentsuccess

Using Ensemble Learning to Create Accurate Machine Learning Algorithms

In today’s post, Grace from the Student Programs Team will show how you can started with ensemble learning. Over to you, Grace!
When building a predictive machine learning model, there are many ways to improve it’s performance: try out different algorithms, optimize the parameters of the algorithm, find the best way to divide and process your data, and more. Another great way to create accurate predictive models is through ensemble learning.

What is ensemble learning?

Ensemble learning is the practice of combining multiple machine learning models into one predictive model. Some types of machine learning algorithms are considered weak learners, meaning that they are highly sensitive to the data that is used to train them and are prone to inaccuracies. Creating an ensemble of weak learners and aggregating their results to make predictions on new observations often results in a single higher-quality model. At it’s simplest, ensemble learning can be represented with the animation below:
EnsembleGif.gif
Ensemble learning can be used for a wide variety of machine and deep learning methods. Today, I will show how to create an ensemble of machine learning models for a regression problem, though the workflow will be similar for classification problems as well. Let’s get started!

1. Prepare the data

For this problem, we have a set of tabular data pertaining to cars, as shown below:
load carbig
carTable = table(Acceleration,Cylinders,Displacement,
Horsepower,Model_Year,Weight,MPG);
head(carTable)
Acceleration Cylinders Displacement Horsepower Model_Year Weight MPG
____________ _________ ____________ __________ __________ ______ ___12 8 307 130 70 3504 18
11.5 8 350 165 70 3693 15
11 8 318 150 70 3436 18
12 8 304 150 70 3433 16
10.5 8 302 140 70 3449 17
10 8 429 198 70 4341 15
9 8 454 220 70 4354 14
8.5 8 440 215 70 4312 14
Our goal is to create a model that can accurately predict what a car’s mileage per gallon (MPG) will be. With any data problem, you should take the time to explore and preprocess the data, but for this tutorial I will just do some simple steps. For more information on cleaning your data, check out this example that shows a lot of great ways you can preprocess tabular data!
First, I’ll check if our set has any rows with missing data, as this will inform some of our decisions later.
missingElements = ismissing(carTable);
rowsWithMissingValues = any(missingElements,2);
missingValuesTable = carTable(rowsWithMissingValues,:)
missingValuesTable = 14×7 table
Acceleration Cylinders Displacement Horsepower Model_Year Weight MPG
1 17.5000 4 133 115 70 3090 NaN
2 11.5000 8 350 165 70 4142 NaN
3 11 8 351 153 70 4034 NaN
4 10.5000 8 383 175 70 4166 NaN
5 11 8 360 175 70 3850 NaN
6 8 8 302 140 70 3353 NaN
7 19 4 98 NaN 71 2046 25
8 20 4 97 48 71 1978 NaN
9 17 6 200 NaN 74 2875 21
10 17.3000 4 85 NaN 80 1835 40.9000
11 14.3000 4 140 NaN 80 2905 23.6000
12 15.8000 4 100 NaN 81 2320 34.5000
13 15.4000 4 121 110 81 2800 NaN
14 20.5000 4 151 NaN 82 3035 23
There are a total of 14 rows with missing data, 8 of which are missing the ‘MPG’ value. I’ll remove these rows, as they are not helpful for training, but we will use the other rows as they could still provide helpful information when training.
rowsMissingMPG = ismissing(carTable.MPG);
carTable(rowsMissingMPG,: ) = []
carTable = 398×7 table
Acceleration Cylinders Displacement Horsepower Model_Year Weight MPG
1 12 8 307 130 70 3504 18
2 11.5000 8 350 165 70 3693 15
3 11 8 318 150 70 3436 18
4 12 8 304 150 70 3433 16
5 10.5000 8 302 140 70 3449 17
6 10 8 429 198 70 4341 15
7 9 8 454 220 70 4354 14
8 8.5000 8 440 215 70 4312 14
9 10 8 455 225 70 4425 14
10 8.5000 8 390 190 70 3850 15
11 10 8 383 170 70 3563 15
12 8 8 340 160 70 3609 14
13 9.5000 8 400 150 70 3761 15
14 10 8 455 225 70 3086 14
Last, I will split our data into a training set and a testing set, which will be used to teach and evaluate the ensemble, respectively. Using the dividerand function, I put 70% of the data into the training set and 30% into the testing set as a starting point, but you can test out different divisions of data when building your own models.
numRows = size(carTable,1);
[trainInd, ~, testInd] = dividerand(numRows, .7, 0, .3);
trainingData = carTable(trainInd, :);
testingData = carTable(testInd, :);

2. Create an ensemble

Now that our data is ready, it’s time to start creating the ensemble! I’ll start by showing the easiest way to create an ensemble using the default parameters for each individual learner, and then I’ll also show how to use templates to customize your weak learners.

Using Bulit-In Algorithms

You can create an ensemble for regression by using fitrensemble (fitcensemble for classification). With just this function and your data, you could have an ensemble ready by executing the following line of code:
Mdl = fitrensemble(trainingData, ‘MPG’);
This will use all the default settings for training an ensemble of weak regression learners: 100 trees are trained and they are aggregated using the least-squares boosting (LSBoost) algorithm.
However, fitrensemble also provides the option to customize the ensemble settings, so I will specify a few of these settings. First, I want to use the ‘Bag’ method of aggregation instead of the ‘LSBoost’ method because it tends to have higher accuracy and our dataset is relatively small. For a full list of aggregation algorithms and some suggestions on how to choose a starting algorithm, check out this documentation page!
I also want to specify how many learners will be in the ensemble, which is set by the ‘NumLearningCycles’ property. To choose how many learners the ensemble will have, try starting with several dozen, training the ensemble, and then checking the ensemble quality. If the ensemble is not accurate and could benefit from more learners, then you can adjust this number or add them in later! For now, I’ll start with 30 learners.
Both of these options are set using Name-Value arguments, as shown below.
Mdl = fitrensemble(trainingData, ‘MPG’, ‘Method’, ‘Bag’, ‘NumLearningCycles’, 30);
And just like that, we’ve trained an ensemble of 50 learners that are ready to be used on new data!

Using Templates

There may be times when you want to change some parameters of the individual learners, not just of the ensemble. To do that, we can use learner templates.
Unless otherwise specified, fitrensemble creates an ensemble of default regression tree learners, but this may not always be what you want. As we saw earlier, our data has some missing values, which can decrease the performance of these trees. Trees that use surrogate splits tend to perform better with missing data than trees that don’t, so I will use the templateTree function to specify that I want the learners to use surrogate splits.
templ = templateTree(‘Surrogate’,‘all’, ‘Type’, ‘regression’);
templMdl = fitrensemble(trainingData, ‘MPG’, ‘Method’, ‘Bag’, ‘NumLearningCycles’, 30, ‘Learners’, templ);
As before, we end up with an ensemble that can be used to make predictions on new data! While you can give fitcensemble and fitrensemble a cell array of learner templates, the most common usage is to give just one weak learner template.

3. Evaluate the Ensemble

Once you have a trained ensemble, it’s time to see how well it performs! The predictive quality of an ensemble cannot be evaluated based on its performance on training data, as the ensemble knows it too well. It’s very likely that it will perform really well on the training data, but that does not mean it will perform well on any other data. To obtain a better idea of the quality of an ensemble, you can use one of these methods:
I will show how to use both of these methods to evaluate a model in the following sections.

Evaluate through Cross-Validation

Cross validation is a common technique for evaluating a model’s performance by partitioning the total dataset and using some partitions for training and others for testing. There are multiple types of cross-validation, and many allow you to eventually use all of the training data to train the model, which is what makes it ideal for smaller datasets. If you are not familiar with cross-validation, check out this discovery page to learn more!
For this example, I will be using k-fold cross-validation. You can perform cross-validation when creating your ensemble by using the ‘CrossVal’ Name-Value Argument of fitrensemble, as outlined below:
Mdl = fitrensemble(trainingData, ‘MPG’, ‘CrossVal’, ‘on’);
Since we already have an ensemble, however, we can use the crossval function to cross-validate our model, then use kfoldLoss to extract the average mean squared error (MSE), or loss, of our final model.
cvens = crossval(templMdl);
kfoldLoss(cvens)
ans = 9.0388
We can also set the ‘mode’ of kfoldLoss to ‘cumulative’ and then plot the results to show how the loss value changes as more trees are trained.
cValLoss = kfoldLoss(cvens,‘mode’,‘cumulative’)
cValLoss = 30×1
15.6090
11.8164
10.9555
10.7656
10.5454
9.9553
9.8663
9.8305
9.6588
9.4307
plot(cValLoss, ‘r–‘)
xlabel(‘Number of trees’)
ylabel(‘Cross-Validation loss’)

Evaluate on test set

If you have enough data to use only a portion of it for training, you can use the rest of the data to test how well your model performs. First, make sure you separate your data into a training and testing set, as we did earlier, then train your model using only the training set.
Once the ensemble is trained, we can use it on the testing data and then calculate the loss of the model on this data:
loss(templMdl,testingData,“MPG”)
ans = 8.5888
plot(loss(templMdl,testingData,“MPG”,‘mode’,‘cumulative’))
xlabel(‘Number of trees’)
ylabel(‘Test loss’)
You can also use the ensemble to make predictions using the predict function. With a test set, you can compare the expected results from the testing data to the results predicted by the ensemble. In the plot below, the blue line represents the expected results, and the red circles are the predicted results; the further away a circle is from the blue line, the less accurate the prediction was.
predMPG = predict(templMdl, testingData);
expectMPG = testingData.MPG;
plot(expectMPG, expectMPG);
hold on
scatter(expectMPG, predMPG)
xlabel(‘True Response’)
ylabel(‘Predicted Response’)
hold off
You can use these evaluation metrics to compare multiple ensembles and choose the one that performs the best.

4. Iterate and Improve!

As with any machine learning workflow, it’s important to try out different algorithms until you get an ensemble that you are happy with. When I first started creating this ensemble, I used the ‘LSBoost’ aggregation method instead of ‘Bag’ and the performance was consistently pretty poor, so I changed this property in line 17 (and 19) and re-ran the entire Live Script, resulting in a new, fully evaluated model in a matter of seconds. In addition to testing out different aggregation algorithms, here are some other suggestions for improving your ensemble:
  • If it appears that the loss of your ensemble is still decreasing when all members have finished training, this could indicate that you need more members. You can add them using the resume method. Repeat until adding more members does not improve ensemble quality.
  • Try optimizing your hyperparameters automatically by using the ‘OptimizeHyperparameters’ and ‘HyperparameterOptimizationOptions’ Name-Value arguments when calling fitrensemble. Check out this example in the documentation to learn more: Hyperparameter optimization.
  • Test out different weak learners! There are lots of different settings and templates you can use, especially if you’re creating a classification ensemble. Try different parameters when calling fitrensemble or fitcensemble, use different template types, and play around with the different options of each template.
  • At the end of the day, a model is only as good as the data it is trained on, so make sure your data is clean and test out different divisions of training and testing data to see what works best for your ensemble. There are many different ways to clean data depending on what format it is in, so use this documentation page as a starting point to find resources based on the format and patterns of your dataset!
If you are interested in deep learning and would like to learn about ensemble learning with neural networks, check out this blog post next!

|
  • print

コメント

コメントを残すには、ここ をクリックして MathWorks アカウントにサインインするか新しい MathWorks アカウントを作成します。