Artificial Intelligence

Apply machine learning and deep learning

Auto-Categorization of Content using Deep Learning

This post is from Anshul Varma, developer at MathWorks, who will talk about a project where MATLAB is used for a real production application: Applying Deep Learning to categorize MATLAB Answers.

In the Spring of 2019, I had a serious problem. I had just been given the task of putting individual MATLAB Answers into categories for the new Help Center that integrates different documentation and community resources into a single, categorical-based design. The categories help organize content based on topics and enable you to find information easily.

Let me give you an example: Here's an answer related to ANOVA statistical analysis:

I put it in the ANOVA category under the AI, Data Science, and Statistics > Statistics and Machine Learning Toolbox > ANOVA > Analysis of Variance and Covariance. Easy enough, right?

This wouldn't be so hard except for one thing: I had more than 300,000 MATLAB Answers that needed to be slotted into more than 3,500 categories.

There are over 40,000 pages of MathWorks product documentation, spread across the 90+ products MathWorks offers. These documentation pages are organized into over 3,500 categories, to make topics and reference pages easy to find and parse. You've probably used these categories, in the left-hand navigation, when viewing a documentation page:

Building Help Center required the integration of over 300,000 MATLAB Answers, aligning them by category to the existing documentation pages.

Obviously, manually assigning categories to 300,000+ Answers wasn't a realistic solution:

  • How do I go through thousands of Answers in a short amount of time?
  • How do I get help from content experts who were knowledgeable on the entire Category Taxonomy?
  • How do I manually categorize thousands of new Answers created each week and make it a scalable process?

I needed a lot of help. Fortunately for me, the help I needed was right here: Deep Learning Toolbox and Text Analytics Toolbox. Working with my team, we put together a plan to pull off this remarkable feat in a short amount of time to automatically categorize thousands of Answers. And it worked! It worked even better than we could have hoped.

Let me tell you the story of how we did it...

Deep Learning for Auto-Categorization

I chose MATLAB with Deep Learning and Text Analytics Toolboxes to build my solution because of the following two main reasons:

There are 3 steps to solving a classic supervised learning text classification problem, where we want to classify text content to the most relevant class:

  1. Prepare the data that you want to use for training
  2. Train a model using the training data
  3. Validate the trained model

Step 1: Prepare the data for training

I used the following MathWorks resources that have content categorized into categories (manually curated) for building a model:

I used Text Analytics Toolbox for the data preparation. The first part of the text data preparation pipeline is mainly needed for data cleaning where we do things like decoding HTML entities into characters, convert documents to lowercase, erase HTTP and HTTPS URLs, erase HTML and XML tags, and lastly remove all email addresses from the text:

rawTextArray = decodeHTMLEntities(rawTextArray); % Convert HTML and XML entities into characters
rawTextArray = lower(rawTextArray); % Convert documents to lowercase
rawTextArray = eraseURLs(rawTextArray); % Erase HTTP and HTTPS URLs from text
rawTextArray = eraseTags(rawTextArray); % Erase HTML and XML tags from text
rawTextArray = regexprep(rawTextArray, '[^@\s]*@[^@\s]*\.[^@\s]*', ''); % Erase email addresses

The next step in data preparation is needed to clean and preserve valuable information about content that can be easily lost if we perform text analysis to remove all punctuations. The removal of punctuations will erase a term like 'c++' completely. We preserve programming language-specific terms and some MathWorks data-specific terms in this part of the text data preparation pipeline. An example excerpt from this part of the pipeline is shown below:

% Preserve programming language before removing all punctuations 
rawTextArray = replace(rawTextArray, 'c++',' cplusplus ');
rawTextArray = replace(rawTextArray, 'c#',' csharp ');
rawTextArray = replace(rawTextArray, '.net',' dotnet ');
 
 % Do more custom preservation of terms specific to MathWorks data... 
 
rawTextArray = regexprep(rawTextArray, '[\n\r]+',' '); % Erase \n and \r from text
rawTextArray = erasePunctuation(rawTextArray); % Erase punctuation from text

The last part of the text data preparation pipeline is tokenizing and analyzing the text data. We tokenize the documents, remove stop words, remove short and long words, and lastly normalize the words:

preprocessed_tokens = tokenizedDocument(rawTextArray, 'Language', 'en'); % Array of tokenized documents for text analysis
preprocessed_tokens = removeWords(preprocessed_tokens, stopWords); % Remove selected words from documents
preprocessed_tokens = removeShortWords(preprocessed_tokens, 2); % Remove short words from documents
preprocessed_tokens = removeLongWords(preprocessed_tokens, 15); % Remove long words from documents
output_analyzed_text = normalizeWords(preprocessed_tokens); % Stem or lemmatize words

Once all documents are preprocessed, we can use them to create a deep learning text classifier.

