Deep Wine Designer
- Train a long short-term memory (LSTM) network that treats text data as a time series and learns long-term dependencies between time steps.
- Train a convolutional neural network (CNN) that treats text data as hyperspectral images and learns localized features by applying sliding convolutional filters.
Background: Word Embeddings
To convert text to hyperspectral images, we can use a word embedding that maps a sequence of words to a 2-D array representing an image. In other words, a word embedding maps words to high-dimensional vectors. These vectors sometimes have interesting properties. For example, given the word vectors corresponding to Italy, Rome, Paris, and France you might discover the relationship:Italy – Rome + Paris ≈ France
That is, the vector corresponding to the word Italy without the components of the word Rome but with the added components of the word Paris is approximately equal to the vector corresponding to France. To do this in MATLAB, we can use the Text Analytics Toolbox™ Model for fastText English 16 Billion Token Word Embedding support package. This word embedding maps approximately 1,000,000 English words to 1-by-300 vectors. Let's load the word embedding using the fastTextWordEmbedding function.emb = fastTextWordEmbedding;Let's visualize the word vectors Italy, Rome, Paris, and France. By calculating the word vectors using the word2vec function, we can see the relationship:
italy = word2vec(emb,"Italy"); rome = word2vec(emb,"Rome"); paris = word2vec(emb,"Paris"); word = vec2word(emb,italy - rome + paris)
Words to Images
To use a word embedding to map a sequence of words to an image, let's split the text into words using tokenizedDocument and convert the words to a sequence vectors using doc2sequence.str = "The rain in Spain falls mainly on the plain."; document = tokenizedDocument(str); sequence = doc2sequence(emb,document);Let's view the hyperspectral images corresponding to this sequence of words.
figure I = sequence{1}; imagesc(I,[-1 1]) colorbar xlabel("Word Index") ylabel("Embedding Feature") title("Word Vectors")
data:image/s3,"s3://crabby-images/45b7d/45b7dd4f8a767f02100501e8eeda89e9944ef062" alt=""
Load Wine Reviews Data
Let's download the Wine Reviews data from Kaggle and extract the data into a folder named wine-reviews. After downloading the data, we can read the data from winemag-data-130k-v2.csv into a table using the readtable function. The data contains special characters such as the é in Rosé, so we must specify the text encoding option too.filename = fullfile("wine-reviews","winemag-data-130k-v2.csv"); data = readtable(filename,"Encoding","UTF-8"); data.variety = categorical(data.variety);
Explore Wine Reviews Data
To get a feel for the data, let's visualize the text data using word clouds. First create a word cloud of the different grape varieties.figure; wordcloud(data.variety); title("Grape Varieties")
data:image/s3,"s3://crabby-images/dddd1/dddd1717896fdd1c222eba85b0970c4c07c53e8e" alt=""
labels = ["Gewürztraminer" "Chardonnay" "Nebbiolo" "Malbec"]; commonWords = ["wine" "Drink" "drink" "flavors" "finish" "palate" "notes" "aromas"]; figure for i = 1:4 subplot(2,2,i) label = labels(i); idx = data.variety == label; str = data.description(idx); documents = tokenizedDocument(str) documents = removeWords(documents,commonWords); documents = removeWords(documents,labels); str = joinWords(documents); wordcloud(str); title(label) end
data:image/s3,"s3://crabby-images/ce7a6/ce7a6b4e492b28cbecbf226fadeabfa10de510e4" alt=""
Prepare Text Data for Deep Learning
To classify text data using convolutions, we need to convert the text data into images. To do this, let's pad or truncate the observations to have a constant length S and convert the documents into sequences of word vectors of length C using the pretrained word embedding. We can then represent a document as a 1-by-S-by-C image (an image with height 1, width S, and C channels). To convert text data from a CSV file to images, I have a helper function at the end of this post called transformTextData. It creates a tabularTextDatastore object and uses the transform function with a custom transformation function that converts the data read from the tabularTextDatastore object to images for deep learning. In this example, we'll train a network with 1-D convolutional filters of varying widths. The width of each filter corresponds the number of words the filter can see (the n-gram length). The network has multiple branches of convolutional layers, so it can use different n-gram lengths.Clean up Data
Remove reviews without a label.idxMissing = ismissing(data.variety); data(idxMissing,:) = [];Remove any reviews where the grape variety is not one of the top 200 varieties in the data. (If you can't find a wine you like in the top 200 choices available, MATLAB probably can't help you.)
numClasses = 200; [classCounts,classNames] = histcounts(data.variety); [~,idx] = maxk(classCounts,numClasses); classNames = classNames(idx); idx = ismember(data.variety,classNames); data = data(idx,:);Remove the unused categories from the data.
data.variety = removecats(data.variety); classNames = categories(data.variety);
Partition Data
To help evaluate the performance of the network, let's partition the data into training, testing, and validation sets. Let's set aside 30% of the data for validation and testing (two partitions of 15%).cvp = cvpartition(data.variety,'HoldOut',0.3);
filenameTrain = fullfile("wine-reviews","wineReviews_" + numClasses + "_classes_Train.csv");
dataTrain = data(training(cvp),:);
writetable(dataTrain,filenameTrain,"Encoding","UTF-8");
dataHeldOut = data(test(cvp),:);
cvp = cvpartition(dataHeldOut.variety,'HoldOut',0.5);
filenameValidation = fullfile("wine-reviews","wineReviews_" + numClasses + "_classes_Validation.csv");
dataValidation = dataHeldOut(training(cvp),:);
writetable(dataValidation,filenameValidation,"Encoding","UTF-8");
filenameTest = fullfile("wine-reviews","wineReviews_" + numClasses + "_classes_Test.csv");
dataTest = dataHeldOut(test(cvp),:);
writetable(dataTest,filenameTest,"Encoding","UTF-8"); |
![]() |
miniBatchSize = 128; ttdsTrain = tabularTextDatastore(filenameTrain, ... 'SelectedVariableNames',["description" "variety"], ... 'ReadSize',miniBatchSize);
Specify Input Size
To input the text data into the network, we need to convert the text to images with a fixed size by padding or truncating the sequences. Ideally, we need to choose a value that minimizes both the amount of padding added to the sequences and the amount of data discarded due to truncation. Let's try to approximate the number of words in each review by counting the number of spaces and plotting the sequence lengths in a histogram.data:image/s3,"s3://crabby-images/28b7c/28b7ce012899145d286759fb9d25e7f47838f523" alt=""
sequenceLength = 80; tdsTrain = transform(ttdsTrain, @(data) transformTextData(data,sequenceLength,emb,classNames));The predictors are 1-by-S-by-C arrays, where S is the sequence length and C is the number of features. The responses are the categorical labels.
preview(tdsTrain)
data:image/s3,"s3://crabby-images/151f4/151f4d6e019aa84fe89335fe9b87d8d56ec08fa1" alt=""
ttdsValidation = tabularTextDatastore(filenameValidation, ... 'SelectedVariableNames',["description" "variety"], ... 'ReadSize',miniBatchSize); tdsValidation = transform(ttdsValidation, @(data) transformTextData(data,sequenceLength,emb,classNames))
Define Network Architecture
Let's now define the network architecture for the classification task, we can use deepNetworkDesigner to create the network.![]() |
The following describes the network architecture:
|
Specify Training Options
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize); options = trainingOptions('adam', ... 'MaxEpochs',50, ... 'Shuffle','never', ... 'MiniBatchSize',miniBatchSize, ... 'ValidationData',tdsValidation, ... 'ValidationFrequency',numIterationsPerEpoch, ... 'Plots','training-progress', ... 'Verbose',false);
Train Network
Finally, we can train the network! Let's train the network using the trainNetwork function. Depending on your hardware, this can take a long time. If you are having trouble with hardware or training, you can email Johanna for a copy of the trained network.caberNet = trainNetwork(tdsTrain,lgraph,options); save("caberNet.mat","caberNet")*Note from Johanna to Ieuan: I have a no-pun policy on this blog, and "caberNet" is borderline, so consider this a warning.
data:image/s3,"s3://crabby-images/6449f/6449f6c4b10dd114143dc69643f6c4d67c07b25d" alt=""
Test Network
Now the network is trained, we can test it using the held-out test data. First, let's create a transformed datastore containing the held-out test data.ttdsTest = tabularTextDatastore(filenameTest, ... 'SelectedVariableNames',["description" "variety"], ... 'ReadSize',miniBatchSize); tdsTest = transform(ttdsTest, @(data) transformTextData(data,sequenceLength,emb,classNames)); tbl = readall(ttdsTest); labelsTest = tbl.variety; YTest = categorical(labelsTest,classNames); YPred = classify(caberNet,tdsTest,'MiniBatchSize',miniBatchSize);accuracy = 0.6397 Here, we can see that the network is about 64% accurate on the held-out test data. Given the varied and subjective nature of wine tasting notes, I think this is a good score!
Make Predictions on New Data
The next step is to try out the classifier in the real world! Here are some notes from a recent wine tasting I attended.- “A crisp, golden coloured, bubbly wine. On the nose, there are aromas of citrus fruits alongside ripe stone fruits. On the palete, vibrant notes of apple and creamy textures.”
- “Straw coloured with a slight hint of green. Notes of peaches and nectarines. Rich and slightly sweet, intense notes of lychee. Strong minerality with some sweetness.”
- “Pale straw in colour with zesty citrus fruit on the nose. On the palate, intense gooseberry and crisp lime flavours with slight hints of oak.”
- “Deep golden colour. Strong aromas of toast and butter with strong hints of oak. On the palate, intense flavours of ripe banana and cooked apples.”
- “Very light bodied wine and pale in colour. Aromas of strawberries and forest fruits. Slightly oaked with slight tannins. Vibrant taste of red cherries.”
- “Medium bodied and brick-red in colour. On the nose, black cherry, and violet. Complex flavours including strong tannins coupled with flavours of black fruits and pepper.”
- “Deep ruby red in colour. Aromas of dark cherries, oak, and clove. Slightly smokey in taste with strong hints of blackberries and licorce.”
- “Strong aromas of blackcurrent and blueberries. A very big wine with high alcohol content. Intense flavour on the palate with a long finish. Vibrant flavors of black fruits and spices.”
sequencesNew = text2sequence(emb,str,sequenceLength); [YNewPred,scoresNew] = classify(caberNet,sequencesNew); tbl = table; tbl.PredictedVariety = YNewPred; tbl.TrueVariety = YNewTest
data:image/s3,"s3://crabby-images/2301a/2301a78049bb136b3e000ae465c6255ef5771782" alt=""
Visualize Network Predictions
For image classification problems, you can visualize the predictions of a network by taking an image, deleting a patch of the image, measure if the classification gets better or worse, then overlay the results on the image. In other words, if you delete a patch of the image and the classification gets worse, then that patch must contain features pertaining to the true class. Similarly, if you delete a patch of the image, and the classification gets better, then that patch must contain features pertaining to a different class and thus confuses the classifier. We can do this using the occlusionSensitivity function. Let's select one of the observations of the text data where the network has predicted the correct label.idxObservation = 2; strNew = str(idxObservation) labelTest = YNewTest(idxObservation)strNew = "Slightly straw colored with a hint of greenness. Notes of peaches and nectarines. Rich and slightly sweet, intense notes of lychee. A soft finish with some sweetness." labelTest = "Gewürztraminer" Let's view the occlusion sensitivity scores using the function plotOcclusion, which I have listed at the end of the blog post. This shows which patches of words contribute most to the prediction.
h = figure; h.Position(3) = 1.5 * h.Position(3); plotOcclusion(caberNet,emb,strNew,sequenceLength,labelTest)
data:image/s3,"s3://crabby-images/0af44/0af44542e6f94b74b88bb196cc9c5c8021c076b4" alt=""
idxObservation = 8; strNew = str(idxObservation) labelTest = YNewTest(idxObservation)strNew = "Strong aromas of black cherry. Powerful taste with a high alcohol content. Rich flavor with strong tannins and a long finish. Vibrant flavors of cherries and a hint of pepper." labelTest = "Zinfandel"
h = figure; h.Position(3) = 1.5 * h.Position(3); h.Position(4) = 1.5 * h.Position(4); plotOcclusion(caberNet,emb,strNew,sequenceLength,labelTest)
data:image/s3,"s3://crabby-images/011e0/011e0cef95cff9820db0c5b5c0897275981e0324" alt=""
Copyright 2018 The MathWorks, Inc. Get the MATLAB code
- Category:
- Deep Learning
Comments
To leave a comment, please click here to sign in to your MathWorks Account or create a new one.