Source code for autoreject.autoreject

"""Automated rejection and repair of trials in M/EEG."""

# Authors: Mainak Jas <mainak.jas@telecom-paristech.fr>
#          Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
#          Denis A. Engemann <denis.engemann@gmail.com>

import os.path as op
from functools import partial

import numpy as np
from scipy.stats.distributions import uniform

from joblib import Parallel, delayed

import mne
from mne import pick_types
from mne.externals.h5io import read_hdf5, write_hdf5
from mne.viz import plot_epochs as plot_mne_epochs

from sklearn.base import BaseEstimator
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import KFold, StratifiedShuffleSplit
from sklearn.model_selection import cross_val_score

from .utils import (_clean_by_interp, interpolate_bads, _get_epochs_type,
                    _pbar, _handle_picks, _check_data, _compute_dots,
                    _get_picks_by_type, _pprint)
from .bayesopt import expected_improvement, bayes_opt

_INIT_PARAMS = ('consensus', 'n_interpolate', 'picks',
                'verbose', 'n_jobs', 'cv', 'random_state',
                'thresh_method')

_FIT_PARAMS = ('threshes_', 'n_interpolate_', 'consensus_',
               'dots', 'picks_', 'loss_')


def _slicemean(obj, this_slice, axis):
    mean = np.nan
    if len(obj[this_slice]) > 0:
        mean = np.mean(obj[this_slice], axis=axis)
    return mean


