Source code for autoreject.ransac

"""RANSAC code

The code is adopted from the PREP pipeline written in MATLAB:
https://github.com/VisLab/EEG-Clean-Tools. This implementation
also works for MEG data.
"""

# Authors: Mainak Jas <mainak.jas@telecom-paristech.fr>

import numpy as np
from joblib import Parallel, delayed

import mne
from mne.channels.interpolation import _make_interpolation_matrix
from mne.parallel import check_n_jobs
from mne.utils import check_random_state

from .utils import _pbar, _handle_picks
from .utils import _check_data


def _iterate_epochs(ransac, epochs, idxs, ch_subset, verbose):
    ransac.mappings_ = ransac._get_mappings(epochs, ch_subset)
    n_channels = len(ransac.picks)
    corrs = np.zeros((len(idxs), n_channels))
    for idx, _ in enumerate(_pbar(idxs, desc='Iterating epochs',
                                  verbose=verbose)):
        ransac.corr_ = ransac._compute_correlations(
            epochs.get_data()[idx, ransac.picks])
        corrs[idx, :] = ransac.corr_
    return corrs


def _get_channel_type(epochs, picks):
    picked_info = mne.io.pick.pick_info(epochs.info, picks)
    ch_types_picked = {
        mne.io.meas_info.channel_type(picked_info, idx)
        for idx in range(len(picks))}
    invalid_ch_types_present = [key for key in ch_types_picked
                                if key not in ['mag', 'grad', 'eeg'] and
                                key in epochs]
    if len(invalid_ch_types_present) > 0:
        raise ValueError('Invalid channel types present in epochs.'
                         ' Expected ONLY `meg` or ONLY `eeg`. Got %s'
                         % ', '.join(invalid_ch_types_present))

    has_meg = any(kk in ch_types_picked for kk in ('mag', 'grad'))
    if 'eeg' in ch_types_picked and has_meg:
        raise ValueError('Got mixed channel types. Pick either eeg or meg'
                         ' but not both')
    if 'eeg' in ch_types_picked:
        return 'eeg'
    elif has_meg:
        return 'meg'
    else:
        raise ValueError('Oh no! Your channel type is not known.')