Step 2: Train the model using the training data

I referred to the Classify Text Data Using Deep Learning example to create a deep learning LSTM text classifier. I use the preprocessed and analyzed data to train the model.

Prepare data for training, testing, and validation:

% Remove the rows of the table with empty documents.
idxEmpty = strlength(data.preprocessed_text) == 0; 
data(idxEmpty, :) = [];
 
% To divide the data into classes, convert these labels to categorical.
data.id = categorical(data.id);
 
% Find the classes containing fewer than ten observations.
idxLowCounts = classCounts < 10;
infrequentClasses = classNames(idxLowCounts);
 
% Remove these infrequent classes from the data. Use removecats to remove the unused categories from the categorical data.
idxInfrequent = ismember(data.id, infrequentClasses);
data(idxInfrequent,:) = [];
data.id = removecats(data.id);
 
% Partition the data into a training partition and a held-out partition for validation and testing.
cvp = cvpartition(data.id,'Holdout', 0.1);
dataTrain = data(training(cvp),:);
dataHeldOut = data(test(cvp),:);
cvp = cvpartition(dataHeldOut.id,'HoldOut', 0.5);
dataValidation = dataHeldOut(training(cvp),:);
dataTest = dataHeldOut(test(cvp),:);
 
% Extract the preprocessed text data and labels from the partitioned tables.
documentsTrain = tokenizedDocument(dataTrain.preprocessed_text, 'Language', 'en');
documentsValidation = tokenizedDocument(dataValidation.preprocessed_text, 'Language', 'en');
documentsTest = tokenizedDocument(dataTest.preprocessed_text, 'Language', 'en');
 
YTrain = dataTrain.id;
YValidation = dataValidation.id;
YTest = dataTest.id;

The training data excerpt is shown below:

% View the first few preprocessed training documents.
documentsTrain(1:5)

5×1 tokenizedDocument:
 
    26 tokens: integr fire neuron model simulink integr fire neuron model simulink shufan simulink hdl coder simulink model integr fire neuron network model simulink integr fire neuron model
    33 tokens: audio watermark selvakarna audio watermark selvakarna audio process selva extract forman frequenc extract process wave selva karna matlab simulink dsp system fixedpoint design matlab coder signal process audio watermark selvakarna audio watermark selvakarna
    45 tokens: simulink initi schedul simulink initi sequenti block execut exampl enabl execut initi schedul sequenti trigger giampiero campa simulink file contain simpl exampl base enabl subsystem block show execut simulink block initi execut block initi sequenti schedul execut differ subsystem simulink initi sequenti block execut exampl
    10 tokens: power transfer limit simul kathir vel simulink simul paper simul
    19 tokens: matlab graphic simulink techniqu matlab graphic simulink simul graphic simulink live script mike garriti techniqu matlab graphic simulink simul

Generate a word cloud.

% Visualize the training text data using a word cloud
figure
wordcloud(documentsTrain);
title("Training Data")

Set documents to sequences.

% Create a word encoding.
enc = wordEncoding(documentsTrain);
 
% Convert the documents to sequences of numeric indices using doc2sequence.
seq_length = 800;
XTrain = doc2sequence(enc, documentsTrain, 'Length', seq_length);
XValidation = doc2sequence(enc, documentsValidation, 'Length', seq_length);

Prepare training options, configure LSTM network layers, and perform training.

% Create and Train LSTM Network
% Set training configuration.
inputSize = 1; 
embeddingDimension = 100;
numHiddenUnits = enc.NumWords;
hiddenSize = 180;
numClasses = numel(categories(YTrain));
 
layers = [ ...
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension, numHiddenUnits)
    lstmLayer(hiddenSize, 'OutputMode', 'last')
    dropoutLayer(0.4)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];
 
options = trainingOptions('adam', ...
    'LearnRateSchedule', 'piecewise', ...
    'LearnRateDropFactor', 0.1, ...
    'LearnRateDropPeriod', 20, ...
    'MaxEpochs', 25, ...    
    'GradientThreshold', 1, ...
    'InitialLearnRate', 0.01, ...
    'ValidationData', {XValidation, YValidation}, ...
    'Plots', 'training-progress', ...
    'Verbose', true, ...
    'shuffle', 'every-epoch');


% Train the LSTM network using the trainNetwork function.
net = trainNetwork(XTrain, YTrain, layers, options);

Step 3: Validate the trained model

We used both Objective and Subjective validation to verify the model worked as expected.

Objective Validation

Test the accuracy of the LSTM network using the test documents

%% Test the LSTM network
% Convert the test documents to sequences using doc2sequence with the same options 
as when creating the training sequences.
XTest = doc2sequence(enc, documentsTest, 'Length', seq_length);
 
% Classify the test documents using the trained LSTM network.
[YPred, PBScore] = classify(net, XTest);
 