[docs]def validation_curve(epochs, y=None, param_name="thresh", param_range=None, cv=None, return_param_range=False, n_jobs=1): """Validation curve on epochs for global autoreject. Parameters ---------- epochs : instance of mne.Epochs. The epochs. y : array | None The labels. param_name : str Name of the parameter that will be varied. Defaults to 'thresh'. param_range : array | None The values of the parameter that will be evaluated. If None, 15 values between the min and the max threshold will be tested. cv : int, cross-validation generator or an iterable, optional Determines the cross-validation strategy. return_param_range : bool If True the used param_range is returned. Defaults to False. n_jobs : int The number of thresholds to compute in parallel. Returns ------- train_scores : array The scores in the training set test_scores : array The scores in the test set param_range : array The thresholds used to build the validation curve. Only returned if `return_param_range` is True. """ from sklearn.model_selection import validation_curve estimator = _GlobalAutoReject() BaseEpochs = _get_epochs_type() if not isinstance(epochs, BaseEpochs): raise ValueError('Only accepts MNE epochs objects.') data_picks = _handle_picks(info=epochs.info, picks=None) X = epochs.get_data()[:, data_picks, :] n_epochs, n_channels, n_times = X.shape if param_range is None: ptps = np.ptp(X, axis=2) param_range = np.linspace(ptps.min(), ptps.max(), 15) estimator.n_channels = n_channels estimator.n_times = n_times train_scores, test_scores = \ validation_curve(estimator, X.reshape(n_epochs, -1), y=y, param_name="thresh", param_range=param_range, cv=cv, n_jobs=n_jobs, verbose=0) out = (train_scores, test_scores) if return_param_range: out += (param_range,) return out
def read_auto_reject(fname): """Read AutoReject object. Parameters ---------- fname : str The filename where the AutoReject object is saved. Returns ------- ar : instance of autoreject.AutoReject """ state = read_hdf5(fname, title='autoreject') init_kwargs = {param: state[param] for param in _INIT_PARAMS} if isinstance(init_kwargs['verbose'], int): init_kwargs['verbose'] = bool(init_kwargs['verbose']) ar = AutoReject(**init_kwargs) ar.__setstate__(state) return ar class BaseAutoReject(BaseEstimator): """Base class for rejection.""" def score(self, X, y=None): """Score it.""" if hasattr(self, 'n_channels'): X = X.reshape(-1, self.n_channels, self.n_times) if np.any(np.isnan(self.mean_)): return -np.inf else: return -np.sqrt(np.mean((np.median(X, axis=0) - self.mean_) ** 2)) class _GlobalAutoReject(BaseAutoReject): """Class to compute global rejection thresholds. Parameters ---------- n_channels : int | None The number of channels in the epochs. Defaults to None. n_times : int | None The number of time points in the epochs. Defaults to None. thresh : float Boilerplate API. The rejection threshold. """ def __init__(self, n_channels=None, n_times=None, thresh=40e-6): """Init it.""" self.thresh = thresh self.n_channels = n_channels self.n_times = n_times def fit(self, X, y=None): """Fit it.""" if self.n_channels is None or self.n_times is None: raise ValueError('Cannot fit without knowing n_channels' ' and n_times') X = X.reshape(-1, self.n_channels, self.n_times) deltas = np.array([np.ptp(d, axis=1) for d in X]) epoch_deltas = deltas.max(axis=1) keep = epoch_deltas <= self.thresh self.mean_ = _slicemean(X, keep, axis=0) return self
[docs]def get_rejection_threshold(epochs, decim=1, random_state=None, ch_types=None, cv=5, verbose=True): """Compute global rejection thresholds. Parameters ---------- epochs : mne.Epochs object The epochs from which to estimate the epochs dictionary decim : int The decimation factor: Increment for selecting every nth time slice. random_state : int seed, RandomState instance, or None (default) The seed of the pseudo random number generator to use. ch_types : str | list of str | None The channel types for which to find the rejection dictionary. e.g., ['mag', 'grad']. If None, the rejection dictionary will have keys ['mag', 'grad', 'eeg', 'eog']. cv : int The number of folds used. Defaults to 5. verbose : bool If False, suppress all output messages. Returns ------- reject : dict The rejection dictionary with keys as specified by ch_types. Note ---- Sensors marked as bad by user will be excluded when estimating the rejection dictionary. """ reject = dict() if ch_types is not None and not isinstance(ch_types, (list, str)): raise ValueError('ch_types must be of type None, list,' 'or str. Got %s' % type(ch_types)) if ch_types is None: ch_types = ['mag', 'grad', 'eeg', 'eog'] elif isinstance(ch_types, str): ch_types = [ch_types] if decim > 1: epochs = epochs.copy() epochs.decimate(decim=decim) cv = KFold(n_splits=cv, random_state=random_state) for ch_type in ch_types: if ch_type not in epochs: continue if ch_type == 'mag': picks = pick_types(epochs.info, meg='mag', eeg=False) elif ch_type == 'eeg': picks = pick_types(epochs.info, meg=False, eeg=True) elif ch_type == 'eog': picks = pick_types(epochs.info, meg=False, eog=True) elif ch_type == 'grad': picks = pick_types(epochs.info, meg='grad', eeg=False) X = epochs.get_data()[:, picks, :] n_epochs, n_channels, n_times = X.shape deltas = np.array([np.ptp(d, axis=1) for d in X]) all_threshes = np.sort(deltas.max(axis=1)) if verbose: print('Estimating rejection dictionary for %s' % ch_type) cache = dict() est = _GlobalAutoReject(n_channels=n_channels, n_times=n_times) def func(thresh): idx = np.where(thresh - all_threshes >= 0)[0][-1] thresh = all_threshes[idx] if thresh not in cache: est.set_params(thresh=thresh) obj = -np.mean(cross_val_score(est, X, cv=cv)) cache.update({thresh: obj}) return cache[thresh] n_epochs = all_threshes.shape[0] idx = np.concatenate(( np.linspace(0, n_epochs, 5, endpoint=False, dtype=int), [n_epochs - 1])) # ensure last point is in init idx = np.unique(idx) # linspace may be non-unique if n_epochs < 5 initial_x = all_threshes[idx] best_thresh, _ = bayes_opt(func, initial_x, all_threshes, expected_improvement, max_iter=10, debug=False, random_state=random_state) reject[ch_type] = best_thresh return reject
class _ChannelAutoReject(BaseAutoReject): """docstring for AutoReject.""" def __init__(self, thresh=40e-6): self.thresh = thresh def fit(self, X, y=None): """Fit it. Parameters ---------- X : array, shape (n_epochs, n_times) The data for one channel. y : None Redundant. Necessary to be compatible with sklearn API. """ deltas = np.ptp(X, axis=1) self.deltas_ = deltas keep = deltas <= self.thresh # XXX: actually go over all the folds before setting the min # in skopt. Otherwise, may confuse skopt. if self.thresh < np.min(np.ptp(X, axis=1)): assert np.sum(keep) == 0 keep = deltas <= np.min(np.ptp(X, axis=1)) self.mean_ = _slicemean(X, keep, axis=0) return self def _pick_exclusive_channels(info, ch_type): """Pick one and only one type.""" if ch_type == 'eeg': picks = mne.pick_types(info, meg=False, eeg=True) elif ch_type == 'eog': picks = mne.pick_types(info, meg=False, eog=True) elif ch_type == 'meg': picks = mne.pick_types(info, meg=True) elif ch_type == 'grad' or ch_type == 'mag': picks = mne.pick_types(info, meg=ch_type) return picks def _compute_thresh(this_data, method='bayesian_optimization', cv=10, y=None, random_state=None): """Compute the rejection threshold for one channel. Parameters ---------- this_data: array (n_epochs, n_times) Data for one channel. method : str 'bayesian_optimization' or 'random_search' cv : iterator Iterator for cross-validation. random_state : int seed, RandomState instance, or None (default) The seed of the pseudo random number generator to use. Returns ------- best_thresh : float The best threshold. Notes ----- For method='random_search', the random_state parameter gives deterministic results only for scipy versions >= 0.16. This is why we recommend using autoreject with scipy version 0.16 or greater. """ est = _ChannelAutoReject() all_threshes = np.sort(np.ptp(this_data, axis=1)) if method == 'random_search': param_dist = dict(thresh=uniform(all_threshes[0], all_threshes[-1])) rs = RandomizedSearchCV(est, param_distributions=param_dist, n_iter=20, cv=cv, random_state=random_state) rs.fit(this_data, y) best_thresh = rs.best_estimator_.thresh elif method == 'bayesian_optimization': cache = dict() def func(thresh): idx = np.where(thresh - all_threshes >= 0)[0][-1] thresh = all_threshes[idx] if thresh not in cache: est.set_params(thresh=thresh) obj = -np.mean(cross_val_score(est, this_data, y=y, cv=cv)) cache.update({thresh: obj}) return cache[thresh] n_epochs = all_threshes.shape[0] idx = np.concatenate(( np.linspace(0, n_epochs, 40, endpoint=False, dtype=int), [n_epochs - 1])) # ensure last point is in init idx = np.unique(idx) # linspace may be non-unique if n_epochs < 40 initial_x = all_threshes[idx] best_thresh, _ = bayes_opt(func, initial_x, all_threshes, expected_improvement, max_iter=10, debug=False, random_state=random_state) return best_thresh
[docs]def compute_thresholds(epochs, method='bayesian_optimization', random_state=None, picks=None, augment=True, verbose='progressbar', n_jobs=1): """Compute thresholds for each channel. Parameters ---------- epochs : instance of mne.Epochs The epochs objects whose thresholds must be computed. method : str 'bayesian_optimization' or 'random_search' random_state : int seed, RandomState instance, or None (default) The seed of the pseudo random number generator to use picks : ndarray, shape(n_channels,) | None The channels to be considered for autoreject. If None, defaults to data channels {'meg', 'eeg'}. augment : boolean Whether to augment the data or not. By default it is True, but set it to False, if the channel locations are not available. verbose : 'tqdm', 'tqdm_notebook', 'progressbar' or False The verbosity of progress messages. If `'progressbar'`, use `mne.utils.ProgressBar`. If `'tqdm'`, use `tqdm.tqdm`. If `'tqdm_notebook'`, use `tqdm.tqdm_notebook`. If False, suppress all output messages. n_jobs : int Number of jobs to run in parallel Returns ------- threshes : dict The channel-level rejection thresholds Examples -------- For example, we can compute the channel-level thresholds for all the EEG sensors this way: >>> compute_thresholds(epochs) """ return _compute_thresholds(epochs, method=method, random_state=random_state, picks=picks, augment=augment, verbose=verbose, n_jobs=n_jobs)
def _compute_thresholds(epochs, method='bayesian_optimization', random_state=None, picks=None, augment=True, dots=None, verbose='progressbar', n_jobs=1): if method not in ['bayesian_optimization', 'random_search']: raise ValueError('`method` param not recognized') picks = _handle_picks(info=epochs.info, picks=picks) _check_data(epochs, picks, verbose=verbose, ch_constraint='data_channels') picks_by_type = _get_picks_by_type(picks=picks, info=epochs.info) picks_by_type = None if len(picks_by_type) == 1 else picks_by_type # XXX if picks_by_type is not None: threshes = dict() for ch_type, this_picks in picks_by_type: threshes.update(_compute_thresholds( epochs=epochs, method=method, random_state=random_state, picks=this_picks, augment=augment, dots=dots, verbose=verbose, n_jobs=n_jobs)) else: n_epochs = len(epochs) data, y = epochs.get_data(), np.ones((n_epochs, )) if augment: epochs_interp = _clean_by_interp(epochs, picks=picks, dots=dots, verbose=verbose) # non-data channels will be duplicate data = np.concatenate((epochs.get_data(), epochs_interp.get_data()), axis=0) y = np.r_[np.zeros((n_epochs, )), np.ones((n_epochs, ))] cv = StratifiedShuffleSplit(n_splits=10, test_size=0.2, random_state=random_state) ch_names = epochs.ch_names my_thresh = delayed(_compute_thresh) parallel = Parallel(n_jobs=n_jobs, verbose=0) desc = 'Computing thresholds ...' threshes = parallel( my_thresh(data[:, pick], cv=cv, method=method, y=y, random_state=random_state) for pick in _pbar(picks, desc=desc, verbose=verbose)) threshes = {ch_names[p]: thresh for p, thresh in zip(picks, threshes)} return threshes class _AutoReject(BaseAutoReject): r"""Automatically reject bad epochs and repair bad trials. Parameters ---------- epochs : instance of mne.Epochs The epochs object consensus : float (0 to 1.0) Percentage of channels that must agree as a fraction of the total number of channels. This sets :math:`\\kappa/Q`. n_interpolate : int (default 0) Number of channels for which to interpolate. This is :math:`\\rho`. thresh_func : callable | None Function which returns the channel-level thresholds. If None, defaults to :func:`autoreject.compute_thresholds`. picks : ndarray, shape(n_channels,) | None The channels to be considered for autoreject. If None, defaults to data channels {'meg', 'eeg'}. verbose : 'tqdm', 'tqdm_notebook', 'progressbar' or False The verbosity of progress messages. If `'progressbar'`, use `mne.utils.ProgressBar`. If `'tqdm'`, use `tqdm.tqdm`. If `'tqdm_notebook'`, use `tqdm.tqdm_notebook`. If False, suppress all output messages. Attributes ----------- bad_segments : array, shape (n_epochs, n_channels) A boolean matrix where 1 denotes a bad data segment according to the sensor thresholds. labels : array, shape (n_epochs, n_channels) Similar to bad_segments, but with entries 0, 1, and 2. 0 : good data segment 1 : bad data segment not interpolated 2 : bad data segment interpolated bad_epochs_idx : array The indices of bad epochs. threshes_ : dict The sensor-level thresholds with channel names as keys and the peak-to-peak thresholds as the values. """ def __init__(self, consensus=0.1, n_interpolate=0, thresh_func=None, method='bayesian_optimization', picks=None, dots=None, verbose='progressbar'): """Init it.""" if thresh_func is None: thresh_func = _compute_thresholds if not (0 <= consensus <= 1): raise ValueError('"consensus" must be between 0 and 1. ' 'You gave me %s.' % consensus) self.consensus = consensus self.n_interpolate = n_interpolate self.thresh_func = thresh_func self.picks = picks self.verbose = verbose self.dots = dots def __repr__(self): """repr.""" class_name = self.__class__.__name__ params = dict(n_interpolate=self.n_interpolate, consensus=self.consensus, verbose=self.verbose, picks=self.picks) return '%s(%s)' % (class_name, _pprint(params, offset=len(class_name),),) def _vote_bad_epochs(self, epochs, picks): """Each channel votes for an epoch as good or bad. Parameters ---------- epochs : instance of mne.Epochs The epochs object for which bad epochs must be found. picks : array-like The indices of the channels to consider. """ labels = np.zeros((len(epochs), len(epochs.ch_names))) labels.fill(np.nan) bad_sensor_counts = np.zeros((len(epochs),)) this_ch_names = [epochs.ch_names[p] for p in picks] deltas = np.ptp(epochs.get_data()[:, picks], axis=-1).T threshes = [self.threshes_[ch_name] for ch_name in this_ch_names] for ch_idx, (delta, thresh) in enumerate(zip(deltas, threshes)): bad_epochs_idx = np.where(delta > thresh)[0] labels[:, picks[ch_idx]] = 0 labels[bad_epochs_idx, picks[ch_idx]] = 1 bad_sensor_counts = np.sum(labels == 1, axis=1) return labels, bad_sensor_counts def _get_epochs_interpolation(self, epochs, labels, picks, n_interpolate, verbose='progressbar'): """Interpolate the bad epochs.""" # 1: bad segment, # 2: interpolated assert labels.shape[0] == len(epochs) assert labels.shape[1] == len(epochs.ch_names) labels = labels.copy() non_picks = np.setdiff1d(range(epochs.info['nchan']), picks) for epoch_idx in range(len(epochs)): n_bads = labels[epoch_idx, picks].sum() if n_bads == 0: continue else: if n_bads <= n_interpolate: interp_chs_mask = labels[epoch_idx] == 1 else: # get peak-to-peak for channels in that epoch data = epochs[epoch_idx].get_data()[0] peaks = np.ptp(data, axis=-1) peaks[non_picks] = -np.inf # find channels which are bad by rejection threshold interp_chs_mask = labels[epoch_idx] == 1 # ignore good channels peaks[~interp_chs_mask] = -np.inf # find the ordering of channels amongst the bad channels sorted_ch_idx_picks = np.argsort(peaks)[::-1] # then select only the worst n_interpolate channels interp_chs_mask[ sorted_ch_idx_picks[n_interpolate:]] = False labels[epoch_idx][interp_chs_mask] = 2 return labels def _get_bad_epochs(self, bad_sensor_counts, ch_type, picks): """Get the mask of bad epochs.""" # XXX : avoid sorting twice sorted_epoch_idx = np.argsort(bad_sensor_counts)[::-1] bad_sensor_counts = np.sort(bad_sensor_counts)[::-1] n_channels = len(picks) n_consensus = self.consensus_[ch_type] * n_channels bad_epochs = np.zeros(len(bad_sensor_counts), dtype=np.bool) if np.max(bad_sensor_counts) >= n_consensus: n_epochs_drop = np.sum(bad_sensor_counts >= n_consensus) bad_epochs_idx = sorted_epoch_idx[:n_epochs_drop] bad_epochs[bad_epochs_idx] = True return bad_epochs def get_reject_log(self, epochs, threshes=None, picks=None): """Get rejection logs from epochs. .. note:: If multiple channel types are present, reject_log.bad_epochs reflects the union of bad epochs across channel types. Parameters ---------- epochs : instance of mne.Epochs The epochs from which to get the drop logs. picks : np.ndarray, shape(n_channels, ) | list | None The channel indices to be used. If None, the .picks attribute will be used. Returns ------- reject_log : instance of autoreject.RejectLog The rejection log. """ picks = (self.picks_ if picks is None else picks) picks_by_type = _get_picks_by_type(picks=picks, info=epochs.info) assert len(picks_by_type) == 1 ch_type, this_picks = picks_by_type[0] del picks labels, bad_sensor_counts = self._vote_bad_epochs( epochs, picks=this_picks) labels = self._get_epochs_interpolation( epochs, labels=labels, picks=this_picks, n_interpolate=self.n_interpolate_[ch_type]) assert len(labels) == len(epochs) bad_epochs = self._get_bad_epochs( bad_sensor_counts, ch_type=ch_type, picks=this_picks) reject_log = RejectLog(labels=labels, bad_epochs=bad_epochs, ch_names=epochs.ch_names) return reject_log def fit(self, epochs): """Compute the thresholds. Parameters ---------- epochs : instance of mne.Epochs The epochs object from which the channel-level thresholds are estimated. Returns ------- self : instance of _AutoReject The instance. """ self.picks_ = _handle_picks(info=epochs.info, picks=self.picks) _check_data(epochs, picks=self.picks_, verbose=self.verbose, ch_constraint='single_channel_type') picks_by_type = _get_picks_by_type(picks=self.picks_, info=epochs.info) assert len(picks_by_type) == 1 ch_type, this_picks = picks_by_type[0] self.consensus_ = dict() self.n_interpolate_ = dict() self.n_interpolate_[ch_type] = self.n_interpolate self.consensus_[ch_type] = self.consensus self.threshes_ = self.thresh_func( epochs.copy(), dots=self.dots, picks=self.picks_, verbose=self.verbose) reject_log = self.get_reject_log(epochs=epochs, picks=self.picks_) epochs_copy = epochs.copy() interp_channels = _get_interp_chs( reject_log.labels, reject_log.ch_names, this_picks) # interpolate copy to compute the clean .mean_ _interpolate_bad_epochs( epochs_copy, interp_channels=interp_channels, picks=self.picks_, verbose=self.verbose) self.mean_ = _slicemean( epochs_copy.get_data(), np.nonzero(np.invert(reject_log.bad_epochs))[0], axis=0) del epochs_copy # I can't wait for garbage collection. return self def transform(self, epochs, return_log=False): """Fix and find the bad epochs. Parameters ---------- epochs : instance of mne.Epochs The epochs object for which bad epochs must be found. return_log : bool If true the rejection log is also returned. Returns ------- epochs_clean : instance of mne.Epochs The cleaned epochs. reject_log : instance of autoreject.RejectLog The rejection log. Returned only of return_log is True. """ _check_data(epochs, picks=self.picks, verbose=self.verbose, ch_constraint='data_channels') reject_log = self.get_reject_log(epochs, picks=None) if np.all(reject_log.bad_epochs): raise ValueError('All epochs are bad. Sorry.') epochs_clean = epochs.copy() # this one knows how to handle picks. _apply_interp(reject_log, self, epochs_clean, self.threshes_, self.picks_, self.dots, self.verbose) _apply_drop(reject_log, self, epochs_clean, self.threshes_, self.picks_, self.verbose) if return_log: return epochs_clean, reject_log else: return epochs_clean def _interpolate_bad_epochs( epochs, interp_channels, picks, dots=None, verbose='progressbar'): """Actually do the interpolation.""" assert len(epochs) == len(interp_channels) pos = 2 for epoch_idx, interp_chs in _pbar( list(enumerate(interp_channels)), desc='Repairing epochs', position=pos, leave=True, verbose=verbose): epoch = epochs[epoch_idx] epoch.info['bads'] = interp_chs interpolate_bads(epoch, dots=dots, picks=picks, reset_bads=True) epochs._data[epoch_idx] = epoch._data def _run_local_reject_cv(epochs, thresh_func, picks_, n_interpolate, cv, consensus, dots, verbose): n_folds = cv.get_n_splits() loss = np.zeros((len(consensus), len(n_interpolate), n_folds)) # The thresholds must be learnt from the entire data local_reject = _AutoReject(thresh_func=thresh_func, verbose=verbose, picks=picks_, dots=dots) local_reject.fit(epochs) assert len(local_reject.consensus_) == 1 # works with one ch_type ch_type = next(iter(local_reject.consensus_)) labels, bad_sensor_counts = \ local_reject._vote_bad_epochs(epochs, picks=picks_) desc = 'n_interp' for jdx, n_interp in enumerate(_pbar(n_interpolate, desc=desc, position=1, verbose=verbose)): # we can interpolate before doing cross-valida(tion # because interpolation is independent across trials. local_reject.n_interpolate_[ch_type] = n_interp labels = local_reject._get_epochs_interpolation( epochs, labels=labels, picks=picks_, n_interpolate=n_interp) interp_channels = _get_interp_chs(labels, epochs.ch_names, picks_) epochs_interp = epochs.copy() # for learning we need to go by channnel type, even for meg _interpolate_bad_epochs( epochs_interp, interp_channels=interp_channels, picks=picks_, dots=dots, verbose=verbose) # Hack to allow len(self.cv_.split(X)) as ProgressBar # assumes an iterable whereas self.cv_.split(X) is a # generator class CVSplits(object): def __init__(self, gen, length): self.gen = gen self.length = length def __len__(self): return self.length def __iter__(self): return self.gen X = epochs.get_data()[:, picks_] cv_splits = CVSplits(cv.split(X), n_folds) pbar = _pbar(cv_splits, desc='Fold', position=3, verbose=verbose) for fold, (train, test) in enumerate(pbar): for idx, this_consensus in enumerate(consensus): # \kappa must be greater than \rho n_channels = len(picks_) if this_consensus * n_channels <= n_interp: loss[idx, jdx, fold] = np.inf continue local_reject.consensus_[ch_type] = this_consensus bad_epochs = local_reject._get_bad_epochs( bad_sensor_counts[train], picks=picks_, ch_type=ch_type) good_epochs_idx = np.nonzero(np.invert(bad_epochs))[0] local_reject.mean_ = _slicemean( epochs_interp[train].get_data()[:, picks_], good_epochs_idx, axis=0) loss[idx, jdx, fold] = -local_reject.score(X[test]) return local_reject, loss
[docs]class AutoReject(object): r"""Efficiently find n_interpolate and consensus. .. note:: AutoReject by design supports multiple channels. If no picks are passed, separate solutions will be computed for each channel type and internally combined. This then readily supports cleaning unseen epochs from the different channel types used during fit. Parameters ---------- consensus : array | None The values to try for percentage of channels that must agree as a fraction of the total number of channels. This sets :math:`\\kappa/Q`. If None, defaults to `np.linspace(0, 1.0, 11)` n_interpolate : array | None The values to try for the number of channels for which to interpolate. This is :math:`\\rho`. If None, defaults to np.array([1, 4, 32]) cv : a scikit-learn cross-validation object Defaults to cv=10 picks : ndarray, shape(n_channels) | None The channels to be considered for autoreject. If None, defaults to data channels {'meg', 'eeg'}, which will lead fitting and combining autoreject solutions across these channel types. Note that, if picks is None, autoreject ignores channels marked bad in epochs.info['bads']. thresh_method : str 'bayesian_optimization' or 'random_search' n_jobs : int The number of jobs. random_state : int seed, RandomState instance, or None (default) The seed of the pseudo random number generator to use. verbose : 'tqdm', 'tqdm_notebook', 'progressbar' or False The verbosity of progress messages. If `'progressbar'`, use `mne.utils.ProgressBar`. If `'tqdm'`, use `tqdm.tqdm`. If `'tqdm_notebook'`, use `tqdm.tqdm_notebook`. If False, suppress all output messages. Attributes ----------- local_reject_ : list The instances of _AutoReject for each channel type. threshes_ : dict The sensor-level thresholds with channel names as keys and the peak-to-peak thresholds as the values. loss_ : dict of array, shape (len(n_interpolate), len(consensus)) The cross validation error for different parameter values. consensus_ : dict The estimated consensus per channel type. n_interpolate_ : dict The estimated n_interpolate per channel type. picks_ : array-like, shape (n_data_channels,) The data channels considered by autoreject. By default only data channels, not already marked as bads are considered. """
[docs] def __init__(self, n_interpolate=None, consensus=None, thresh_func=None, cv=10, picks=None, thresh_method='bayesian_optimization', n_jobs=1, random_state=None, verbose='progressbar'): """Init it.""" self.n_interpolate = n_interpolate self.consensus = consensus self.thresh_method = thresh_method self.cv = cv self.verbose = verbose self.picks = picks # XXX : should maybe be ch_types? self.n_jobs = n_jobs self.random_state = random_state if self.consensus is None: self.consensus = np.linspace(0, 1.0, 11)
def __repr__(self): """repr.""" class_name = self.__class__.__name__ params = dict(n_interpolate=self.n_interpolate, consensus=self.consensus, cv=self.cv, verbose=self.verbose, picks=self.picks, thresh_method=self.thresh_method, random_state=self.random_state, n_jobs=self.n_jobs) return '%s(%s)' % (class_name, _pprint(params, offset=len(class_name),),) def __getstate__(self): """Get the state of autoreject as a dictionary.""" state = dict() for param in _INIT_PARAMS: state[param] = getattr(self, param) for param in _FIT_PARAMS: if hasattr(self, param): state[param] = getattr(self, param) if hasattr(self, 'local_reject_'): state['local_reject_'] = dict() for ch_type in self.local_reject_: state['local_reject_'][ch_type] = dict() for param in _INIT_PARAMS[:4] + _FIT_PARAMS[:4]: state['local_reject_'][ch_type][param] = \ getattr(self.local_reject_[ch_type], param) return state def __setstate__(self, state): """Set the state of autoreject.""" for param in state.keys(): if param == 'local_reject_': local_reject_ = dict() for ch_type in state['local_reject_']: init_kwargs = { key: state['local_reject_'][ch_type][key] for key in _INIT_PARAMS[:4] } if isinstance(init_kwargs['verbose'], int): init_kwargs['verbose'] = bool(init_kwargs['verbose']) local_reject_[ch_type] = _AutoReject(**init_kwargs) for key in _FIT_PARAMS[:4]: setattr(local_reject_[ch_type], key, state['local_reject_'][ch_type][key]) self.local_reject_ = local_reject_ elif param not in _INIT_PARAMS: setattr(self, param, state[param]) def fit(self, epochs): """Fit the epochs on the AutoReject object. Parameters ---------- epochs : instance of mne.Epochs The epochs object to be fit. Returns ------- self : instance of AutoReject The instance. """ self.picks_ = _handle_picks(picks=self.picks, info=epochs.info) _check_data(epochs, picks=self.picks_, verbose=self.verbose) self.cv_ = self.cv if isinstance(self.cv_, int): self.cv_ = KFold(n_splits=self.cv_) # XXX : maybe use an mne function in pick.py ? picks_by_type = _get_picks_by_type(info=epochs.info, picks=self.picks_) ch_types = [ch_type for ch_type, _ in picks_by_type] self.dots = None if 'mag' in ch_types or 'grad' in ch_types: meg_picks = pick_types(epochs.info, meg=True, eeg=False, exclude=[]) this_info = mne.pick_info(epochs.info, meg_picks, copy=True) self.dots = _compute_dots(this_info) thresh_func = partial(_compute_thresholds, n_jobs=self.n_jobs, method=self.thresh_method, random_state=self.random_state, dots=self.dots) if self.n_interpolate is None: if len(self.picks_) < 4: raise ValueError('Too few channels. autoreject is unlikely' ' to be effective') # XXX: dont interpolate all channels max_interp = min(len(self.picks_) - 1, 32) self.n_interpolate = np.array([1, 4, max_interp]) self.n_interpolate_ = dict() # rho self.consensus_ = dict() # kappa self.threshes_ = dict() # update self.loss_ = dict() self.local_reject_ = dict() for ch_type, this_picks in picks_by_type: if self.verbose is not False: print('Running autoreject on ch_type=%s' % ch_type) this_local_reject, this_loss = \ _run_local_reject_cv(epochs, thresh_func, this_picks, self.n_interpolate, self.cv_, self.consensus, self.dots, self.verbose) self.threshes_.update(this_local_reject.threshes_) best_idx, best_jdx = \ np.unravel_index(this_loss.mean(axis=-1).argmin(), this_loss.shape[:2]) self.consensus_[ch_type] = self.consensus[best_idx] self.n_interpolate_[ch_type] = self.n_interpolate[best_jdx] self.loss_[ch_type] = this_loss # update local reject with best and store it this_local_reject.consensus_[ch_type] = self.consensus_[ch_type] this_local_reject.n_interpolate_[ch_type] = \ self.n_interpolate_[ch_type] # needed for generating reject logs by channel self.local_reject_[ch_type] = this_local_reject if self.verbose is not False: print('\n\n\n\nEstimated consensus=%0.2f and n_interpolate=%d' % (self.consensus_[ch_type], self.n_interpolate_[ch_type])) return self def get_reject_log(self, epochs, picks=None): """Get rejection logs of epochs. .. note:: If multiple channel types are present, reject_log['bad_epochs_idx'] reflects the union of bad trials across channel types. Parameters ---------- epochs : instance of mne.Epochs The epoched data for which the reject log is computed. picks : np.ndarray, shape(n_channels, ) | list | None The channel indices to be used. If None, the .picks attribute will be used. Returns ------- reject_log : instance of autoreject.RejectLog The reject log. """ # XXX gut feeling that there is a bad condition that we miss ch_names = [cc for cc in epochs.ch_names] labels = np.ones((len(epochs), len(ch_names))) labels.fill(np.nan) reject_log = RejectLog( labels=labels, bad_epochs=np.zeros(len(epochs), dtype=np.bool), ch_names=ch_names) picks_by_type = _get_picks_by_type(info=epochs.info, picks=self.picks_) for ch_type, this_picks in picks_by_type: this_reject_log = self.local_reject_[ch_type].get_reject_log( epochs, threshes=self.threshes_, picks=this_picks) reject_log.labels[:, this_picks] = \ this_reject_log.labels[:, this_picks] reject_log.bad_epochs = np.logical_or( reject_log.bad_epochs, this_reject_log.bad_epochs) reject_log.ch_names = this_reject_log.ch_names return reject_log def transform(self, epochs, return_log=False): """Remove bad epochs, repairs sensors and returns clean epochs. Parameters ---------- epochs : instance of mne.Epochs The epochs object which must be cleaned. return_log : bool If true the rejection log is also returned. Returns ------- epochs_clean : instance of mne.Epochs The cleaned epochs reject_log : instance of autoreject.RejectLog The rejection log. Returned only if return_log is True. """ # XXX : should be a check_fitted method if not hasattr(self, 'n_interpolate_'): raise ValueError('Please run autoreject.fit() method first') _check_data(epochs, picks=self.picks_, verbose=self.verbose) reject_log = self.get_reject_log(epochs) epochs_clean = epochs.copy() _apply_interp(reject_log, epochs_clean, self.threshes_, self.picks_, self.dots, self.verbose) _apply_drop(reject_log, epochs_clean, self.threshes_, self.picks_, self.verbose) if return_log: return epochs_clean, reject_log else: return epochs_clean def fit_transform(self, epochs, return_log=False): """Estimate the rejection params and finds bad epochs. Parameters ---------- epochs : instance of mne.Epochs The epochs object which must be cleaned. return_log : bool If true the rejection log is also returned. Returns ------- epochs_clean : instance of mne.Epochs The cleaned epochs. reject_log : instance of autoreject.RejectLog The rejection log. Returned only of return_log is True. """ return self.fit(epochs).transform(epochs, return_log=return_log) def save(self, fname, overwrite=False): """Save autoreject object. Parameters ---------- fname : str The filename to save to. The filename must end in '.h5' or '.hdf5'. overwrite : bool If True, overwrite file if it already exists. Defaults to False. """ fname = op.realpath(fname) if not overwrite and op.isfile(fname): raise ValueError('%s already exists. Please make overwrite=True' 'if you want to overwrite this file' % fname) write_hdf5(fname, self.__getstate__(), overwrite=overwrite, title='autoreject')
def _check_fit(epochs, threshes_, picks_): msg = ('You are passing channels which were not present ' 'at fit-time. Please fit it again, this time ' 'correctly.') if not all(epochs.ch_names[pp] in threshes_ for pp in picks_): raise ValueError(msg) def _apply_interp(reject_log, epochs, threshes_, picks_, dots, verbose): _check_fit(epochs, threshes_, picks_) interp_channels = _get_interp_chs( reject_log.labels, reject_log.ch_names, picks_) _interpolate_bad_epochs( epochs, interp_channels=interp_channels, picks=picks_, dots=dots, verbose=verbose) def _apply_drop(reject_log, epochs, threshes_, picks_, verbose): _check_fit(epochs, threshes_, picks_) if np.any(reject_log.bad_epochs): epochs.drop(np.nonzero(reject_log.bad_epochs)[0], reason='AUTOREJECT') elif verbose: print("No bad epochs were found for your data. Returning " "a copy of the data you wanted to clean. Interpolation " "may have been done.") def _get_interp_chs(labels, ch_names, picks): """Convert labels to channel names. It returns a list of length n_epochs. Each entry contains the names of the channels to interpolate. labels is of shape n_epochs x n_channels and picks is the sublist of channels to consider. """ interp_channels = list() assert labels.shape[1] == len(ch_names) assert labels.shape[1] > np.max(picks) idx_nan_in_row = np.where(np.any(~np.isnan(labels), axis=0))[0] np.testing.assert_array_equal(picks, idx_nan_in_row) for this_labels in labels: interp_idx = np.where(this_labels == 2)[0] interp_channels.append([ch_names[ii] for ii in interp_idx]) return interp_channels
[docs]class RejectLog(object): """The Rejection Log. Parameters ---------- bad_epochs : array-like, shape (n_epochs,) The boolean array with entries True for epochs that are marked as bad. labels : array, shape (n_channels, n_epochs) It contains integers that encode if a channel in a given epoch is good (value 0), bad (1), or bad and interpolated (2). ch_names : list of str The list of channels corresponding to the rows of the labels. """
[docs] def __init__(self, bad_epochs, labels, ch_names): self.bad_epochs = bad_epochs self.labels = labels self.ch_names = ch_names assert len(bad_epochs) == labels.shape[0] assert len(ch_names) == labels.shape[1]
def plot(self, orientation='vertical', show=True): """Plot. Parameters ---------- orientation : 'vertical' or 'horizontal' If `'vertical'`, will plot sensors on x-axis and epochs on y-axis. If `'horizontal'`, will plot epochs on x-axis and sensors on y-axis. show : bool If True, display the figure immediately. Returns ------- figure : Instance of matplotlib.figure.Figure """ import matplotlib.pyplot as plt figure, ax = plt.subplots(figsize=(12, 6)) ax.grid(False) ch_names_ = self.ch_names[7::10] if orientation == 'horizontal': ax.imshow(self.labels.T, cmap='Reds', interpolation='nearest') ax.set_xlabel('Epochs') ax.set_ylabel('Channels') plt.setp(ax, yticks=range(7, self.labels.shape[1], 10), yticklabels=ch_names_) elif orientation == 'vertical': ax.imshow(self.labels, cmap='Reds', interpolation='nearest') ax.set_xlabel('Channels') ax.set_ylabel('Epochs') plt.setp(ax, xticks=range(7, self.labels.shape[1], 10), xticklabels=ch_names_) else: msg = """orientation can be only \ 'horizontal' or 'vertical'. Got %s""" % orientation raise ValueError(msg) # XXX to be fixed plt.setp(ax.get_yticklabels(), rotation=0) plt.setp(ax.get_xticklabels(), rotation=90) ax.tick_params(axis=u'both', which=u'both', length=0) plt.tight_layout(rect=[None, None, None, 1.1]) if show: plt.show() return figure def plot_epochs(self, epochs, scalings=None, title=''): """Plot interpolated and dropped epochs. Parameters ---------- epochs : instance of Epochs The epochs. scalings : dict | None Scaling factors for the traces. If None, defaults to:: dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4, emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4, whitened=1e2) title : str The title to display. Returns ------- fig : Instance of matplotlib.figure.Figure Epochs traces. """ labels = self.labels n_epochs, n_channels = labels.shape if not labels.shape[0] == len(epochs.events): raise ValueError('The number of epochs should match the number of' 'epochs *before* autoreject. Please provide' 'the epochs object before running autoreject') if not labels.shape[1] == len(epochs.ch_names): raise ValueError('The number of channels should match the number' ' of channels before running autoreject.') bad_epochs_idx = np.where(self.bad_epochs)[0] if len(bad_epochs_idx) > 0 and \ bad_epochs_idx.max() > len(epochs.events): raise ValueError('You had a bad_epoch with index' '%d but there are only %d epochs. Make sure' ' to provide the epochs *before* running' 'autoreject.' % (bad_epochs_idx.max(), len(epochs.events))) color_map = {0: None, 1: 'r', 2: (0.6, 0.6, 0.6, 1.0)} epoch_colors = list() for epoch_idx, label_epoch in enumerate(labels): if self.bad_epochs[epoch_idx]: epoch_color = ['r'] * n_channels epoch_colors.append(epoch_color) continue epoch_color = list() for this_label in label_epoch: if not np.isnan(this_label): epoch_color.append(color_map[this_label]) else: epoch_color.append(None) epoch_colors.append(epoch_color) return plot_mne_epochs( epochs=epochs, epoch_colors=epoch_colors, scalings=scalings, title='')
Fork me on GitHub