[docs]class Ransac(object): """RANSAC algorithm to find bad sensors and repair them."""
[docs] def __init__(self, n_resample=50, min_channels=0.25, min_corr=0.75, unbroken_time=0.4, n_jobs=1, random_state=435656, picks=None, verbose=True): """Implements RAndom SAmple Consensus (RANSAC) method to detect bad sensors. Parameters ---------- n_resample : int Number of times the sensors are resampled. min_channels : float Fraction of sensors for robust reconstruction. min_corr : float Cut-off correlation for abnormal wrt neighbours. unbroken_time : float Cut-off fraction of time sensor can have poor RANSAC predictability. n_jobs : int Number of parallel jobs. random_state : None | int | np.random.RandomState The seed of the pseudo random number generator to use. picks : str | list | slice | None Channels to include. Slices and lists of integers will be interpreted as channel indices. In lists, channel *name* strings (e.g., ``['MEG0111', 'MEG2623']``) will pick the given channels. None (default) will pick data channels {'meg', 'eeg'}. Note that channels in ``info['bads']`` *will be included* if their names or indices are explicitly provided. verbose : boolean The verbosity of progress messages. If False, suppress all output messages. Notes ----- The window_size is automatically set to the epoch length. References ---------- [1] Bigdely-Shamlo, Nima, et al. "The PREP pipeline: standardized preprocessing for large-scale EEG analysis." Frontiers in neuroinformatics 9 (2015). [2] Mainak Jas, Denis Engemann, Yousra Bekhti, Federico Raimondo, and Alexandre Gramfort, "Autoreject: Automated artifact rejection for MEG and EEG." arXiv preprint arXiv:1612.08194, 2016. """ self.n_resample = n_resample self.min_channels = min_channels self.min_corr = min_corr self.unbroken_time = unbroken_time self.n_jobs = n_jobs self.random_state = random_state self.verbose = verbose self.picks = picks
def _get_random_subsets(self, info, random_state): """ Get random channels""" # have to set the seed here rng = check_random_state(random_state) picked_info = mne.io.pick.pick_info(info, self.picks) n_channels = len(picked_info['ch_names']) # number of channels to interpolate from n_samples = int(np.round(self.min_channels * n_channels)) # get picks for resamples picks = [] for idx in range(self.n_resample): pick = rng.permutation(n_channels)[:n_samples].copy() picks.append(pick) # get channel subsets as lists ch_subsets = [] for pick in picks: ch_subsets.append([picked_info['ch_names'][p] for p in pick]) return ch_subsets def _get_mappings(self, inst, ch_subsets): from .utils import _fast_map_meg_channels picked_info = mne.io.pick.pick_info(inst.info, self.picks) pos = np.array([ch['loc'][:3] for ch in picked_info['chs']]) ch_names = picked_info['ch_names'] n_channels = len(ch_names) pick_to = range(n_channels) mappings = [] # Try different channel subsets for idx in range(len(ch_subsets)): # don't do the following as it will sort the channels! # pick_from = pick_channels(ch_names, ch_subsets[idx]) pick_from = np.array([ch_names.index(name) for name in ch_subsets[idx]]) mapping = np.zeros((n_channels, n_channels)) if self.ch_type == 'meg': mapping[:, pick_from] = _fast_map_meg_channels( picked_info.copy(), pick_from, pick_to) elif self.ch_type == 'eeg': mapping[:, pick_from] = _make_interpolation_matrix( pos[pick_from], pos[pick_to], alpha=1e-5) mappings.append(mapping) mappings = np.concatenate(mappings) return mappings def _compute_correlations(self, data): """Compute correlation between prediction and real data.""" mappings = self.mappings_ n_channels, n_times = data.shape # get the predictions y_pred = data.T.dot(mappings.T) y_pred = y_pred.reshape((n_times, len(self.picks), self.n_resample), order='F') # pool them using median # XXX: weird that original implementation sorts and takes middle value. # Isn't really the median if n_resample even y_pred = np.median(y_pred, axis=-1) # compute correlation num = np.sum(data.T * y_pred, axis=0) denom = (np.sqrt(np.sum(data.T ** 2, axis=0)) * np.sqrt(np.sum(y_pred ** 2, axis=0))) corr = num / denom return corr def fit(self, epochs): self.picks = _handle_picks(info=epochs.info, picks=self.picks) _check_data(epochs, picks=self.picks, ch_constraint='single_channel_type', verbose=self.verbose) self.ch_type = _get_channel_type(epochs, self.picks) n_epochs = len(epochs) n_jobs = check_n_jobs(self.n_jobs) parallel = Parallel(n_jobs, verbose=10) my_iterator = delayed(_iterate_epochs) if self.verbose is not False and self.n_jobs > 1: print('Iterating epochs ...') verbose = False if self.n_jobs > 1 else self.verbose rng = check_random_state(self.random_state) base_random_state = rng.randint(np.iinfo(np.int16).max) self.ch_subsets_ = [self._get_random_subsets( epochs.info, base_random_state + random_state) for random_state in np.arange(0, n_epochs, n_jobs)] epoch_idxs = np.array_split(np.arange(n_epochs), n_jobs) corrs = parallel(my_iterator(self, epochs, idxs, chs, verbose) for idxs, chs in zip(epoch_idxs, self.ch_subsets_)) self.corr_ = np.concatenate(corrs) if self.verbose is not False and self.n_jobs > 1: print('[Done]') # compute how many windows is a sensor RANSAC-bad self.bad_log = np.zeros_like(self.corr_) self.bad_log[self.corr_ < self.min_corr] = 1 bad_log = self.bad_log.sum(axis=0) bad_idx = np.where(bad_log > self.unbroken_time * n_epochs)[0] if len(bad_idx) > 0: self.bad_chs_ = [ epochs.info['ch_names'][self.picks[p]] for p in bad_idx] else: self.bad_chs_ = [] return self def transform(self, epochs): epochs = epochs.copy() _check_data(epochs, picks=self.picks, ch_constraint='single_channel_type', verbose=self.verbose) epochs.info['bads'] = self.bad_chs_ epochs.interpolate_bads(reset_bads=True) return epochs def fit_transform(self, epochs): return self.fit(epochs).transform(epochs)
Fork me on GitHub