Building a High Five Counter Using Deep Learning

This post is from Brian Douglas, YouTube Content Creator for Control Systems and Deep Learning Applications

For about a decade, I've wanted to implement this silly idea I had of measuring the acceleration of a person's hand to count the number of times they high five throughout the day. I wasn't sure how to accomplish this using the rule-based approaches to algorithm development that I was familiar with and so the project sat on hold. It was only while I was making the MATLAB Tech Talk video series on Deep Learning that I realized that Deep Learning was perfect for solving this problem!

The topic for the 4th video in the series was transfer learning and it turned out that was the key concept that I needed for me to quickly get a high five counting algorithm up and running. In this blog post, I will walk through the details of the code I wrote and the tools I used to get my high five counter working for that video. Hopefully, you can use this as a starting point to solve those difficult classification problems that you've been sitting on for the last 10 years as well.

This post is divided into the following sections:

• An overview of the hardware
• Reading the accelerometer in MATLAB
• Data preprocessing and the scalogram
• Creating the training data
• Transfer learning and GoogLeNet
• Training the network
• Testing the high five counter
• So, let's get to it!

An Overview of the Hardware

The hardware set up is pretty simple. I have an accelerometer that is connected to an Arduino Uno through an I2C bus. The Arduino is then connected to my computer through USB.

To sense acceleration, I'm using the MPU-9250. It's a 9 degree of freedom inertial measurement unit from TDK InvenSense. Instead of integrating the chip into my own custom circuit design, I am using a breakout board that exposes power, ground, and the I2C communication pins. The only reason I am using this particular chip is because I already had one lying around, but any accelerometer will work as long as it's small enough to be moved around quickly by hand.

You can see that my hardware setup was pretty crudely constructed with a breadboard and some jumper wires but I think it's kind of nice that you don't need to set up anything too fancy for this to work.

Reading the Accelerometer in MATLAB

To read acceleration from the MPU-9250 through the Arduino, I'm using the MATLAB Support Package for Arduino Hardware. This package allows you to communicate with an Arduino without having to compile code for it. Plus, there is a built in mpu9250 function that allows you to read the sensor with a one-line command.

It only takes three lines of code to connect to the Arduino, instantiate an MPU9250 object, and read the accelerometer.

Data Preprocessing and the Scalogram

If you watched the 4th video in the Tech Talk series on deep learning, you'll know that I opted to convert the 3-axis acceleration data into an image to take advantage of GoogLeNet - a network trained to recognize images. In particular, I used a continuous wavelet transform to create a scalogram.

A scalogram is a time-frequency representation that is suitable for signals that exist at multiple scales. That is, signals that are low frequency and slowly varying, but then are occasionally interrupted with high frequency transients. It turns out they are useful for visualizing acceleration data for the occasional high frequency high five within an otherwise slowly moving hand.

A cleaned-up version of the MATLAB code that I used to make the above plot is in the following collapsible block.

Click to Expand Code

Note that this code uses a function called cwtfilterbank to create the scalogram which is part of the Wavelet Toolbox. If you don't have access to this toolbox and you don't want to write the code yourself, try giving another type of time-frequency visualization a shot. Maybe a spectrogram will work or some other algorithm that you come up with. Whatever you choose, the idea here is that we're trying to create an image that will make the unique and identifiable features of a high five pattern obvious. I have shown that a scalogram works, but other methods may work as well.

Creating the Training Data

To train a network to recognize high fives, we need multiple examples of what a high five looks like and what a high five doesn't look like. Since we will be starting from a pre-trained network, we won't need as many training examples as we would if we were training a network from scratch. I don't know exactly how much training data is needed to fully capture the solution space for all possible high fives, however, I collected data for 100 high fives and 100 non-high fives and that seemed to work pretty well. I suspect I could have gotten away with less for the video I was making, but I think if I was really creating a product I would have used a lot more examples. You can play around with the amount of labeled training data and see how it affects the result.

Collecting 200 images seems like a lot of work, but I wrote a script that cycles through them one after another and saves the images in the appropriate folder. I ran the following script twice; once with the 'high_five' label with the images being saved to the data/high_five folder and once with the 'no_high_five' label with the images being saved to the data/no_high_five folder.

Click to Expand Code

After running the script, I manually went through my training data and removed images that I thought would corrupt the training. These were images where the high five wasn't in the middle of the frame or images where I knew that I did a poor high five motion. In the below gif, I deleted high five image 49 because it wasn't in the center of the frame.

Transfer Learning and GoogLeNet

With all of my training data in their appropriate folders, the next step is to set up the network. For this part, I was following along with the MATLAB example Classify Time Series Using Wavelet Analysis and Deep Learning, except, rather than run everything through a MATLAB script, I found it easier to set up and train the network using the Deep Network Designer app.

I started from the pre-trained GoogLeNet to take advantage of all of the knowledge this network has for recognizing objects in images. GoogLeNet was trained to recognize things like fish and hotdogs in images - clearly not what I'm looking for - but this is where transfer learning is useful. With transfer learning, I can keep much of the existing network in place and only replace two layers at the end of the network which combine those generic features into the specific patterns I'm looking for. Then when I retrain the network, it's pretty much just these two layers that need to be trained, which is why training is so much faster with transfer learning.

I recommend you follow along with the MATLAB example I used or watch the Tech Talk if you'd like to know exactly how I replaced the layers and what training parameters I used, however, this again is a great place for you to experiment with something different. You can try to start from a different pretrained network like SqueezeNet, or you can replace more layers in GoogLeNet, or change the training parameters. There are a lot of options here and I think deviating from what I did can help you develop some intuition for how all of these variables affect the result.

Training the Network

With the network ready to go, training in the Deep Network Designer App is very simple. In the data tab, I imported the training data by selecting the folder where I saved the set of high five and no high five images. I also set aside 20 percent of the images to be used for validation during the training process.

Then in the training tab, I set my training options. Here, I was using the same options that were used in the MATLAB example that I was following, however, once again I encourage you to play around with some of these values and see how they impact the results.

Training took just over 4 minutes on my single CPU and reached about 97% validation accuracy. Not too bad for a couple hours of work!

Testing the High Five Counter

Now that I have a trained network, I use the function classify from the Deep Learning Toolbox to pass in the scalogram at each sample time and have the network return a label. If the returned label was "high_five" I increment a counter. To keep from counting the same high five several times as the acceleration data streaked across the whole buffer, I added a timeout that would not count a new high five unless it had been at least 2 seconds since the previous high five.

Below is a cleaned up version of the code that I used to count high fives.

Click to Expand Code

And here it is in action!

|

コメント

コメントを残すには、ここ をクリックして MathWorks アカウントにサインインするか新しい MathWorks アカウントを作成します。