% Extract top 3 predictions.
classes = net.Layers(end).Classes;
classes_matrix = repmat(classes, size(classes')); % 'create matrix to represent all classes 
[~, I] = sort(PBScore, 2, 'descend');
Top3_YPred = classes_matrix(I(:,1:3));
 
% Calculate model accuracy.
model_accuracy = nnz(table2array(rowfun(@ismember, table(YTest, Top3_YPred))))/numel(YPred);

The category prediction accuracy of our model is about 70%. We compute the accuracy by evaluating if the category of the test document matches with any of the top 3 categories predicted by the model. We selected this approach to compute accuracy since the categories data is hierarchical. The categorization is very subjective because an answer could live in multiple categories. In most cases, the categorization is not black and white due to the nature of categories and parent-child relationships. We cannot rely exclusively on the objective analysis of accuracy computation for validation. So, we did subjective validation as described below.

Subjective Validation

We enlisted the help of our employees to help test the accuracy of our predictions. We performed numerous bashes, tempting employees with pizza, chocolate, and prizes, to ensure participation would be statistically significant.

Testing predictions
  1. First, we set up a list of common phrases that might be used by customers in our MATLAB Central Community. These came from search phrases captured during customer visits.
  2. Then, we invited staff members who are familiar with our product line, our customer base, or how AI works, in general, to participate in a bash. We asked them to use our internal Categorizations Tool to search for content in a particular topic space. We then asked them to examine the suggested categories for each result and make a subjective judgment on the categories being proposed. We asked them to capture:
    1. What they searched for.
    2. Score the overall predictions as 1 for "good suggestions" or a 0 for "poor suggestions".
    3. What category, in general, they were examining.
    4. Anything else they wanted to share as an observation.
Testing categorizations
  1. We assembled a list of the categories in use and the number of community-asked questions that had been auto-categorized for that topic.
  2. Then, we invited technical support staff who are familiar with our product line to participate in another bash. This time, we asked them to go to the website and navigate to a category for which they had technical expertise. They were to look at the top 10-20 answers shown for that category, and assess whether the results were good, bad, or unclear. They captured:
    1. What category they examined.
    2. A score of 1 (good), 0 (poor), or unknown.
    3. Specific reasons why they chose the score they chose.
    4. Suggestions for more applicable answers, if known, for the category, if they said it was poorly categorized.
After performing the subjective assessments, we then performed a numerical analysis on the data captured and identified:
  • Areas where we felt confident in our categorization.
  • Areas where the algorithm could be improved.
  • Areas where we needed additional testing.

We generated metrics and analyzed the data using objective and subjective validations to gauge the quality of categorizations. Overall, we were happy with the results.

 

That's it! 🙂 We have created a text classification model using MATLAB Deep Learning and Text Analytics Toolboxes that can automatically assign categories to more than 300,000 Answers. Using MATLAB, we can automatically categorize the in-flow of new and updated Answers daily as well. It was a unique solution that helped save us a lot of time and manual effort. It made our processes scalable and manageable.

Let me end this blog by sharing some interesting challenges of using AI techniques in the production environment, pro tips, and resources with you:

Interesting challenges of using AI techniques in the production environment
  • Refreshing the model regularly to prevent model decay: The MathWorks Documentation and the Category Taxonomy are updated every 6 months as part of the new release of MATLAB. We retrain the model with the updated training data and category labels. The updated model is then deployed to production so that the content is classified into the correct categories.

  • Identifying the categories that need our attention for optimal user experience: There are categories where we would like to enhance our model for better prediction outcomes. One way we are identifying such categories is by looking into the training data set. We find categories that lack manual training input provided by a human. The quality of classification in such categories is sub-optimal. We are looking at this data every 6 months when we re-train the model. We take help from the content experts within our organization to help with curation. This activity provides feedback to the model for improvements in future classification.

  • Developing automatic quality measurements and relevancy tests: We are developing ways to automate the validation of the model and classification process. We envision a process to execute our relevancy testing suite to analyze metrics automatically and notify us if there is an anomaly. The automation will help us save development time during data refreshes and model upgrades.

Pro Tips
  • My general recommendation when you are thinking of using AI in production to solve a problem would be to start with something that is easy to understand and debug, and then evolve to a more complicated process (like using neural networks) if it yields better results.

  • Prototype, prototype, and prototype! Don't be afraid of failing. AI can be complicated. Applying AI to solve the complicated problems is even harder. You need to get your hands in the clay sometimes to understand if AI techniques can or cannot work for you.

  • The truth about your data might surprise you. Using AI techniques often reveals some interesting patterns about your data. The patterns could help to identify interesting use-cases/issues that you might not have thought about.

Resources to learn more

|
  • print

评论

要发表评论,请点击 此处 登录到您的 MathWorks 帐户或创建一个新帐户。