In November, Yarden Cohen and I gave a talk at PyData NYC about segmenting vocalizations with neural networks. More specifically, we are segmenting birdsong into elements called syllables, although in principle the same thing can be done with human speech. That is a harder problem, though, because in human speech the beginning and end of a syllable is not as clearly defined as it is in birdsong. Because the song of many bird species is easily segmented into syllables, we proposed that birdsong could provide a test bed for benchmarking different neural net models for segmentation.
We have some scientific questions about birdsong that we are using the neural network to answer. I won't go into the details of that here, but I thought it might be interesting for people to see the process we're going through to make sure the network does what we want it to. I find a lot of tutorials on the web that walk users through applying an established architecture to a toy problem, but not a lot of writing about the process of developing a neural network for a specific real-world application. So here's a sneak peek at the results I've been sending to Yarden. Full disclosure: I'm also applying for the Google AI Residency and want to demonstrate what I've been working on.
We are bravely coding in the open. If you want to see the code I've
developed, with scripts for creating learning curves like those shown below,
you can check out my fork of Yarden's code:
If you are a songbird researcher and you'd like to try out the network on your own data, you might prefer to work with the Jupyter notebook that Yarden originally shared:
The goal: segment birdsong into syllables
Let's look at some example songs to see what I mean by segmenting song into syllables. Below are spectrograms of four songs from four birds, taken from an open repository of Bengalese finch song shared by the lab that I work in, the Sober lab. If you haven't seen a spectrogram before, it is (in this context) an image made from an audio signal, with frequencies on the Y axis and time on the X axis.
Notice that each bird has a handful of repeating elements in its song that we call syllables. Often when we want to study a bird's song, we name these elements by giving them labels (as shown in the bottom subplot of each figure). The labels are arbitrary, by which I mean that just because one syllable is labeled 'a' for bird 1 does not mean that it resembles syllable 'a' from bird 2. All these birds are from the same species but each individual has its own song. (Typically a bird's adult song resembles the song of the tutor that it learned from as a juvenile.) We want our neural network to generalize across birds .
How birdsong is typically segmented
Typically, scientists that study birdsong segment it into syllables with a simple algorithm as shown in the gif below:
- measure the amplitude of the sound (as plotted in the axis under the spectrogram)
- set a threshold and find all time points at which the amplitude is above that threshold. Call each continuous series of time points above threshold a syllable segment (pink lines); each series below the threshold is a silent period segment
- set a minimum silent period duration; if the duration of any silent period segments are less than that minimum, they are removed and the two syllable segments that were on either side of it are joined (red lines).
- set a minimum syllable duration; if any syllable segments are less than that minimum, they are discarded. The remaining syllable segments (dark red lines) are then given labels.
For many species of songbird, this algorithm works fine. But if there is background noise, this algorithm can fail--e.g. when recording birds in the wild or when there are other sounds present during a behavioral experiment. You can see this in the animation above, where there are two syllables between 1.9 and 2.1 seconds, but the algorithm lumps them into one long segment, because of some low-frequency background noise present in the recording. This algorithm also won't work for species that have elements of their song not easily segmented into syllables by simply finding where the amplitude crosses some threshold.
Network architecture: hybrid convolutional and bidirectional LSTM (CNN-biLSTM)
So we would like some machine learning magic to segment song for us, in a way that is robust to noise. Neural networks may not be the only algorithm that can do this, but they are definitely one of the first that come to mind. There are several neural network architectures that could be applied to segmenting song, and I'll discuss a couple others at the end of this blog post. The approach that Yarden developed combines convolutional layers, which are typically used for image recognition, with recurrent layers, which are often used to capture dependencies across time. For previous work that employed similar architectures, check out note .
Here's a schematic that shows the network architecture with a little more detail.
Roughly speaking, the network consists of two convolutional layers (where one convolutional layer includes both convolution and a max. pooling operations) followed by the recurrent layer, a bidirectional long short term memory layer. I'll call this architecture "CNN-biLSTM".
As an input, the network takes spectrograms, and then it outputs a class label for each time bin in the spectrogram. This implicitly gives us segmentation, because we find each continuous series of one label and call that a segment.
To measure performance of the CNN-biLSTM, we're looking at learning curves, which are simply plots of error as a function of training set size. I'll talk more about the metrics we use for error below.
You might be familiar with learning curves from their typical use within a supervised learning context: to evaluate whether the classifier you're using is underfitting or overfitting the training data, and whether more data could improve the classifier's accuracy.
For our study, we're plotting learning curves to give someone studying songbirds an idea of how much training data they will need labeled by hand to achieve a desired accuracy. It is more informative to see the accuracy across a range of training set sizes, rather than just one or two.
Another reason to generate learning curves is to fit actual curves to the data points,
as proposed by Cortes et al. 1994. See note  for some
more info about why you'd want to do this. Because we have a manuscript in preparation, I
can't share actual values here. Instead I fit curves to the values and present these as a
schematized form of the learning curves I've generated. If you have questions about the results,
please feel free to contact me. I'm using the learning curves in this post mainly to
illustrate why I chose to explore which hyperparameters matter for this network.
If you want to play around with fitting learning curves, the functions I use are in the curvefit module adapted from code I wrote for a talk I gave about the Cortes et al. paper previously for Jupyter Day Atlanta 2016.
The frame error rate shows the network learns to accurately classify each time bin
The first metric we looked at I'll call frame error rate . As you might guess from the name, this metric looks at the label for every frame, in this case every time bin of the spectrogram, and asks whether it is correct. The metric gives you a sense of how well you're doing overall: if you have very low frame error rate, you can assume you are correctly finding the right segments, since they consist of sequences of frames.
This figure presents learning curves for all four birds whose song is shown above: Again, I can't give numbers here because of the manuscript in preparation, but as shown by the curves fit to those numbers, the model learns to accurately classify each frame, i.e. time bin, from each bird.
Notice however that the curve for bird 3 appears to have a higher frame error rate at asymptote.
The syllable error rate shows the network accurately segments and classifies syllables
Frame error does not tell us directly how well we are finding and classifying segments. To measure how well words are classified in speech recognition tasks, researchers use the word error rate. Here I measure how well segments are classified as syllables, and by analogy, I call that the syllable error rate.
The syllable error rate is a normalized edit distance between the sequence of labels given to syllables in a song by a human annotator and the sequence of labels predicted by a machine learning algorithm. The edit distance is the number of insertions, deletions, and substitutions required to convert the predicted labels into the actual labels from the ground truth set. (There are other forms of edit distance but that is the one most widely used.) The distance is normalized by dividing it by the number of syllables in the ground truth set. Normalization makes it possible to compare across training sets of different sizes, e.g., if one bird sings many more syllable types than another.
This figure again presents learning curves for all four birds, but this time plotting the syllable error rate.
Although, again, I can't give exact numbers here, I can say that the network achieves a syllable error rate less than 0.46 with much less than eight minutes of training data, which are the values reported for the network of Koumura Okanoya 2016. (In their paper they call this metric the "note error rate") These results suggest the network will give us good enough results to answer the scientific question we are interested in, although of course we will need to test quantitatively how error will affect our results. In addition, Yarden has found that he can achieve near perfect accuracy by simply taking a "majority vote" for each segment that has multiple labels within it.
What goes wrong when the network fails to segment properly?
These results are good, but still, what is the network getting wrong when it misclassifies some time bin? I did a little more work to get at that question. In particular I focused on bird 3, which had the highest frame error and syllable error rates.
Frames from squawky "intro" notes are often mis-classified, as shown by a confusion matrix
I began by calculating a confusion matrix for all frames. As seen in the bottom row of the matrix, most of the error arises from syllable 'm' being mislabeled as syllable 'i'
If you look at the spectrogram of bird 3's song again, you'll see that both 'i' and 'm' are low-amplitude syllables with high-entropy: what power there is, is distributed randomly across the frequency spectrum. To put that in less technical terms: these syllables are squawky, not pitchy. Often they are referred to by people that study birdsong as "introductory notes". These syllables are often hard to differentiate, even for scientists that spend all day staring at birdsong.
It is also apparent from the leftmost column of the matrix that many errors are due to some frames of syllables being misclassified as silent periods. A possible reason for this is that the onsets and offsets of syllables are harder to detect, and so there will be a lot of variation in the training data.
Onsets and offsets are misclassified across syllables, but squawky notes are misclassified throughout
But from the confusion matrix we cannot tell where in the segment the network makes these misclassifications. To visualize where in each segment the network makes errors, I assigned a time to each misclassified time bin. The unit of time I used is percent of the total duration of the syllable. In this way I controlled for the variance in duration of each syllable from one rendition to another. (The bird does not sing the syllable for the exact same duration every time.) I then plotted a histogram of the errors with times on the x axis and the number of misclassifications on the y axis.
As shown below, it is the case that across all syllables, the greatest amount of errors occur in the first or last time bin. Notably, though, for syllable 'i' and 'm' (red histograms), there is a high rate of error throughout the syllable.
The drop in error decays exponentially as a function of time steps.
Okay, now we know what the network misclassifies and where, but ... why? One possibility might be that the LSTM layer of the network does not capture enough of the long term dependencies that would help the network classify each time bin. If this is true, then the error should vary with the number of time steps that the network uses for each prediction.
To test this possibility, I re-ran the learning curve script for bird 3, using the exact same training data while changing the number of time steps. As expected, the frame error increases when we drop the number of time steps to a low number like 10. A bit surprisingly though, the rate at which the frame rate decreases does not seem to be linearly related to the rate at which we raise the number of time steps. In the learning curves shown above, the networks were trained with 88 time steps. Going from 88 to 150 time steps does not yield much of a drop in frame error, although this is a larger increase in time steps than going from 10 to 50.
The same holds true for the syllable error rate.
To get a better sense of the relation between time steps and error rate, I fit the learning curves from plots of frame error, and then plotted the values for the asymptotic error parameter as a function of time steps.
From this plot it appears that the decrease in error decays exponentially as a function of the number of time steps. It will not be a surprise to anyone that has studied recurrent neural networks that there is a ceiling effect, a point of diminishing returns past which we do not decrease error even though we substantially increase training time. However it is valuable to put a specific number to that point. I would guesstimate from this plot that 88 time steps is near the sweet spot in terms of the trade off between error rate and training time. 88 time steps equals about 180 milliseconds, which is 2-3 syllables in Bengalese Finch song.
Based on my results, we will be able to answer the questions we have about birdsong using the
CNN-biLSTM network architecture. The CNN-biLSTM network achieves a lower syllable error rate
than previously proposed architectures, and it does so with less training data.
In the one case where the network did misclassify two syllables, it seemed to do so because the syllables themselves are similar and hard to distinguish. By systematically varying the number of time steps used by the LSTM layer while holding everything else about the network architecture and training data constant, I showed that we are already near the maximum accuracy the network can achieve for segmenting and classifying those two syllables. Of course that still leaves other hyperparameters to explore, such as the shape of the filters in the convolutional layers. I'll save that for another blog post.
As Yarden pointed out during our PyData NYC talk, a spectrogram is an image, but it is an image with correlations across time. In the big picture, it remains unknown whether networks with a recurrent layer will always outperform convolutional architectures, such as fully convolutional networks used for image segmentation. Dilated convolutions are also supposed to be able to capture context across multiple scales.
To my knowledge, no theoretical work has yet put any boundaries on the abilities of convolutional or recurrent networks to segment images. It would be particularly interesting to know
the extent to which convolutional networks can capture correlations across time. I hope I've convinced you that birdsong provides a good starting point to answer these questions.
Yarden Cohen for helpful comments. NVIDIA for the GPU grant that lets us run twice as many models.
1. Generalizing across birds is different from a machine learning algorithm generalizing from training data to unseen data. I do not put a number here to how well the network generalizes in the sense I'm using it, across individuals. I will point out that a lot of work has been done to derive metrics for how similar songs are within and across birds. To my knowledge, this has not been done for human speech, and no work has been done measuring how well neural networks for speech recognition generalize across speakers.
2. The CNN-biLSTM is based on similar architectures described in these papers:
S. Böck and M. Schedl, "Polyphonic piano note transcription with recurrent neural networks," 2012 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Kyoto, 2012, pp. 121-124.
Parascandolo, Giambattista, Heikki Huttunen, and Tuomas Virtanen. "Recurrent neural networks for polyphonic sound event detection in real life recordings." Acoustics, Speech and Signal Processing (ICASSP), 2016 IEEE International Conference on. IEEE, 2016.
You could also think of this as an extension of the deep neural networks-hidden Markov Model (HMM) approach with LSTMs in place of the HMM.
3. In Cortes et al. 1994, the authors propose fitting curves in order to estimate which of two models will give better accuracy without actually training both models on a very large data set, which would be computationally expensive. Instead they measure the accuracy for a range of smaller training data sets, and then fit a curve to those accuracies. The parameters of the fit curve include the asymptote, which you can think of as the accuracy given infinite training data. You can then use points on the fit curve line to predict which model will do better when trained with the very large data set.