Predicting Timely Diagnosis of Metastatic Breast Cancer for the WiDS Datathon 2024
In today’s blog, Grace Woolson will show how you can use MATLAB and machine learning to make meaningful deductions from healthcare data for patients who have been diagnosed with metastatic breast cancer. Over to you Grace!
Introduction
In this blog, I will show how you can use MATLAB for the WiDS Datathon 2024 using the dataset for the WiDS Datathon #1, which runs from January 9th 2024 – March 1st 2024. This challenge tasks participants with creating a model that can predict whether or not a patient with metastatic breast cancer will receive a diagnosis within 90 days based on patient and environmental data. This can help identify relationships between demographics or environmental hazards with the likelihood of getting timely treatment. Please note that this blog is based on a subset of the data and there may be slight differences between this dataset and the one provided by WiDS.
MathWorks is happy to support participants of the Women in Data Science Datathon 2024 by providing complimentary MATLAB licenses, tutorials, workshops, and additional resources. To request complimentary licenses for you and your teammates, go to this MathWorks site, click the “Request Software” button, and fill out the software request form.
This tutorial will walk through the following steps of the model-making process:
- Importing a Tabular Dataset
- Preprocessing the Data
- Exploring and Analyzing Tabular Data
- Choosing and Creating Features
- Training a Machine Learning Model
- Evaluating a Machine Learning Model
- Making New Predictions and Exporting Submissions
Import Data
First, make sure the ‘Current Folder’ is the folder where you saved the data. If you have not already done so, you can download the data from Kaggle after you register for the datathon. The data is provided as a .CSV file, so we can use the readtable function to import the whole file as a table.
dataFolder = fullfile(pwd);
trainDataFilename = ‘Training.csv’;
allTrainData = readtable(fullfile(dataFolder, trainDataFilename))
I want to see some high-level statistics about the data, so I’ll use the summary function to get an idea of what kind of information we have.
summary(allTrainData)
Take some time to scroll through this summary and see what information or patterns you can learn! Here are some things I notice:
- There are a lot of rows or variables that just say “cell array of character vectors”, which doesn’t tell us much about the data.
- There are a few variables that have a high ‘NumMissing’ value.
- The numeric variables can have dramatically different minimums and maximums.
We can use these observations to make decisions about how we want to explore and preprocess the dataset.
Process and Clean the Data
1. Convert text data to categorical
Text data can be hard for machine learning algorithms to understand, so let’s go through and change each “cell array of character vectors” to a categorical. This will help the algorithm sort the text into different categories instead of understanding it as a series of individual letters.
varTypes = varfun(@class, allTrainData, OutputFormat=“cell”);
catIdx = strcmp(varTypes, “cell”);
varNames = allTrainData.Properties.VariableNames;
catVarNames = varNames(catIdx);
for catNameIdx = 1:length(catVarNames)
allTrainData.(catVarNames{catNameIdx}) = categorical(allTrainData.(catVarNames{catNameIdx}));
end
2. Handle Missing Data
Now I want to handle all that missing data I noticed earlier. I’ll go through each variable and specifically look at variables that are missing data for over half of the rows or observations.
dataSum = summary(allTrainData);
for nameIdx = 1:length(varNames)
varName = varNames{nameIdx};
varNumMissing = dataSum.(varName).NumMissing;
if varNumMissing > (height(allTrainData) / 2)
disp(varName);
disp(varNumMissing);
end
end
Let’s remove those variables entirely, since they might not be too helpful for our algorithm.
allTrainData = removevars(allTrainData, [“bmi”, “metastatic_first_novel_treatment”, “metastatic_first_novel_treatment_type”])
Now I want to look at each row and remove any that are missing too many values. It’s okay to have a couple of missing data points in your dataset, but if you have too many it could cause your machine learning algorithm to be less accurate. I’ll use the Clean Missing Data live task to remove any rows that are missing 2 or more data points.
% Remove missing data
[fullData,missingIndices] = rmmissing(allTrainData,“MinNumMissing”,2);
% Display results
figure
% Get locations of missing data
indicesForPlot = ismissing(allTrainData.patient_age);
mask = missingIndices & ~indicesForPlot;
% Plot cleaned data
plot(find(~missingIndices),fullData.patient_age,“SeriesIndex”,1,“LineWidth”,1.5, …
“DisplayName”,“Cleaned data”)
hold on
% Plot data in rows where other variables contain missing entries
plot(find(mask),allTrainData.patient_age(mask),“x”,“SeriesIndex”,“none”, …
“DisplayName”,“Removed by other variables”)
% Plot removed missing entries
x = repelem(find(indicesForPlot),3);
y = repmat([ylim(gca) missing]’,nnz(indicesForPlot),1);
plot(x,y,“Color”,[145 145 145]/255,“DisplayName”,“Removed missing entries”)
title(“Number of removed missing entries: ” + nnz(indicesForPlot))
hold off
legend
ylabel(“patient_age”,“Interpreter”,“none”)
clear indicesForPlot mask x y
Explore the Data
Now that the data is cleaned up, you should spend some time exploring your data to understand how different variables may interact with each other or see if you can draw any meaningful conclusions from the data or figure out which variables may be more or less important when it comes to predicting time to diagnosis.
Univariate Analysis
First, I want to separate the data into two datasets: one full of patients who were diagnosed in 90 days or less (the 1 or “True” values), and one full of patients who were not (the 0 or “False” values). This will allow me to explore the data patterns in each of these datasets and look for any meaningful differences.
allTrueIdx = fullData.DiagPeriodL90D == 1;
allTrueData = fullData(allTrueIdx, :);
allFalseIdx = fullData.DiagPeriodL90D == 0;
allFalseData = fullData(allFalseIdx, :);
Now we can use the Create Plot live task to plot histograms of the different variables in each dataset. In the plot below, blue bars represent data from the folks who were diagnosed in a timely manner, and the red bars represent data from the folks who were not.
figure
% Create histogram of selected data
histogram(allTrueData.health_uninsured,“NumBins”,40,“DisplayName”,“health_uninsured”);
hold on
% Create histogram of selected data
histogram(allFalseData.health_uninsured,“NumBins”,40,“DisplayName”,“health_uninsured”);
hold off
legend
Take some time to explore these visualizations on your own, as I can only show one at a time in this blog. It is worth noting that we have less False data than True data, so the red bars will almost always be lower than the blue bars. If there are red bars that are higher or if the shapes are different, that may indicate a relationship between a variable and time to diagnosis.
I didn’t see many significant differences in shape, though I did notice that for the ‘health_uninsured’ histograms the red vars are fairly high in the higher numbers, indicating that there may be a correlation between populations with high rates of being unisured and time to diagnosis.
Bivariate and Multivariate Analysis
You can break the data down further and plot two (or more!) variables against each other to see if you can find any patterns. In the plot below, for example, we can see the percentage of the population that is unisured and the state the patient is in, broken down by whether or not the patient was diagnosed within 90 days. Again, blue values indicate that the patient was, and red values indicate that the patient was not.
figure
% Create scatter of selected data
scatter(allTrueData,“patient_state”,“health_uninsured”,“DisplayName”,“health_uninsured”);
hold on
% Create scatter of selected data
scatter(allFalseData,“patient_state”,“health_uninsured”,“DisplayName”,“health_uninsured”);
hold off
legend
We can see that in some states, such as GA, OK, or TX, the the red values come from populations that are typically higher in terms of being uninsured. This could indcate that in some states, coming from a zip code with a high population of uninsured folks (or being uninsured yourself) means you are more likely to receive delays in your diagnosis.
Statistical Analysis
You can also create meaningful deductions by calculating various statistics from your data. For example, I want to calculate the skewness, or level of asymmetry, of each of my variables. A negative value indicates the data is left skewed when plotted, and a positive value indicates the data is right skewed when plotted, with a 0 meaning the data is evenly distributed.
statsTrue = varfun(@skewness, allTrueData, “InputVariables”, @isnumeric);
statsFalse = varfun(@skewness, allFalseData, “InputVariables”, @isnumeric);
Now I want to see if any of the variables have a significant difference in their skewness, as differences in the data distributions between patients who were diagnosed in a timely manner vs patients who were not could indicate an underlying relationship between those variables and time to diagnosis.
statsDiffs = abs(statsTrue{:, :} – statsFalse{:, :});
statsTrue.Properties.VariableNames(statsDiffs > 0.2)
If we investigate the four variables that are returned, we can see that population density, the percentage of folks above 80 in your zip code, the median rent burden of your zip code, and the percentage of residents who reported their race as American Indian or Alaska Native in your zip code may have a relationship with time to diagnosis.
Feature Engineering
When it comes to machine learning, you don’t have to use all of the data as it is presented to you. Feature Engineering is the process of deciding what data you want to use, creating new data based on the provided data, and transforming the data to be in whatever format or range is suitable for your workflow. You can do this manually, and some of the exploration we just did should influence decisions you make if you want to play around with including or excluding different variables.
For this blog, I’ll use the gencfeatures function to automate this process. I want to use 90 features, which is 10 more than we currently have in our dataset, and it will go through and create a set of 90 meaningful features based on our processed dataset. It may keep some data as-is, but will often standardize numeric variables and create new variables by manipulating the provided data.
[T, augTrainData] = gencfeatures(fullData, “DiagPeriodL90D”, 90)
To better understand the generated features, you can use the describe function of the returned FeatureTransformer object, ‘T’.
describe(T)
Split the Data
The last step before you can train a machine learning model is to split your data into a training and testing set. We’ll use the training data to fit the model, and the testing set to evaluate how well the model performs on new data before we use it to make a submission. Here I split the data into 80% training and 20% testing.
numRows = height(augTrainData);
[trainInd, ~, testInd] = dividerand(numRows, .8, 0, .2);
trainingData = augTrainData(trainInd, :);
testingData = augTrainData(testInd, :);
Train a Machine Learning Model
In this example, I’ll create a binary decision tree using the fitctree function and set ‘Optimize Hyperparameters’ to ‘auto’, which will attempt to minimize the error of our algorithm by choosing the best value for the ‘MinLeafSize’ parameter. It visualizes the results of adjusting this value, as can be seen below.
classificationTree = fitctree(trainingData, “DiagPeriodL90D”, …
OptimizeHyperparameters=‘auto’);
I used a binary tree as my starting point, but it’s important to test out different types of algorithms to see what works best with your data! Check out the Classification Learner app documentation and this short video to learn how to train several machine learning models quickly and iteratively!
Test Your Model
There are many ways to evaluate the performance of a machine learning model, so in this blog I’ll show how to do so by computing validation accuracy and using testing data.
Validation Accuracy
Cross-validation is one method of evaluating a model, and at a high level is done by:
- Setting aside a subset of the training data, known as validation data
- Using the rest of the training data to fit the model
- Testing how well the model performs on the validation data
You can use the crossval function to do this:
% Perform cross-validation
partitionedModel = crossval(classificationTree, ‘KFold’, 5);
Then, extract the misclassification rate, and subtract it from 1 to get the model’s accuracy. The closer to 1 this value is, the more accurate our model is.
% Compute validation accuracy
validationAccuracy = 1 – kfoldLoss(partitionedModel, LossFun=‘ClassifError’)
Testing Data
In this section, we’ll use the ‘testingData’ dataset we created earlier. Similar to what we did with the validation data, we can use the loss function to compute the misclassification rate when you use the classification tree on the testing data, and subtract it from 1 to get a measure of accuracy.
testAccuracy = 1 – loss(classificationTree, testingData, “DiagPeriodL90D”,…
LossFun=‘classiferror’)
I also want to compare the predictions that the model makes to the actual outputs, so let’s remove the ‘DiagPeriodL90D’ variable from our testing data
testActual = testingData.DiagPeriodL90D;
testingData = removevars(testingData, “DiagPeriodL90D”);
Now, use the model to make predictions on the testing set
[testPreds, scores, ~, ~] = predict(classificationTree, testingData);
And use the confusionchart function to compare the predicted outputs to the actual outputs, to see how often they match or don’t.
confusionchart(testActual, testPreds)
This shows that it almost always predicts 1s correctly, or when the patient is diagnosed within 90 days, but it’s almost a 50/50 chance that this model will predict the 0s correctly.
We can also use the test data and predictions to visualize receiver operating characteristic (ROC) metrics. The ROC curve shows the true positive rate (TPR) versus the false positive rate (FPR) for different thresholds of classification scores. The “Model Operating Point” shows the false positive rate and true positive rate of the model.
rocObj = rocmetrics(testActual, scores, classificationTree.ClassNames);
plot(rocObj)
Here we can see that the classifier correctly assigns about 90-95% of the 1 class observations to 1 (TPR), but incorrectly assigns about 40% of the 0 class observations as 1 (FPR). This is similar to what we observed with the confusion chart.
You can also extract the area under the curve (AUC) value, which is a measure of the overall quality of the classifier. The AUC values are in the range 0 to 1, and larger AUC values indicate better classifier performance.
rocObj.AUC
The AUC is pretty high, but shows that there is definitely room for improvement. To learn more about ROC metrics, check out this documentation page that explains it in more detail.
Create Submission
Once you have a model that performs well on the validation and testing data, it’s time to create a submission for the datathon! As a reminder, you will upload this file to Kaggle to be scored on the leaderboard.
First, import the ‘Test’ dataset:
testDataFilename = ‘Test.csv’;
allTestData = readtable(fullfile(dataFolder, testDataFilename))
Then we need to process this dataset in the same way that we did the training data. In this section, I use code instead of the live tasks for simplicity.
% replace cell arrays with categoricals
varTypes = varfun(@class, allTestData, OutputFormat=“cell”);
catIdx = strcmp(varTypes, “cell”);
varNames = allTestData.Properties.VariableNames;
catVarNames = varNames(catIdx);
for catNameIdx = 1:length(catVarNames)
allTestData.(catVarNames{catNameIdx}) = categorical(allTestData.(catVarNames{catNameIdx}));
end
% remove variables with too many missing data points
fullTestData = removevars(allTestData, [“bmi”, “metastatic_first_novel_treatment”, “metastatic_first_novel_treatment_type”]);
We also need to use the transform function to create the same features as we created using gencfeatures for the training data.
augTestData = transform(T, fullTestData);
Now that the data is in the format our machine learning model expects it to be in, use the predict function to make predictions, and create a table to contain the patient IDs and corresponding predictions.
submissionPreds = predict(classificationTree, augTestData);
submissionTable = table(fullTestData.patient_id, submissionPreds, VariableNames=[“patient_id”, “DiagPeriodL90D”])
Last, export your predictions to a .CSV file, then upload to Kaggle for scoring.
writetable(submissionTable, “Predictions.csv”);
And that’s it! Thank you for following along with this tutorial, and best of luck to all participants. If you have any questions about this tutorial or MATLAB, reach out to us at studentcompetitions@mathworks.com or by tagging gracewoolson in the forum. Keep your eye out for our upcoming WiDS Workshop on January 31st, where we will walk through this tutorial and answer any questions you have along the way!
- Category:
- Data Science
Comments
To leave a comment, please click here to sign in to your MathWorks Account or create a new one.