Deep Learning Supplemental

Introduction

Deep learning is a powerful new technique for machine learning that has exploded in the past decade, due to increases in computational capacity, the availability of big data, and the development of new algorithms and architectures for more effective training. Since spike-sorting is a machine learning task - in particular, a classification task, which is an area deep learning has excelled - there is a growing interest in seeing whether deep learning tools can be used to improve spike-sorting algorithms. We undertook an exploratory investigation at the end of Spruce.

However, spike-sorting is not a perfect fit for deep learning, for a few key reasons. First, deep learning thrives when there is a stable, consistent task; however, the circumstances of spike-sorting change from dataset to dataset, both in the noise characteristics as well as in the number and types of units that are firing. Second, deep learning takes a long time to train, and uses a large quantity of training data; however, computational efficiency is an important concern for spike-sorting, and each electrophysiological recording only has a limited number of good spikes in it. Third, although certain loss functions or methods such as bootstrapping allow for unsupervised deep learning, deep learning is best suited to tasks where the training data is labeled; however, very little electrophysiological data comes with ground truth information attached, and that which does is either synthetic or limited (for instance, by having only one cellular unit’s worth of patch data).

Prior work applying deep learning to spike sorting has, therefore, tended to focus on specific and limited integrations of deep learning into existing pipelines. To overcome the heterogeneity problem, the deep neural network is generally trained from scratch on each single recording. As the training is part of the sorting, this creates an efficiency problem, which is simply ignored. To overcome the supervision problem, the standard techniques - bootstrapping and careful loss function design - are used, and their shortcomings are accepted. By not using ground truth data in the training process, this has the added benefit of making it available for validation.

In 2017, Yang et al. debuted PCAnet, one of the first attempts of this type, although it did not contain a deep neural network per se. They used PCA followed by k-means clusters to initialize a set of templates, then used those templates to fit a spike-train model. This approach was picked up in 2019 by Park et al., who took the 10% most confident predictions from PCA + k-means as training data for their deep model, then used the deep model to sort the remaining 90%.

In 2020, Racz et al used two different deep networks for the task of spike-sorting - one for detecting the existence of a spiking event, and one for classifying the resulting spikes. They have a BCI application in mind, which requires efficient inference, but they avoid the slow-training problem by training a deep neural network ahead of time on a particular probe’s data, with the intention of using it on the same probe later. They acquire ground truth training data by using Kilosort.

All of these approaches hew relatively closely to the “classical pipeline” approach to spike-sorting; that is, detection, followed by feature extraction, followed by classification. The classical pipeline is poorly-suited to high-dimensional MEA data, because it cannot use insights from later stages to inform prior stages, which is often necessary to deal with overlapping spikes or to detect weak spikes in high-noise environments. However, although modern algorithms like Kilosort are moving away from the classical pipeline, these experiments were done inside the classical pipeline due to the convenience of its modularity.

Approach

We used a fixed spike detection method across all experiments. First, the MEA data is preprocessed with a 300-6000 Hz bandpass filter, followed by median dereferencing (removing the median across channels at each point in time). All peaks which exceed 6 times the median absolute deviation across all channels, and which are not within 1 ms of a higher peak (on any channel), are retained as the center of an “event” which is 2 ms wide and contains data from all channels.

We compared five different algorithms for feature extraction and classification:

  1. Template matching k-means. Each spike event is a point in a high-dimensional vector space, and in this approach, we simply cluster by k-means in the high-dimensional space. The center of each cluster is therefore a “template” event which is the average of all events assigned to it, and events are assigned to clusters according to mean squared template matching error.

  2. PCA plus k-means. Principal component analysis is used to project each spike event to a 5-dimensional space, as k-means is known to be sensitive to outliers in high dimensional spaces. Then, k-means is used to cluster spikes in the 5-dimensional space.

  3. Deep encoder plus k-means. A deep convolutional autoencoder (3 convolutional layers, 2 pooling layers, and 1 dense linear layer on each side) is used to encode each spike event into a 5-dimensional space, and then decode it back up to a spiking event. The autoencoder is, following an approach inspired by Xie et al., layerwise pretrained and then finetuned. k-means is used to identify cluster centers in the projected space after the encoder is finetuned, and the decoder is not used for classification. (We modified Xie et al.’s approach in two ways; we added noise to the finetuning process in the 5-dimensional space, to make the encoding more robust, and we eliminated the third training stage, of simultaneously training the encoder and the cluster centers, because it was not found to be helpful in our limited experimentation.)

  4. Kilosort classification. The three k-means methods above process isolated spiking events after a single detection stage; however, Kilosort utilizes the temporal sequencing of spiking events, and has its own detection pipeline. To put Kilosort’s classification engine on an even footing, we synthesized a fake waveform by concatenating the isolated events in a random order (with smooth transitions between them as a buffer). This ensures that any alignment mistakes in our detection process, or any information lost through the random reordering, will also affect the performance of Kilosort. Kilosort was run in 18 trials, with 3 trials on each of 6 different synthetic waveforms (3 on f2n1, 3 on f2n3 data).

  5. KiloSort2 classification. Each time Kilosort was run on a synthetic waveform, KiloSort2 was also run. Again, the synthetic waveform ensured that the classification methods were fairly compared.

For the algorithms which used k-means, we ran 15+ trials with each of 20, 50, and 100 clusters. For Kilosort and KiloSort2, we ran 9 trials.

Results

The three different k-means algorithms - deep encoder, template matching (i.e. no embedding), and PCA - all performed to roughly similar quality, as shown in the figure below. Increasing the number of clusters used for k-means made the clustering more selective, generally causing an increase in precision and a decrease in recall, but the three different embedding algorithms all appear to lie on roughly the same tradeoff curve. It is particularly notable that template matching is about as good as PCA for feature extraction, as it suggests that the dimensionality of the feature space (in this case, 5 dimensions) was not a constraint on performance quality.

_images/f2n1_rough.png
_images/f2n3_rough.png

Plots of the average performance across 10-20 trials of the various feature extraction and clustering algorithms on the f2n1 dataset (top) and f2n3 dataset (bottom). Data points for algorithms using k-means are annotated with the value of k (number of clusters) used. “Reference (KS/KS2)” refers to running Kilosort and KiloSort2 on the raw datasets (i.e. without resynthesizing), with the top 58 channels retained.

The Kilosort-based algorithms (KS and KS2), on the other hand, performed notably better than the k-means algorithms. Kilosort’s core is a modified template matching algorithm, but the cost function used to fit templates is more sophisticated than simple k-means, and the templates are allowed to vary over the course of the dataset. For the higher-quality dataset (f2n1), KS and KS2 performed about as well on the resynthesized data as the original dataset, indicating that template variation was not a necessary feature (as resynthesis, by including random reordering, precludes any positive effect of template variation). However, for the lower-quality dataset (f2n3), the opposite is the case - the resynthesized dataset is significantly harder than the original, and, because of the consistently high recall, any explanations based on different detection algorithms are ruled out.

These results are interesting primarily for their confirmation of the fact that the traditional spike-sorting pipeline (detection, feature extraction, clustering) is not as powerful as more sophisticated approaches, because Kilosort, even on resynthesized data, outperforms all of the traditional-pipeline algorithms. However, for the motivating question - whether deep learning can improve on state-of-the-art performance - these results are inconclusive, as deep learning was only investigated within the traditional pipeline framework, and performed about as well as state-of-the-art methods within those limits.