"""Utility functions for autoreject."""# Authors: Mainak Jas <mainak.jas@telecom-paristech.fr># Denis A. Engemann <denis.engemann@gmail.com>fromcollectionsimportdefaultdictimportinspectimportwarningsimportnumpyasnpimportmnefrommneimportpick_types,pick_info,channel_typefrommne.io.pickimport_picks_to_idxfrommne.channels.interpolationimport_do_interp_dotsfrommne.forward._field_interpolationimport_setup_dots_GDKW=dict()# MNE >= 1.6if"copy"ininspect.getfullargspec(mne.Epochs.get_data).kwonlyargs:_GDKW["copy"]=Falsedef_check_ch_locs(chs):"""Check if channel locations exist. Parameters ---------- chs : dict The channels from info['chs'] """locs3d=np.array([ch['loc'][:3]forchinchs])returnnot((locs3d==0).all()or(~np.isfinite(locs3d)).any()ornp.allclose(locs3d,0.))def_check_data(epochs,picks,ch_constraint='data_channels',check_loc=True,verbose=True):BaseEpochs=_get_epochs_type()ifnotisinstance(epochs,BaseEpochs):raiseValueError('Only accepts MNE epochs objects.')ifepochs.preloadisFalse:raiseValueError('Data must be preloaded.')n_bads=len(epochs.info['bads'])picked_info=pick_info(epochs.info,picks)ch_types_picked={channel_type(picked_info,idx)foridxinrange(len(picks))}ifcheck_locandnot_check_ch_locs(picked_info['chs']):raiseRuntimeError('Valid channel positions are needed ''for autoreject to work')# XXX : ch_constraint -> allow_many_types=True | Falsesupported_channels=['mag','grad','eeg','hbo','hbr','ecog','seeg']channel_text=', '.join(xforxinsupported_channels)message=f'AutoReject only supports {channel_text} currenly.'ifch_constraint=='data_channels':ifnotall(chinsupported_channelsforchinch_types_picked):raiseValueError(message)elifch_constraint=='single_channel_type':ifsum(chinch_types_pickedforchinsupported_channels)>1:raiseValueError(message)# XXX: to checkelse:raiseValueError('bad value for ch_constraint.')ifn_bads>0:ifverboseisnotFalse:warnings.warn('%i channels are marked as bad. These will be ignored. ''If you want them to be considered by autoreject please ''remove them from epochs.info["bads"].'%n_bads)def_handle_picks(info,picks):"""Pick the data channls or return picks."""ifpicksisNone:out=mne.pick_types(info,meg=True,eeg=True,ref_meg=False,fnirs=True,ecog=True,seeg=True,exclude='bads')else:out=_picks_to_idx(info,picks,exclude='bads')returnoutdef_get_picks_by_type(info,picks):"""Get the picks grouped by channel type."""# do magic heresub_picks_=defaultdict(list)keys=list()forppinpicks:key=channel_type(info=info,idx=pp)sub_picks_[key].append(pp)ifkeynotinkeys:keys.append(key)picks_by_type=[(kk,sub_picks_[kk])forkkinkeys]returnpicks_by_type
[docs]defset_matplotlib_defaults(plt,style='ggplot'):"""Set publication quality defaults for matplotlib. Parameters ---------- plt : instance of matplotlib.pyplot The plt instance. """importmatplotlibmatplotlib.style.use(style)fontsize=17params={'axes.labelsize':fontsize+2,'legend.fontsize':fontsize,'xtick.labelsize':fontsize,'ytick.labelsize':fontsize,'axes.titlesize':fontsize+2}plt.rcParams.update(params)
def_pprint(params,offset=0,printer=repr):"""Pretty print the dictionary 'params' (copied from sklearn) Parameters ---------- params : dict The dictionary to pretty print offset : int The offset in characters to add at the begin of each line. printer : callable The function to convert entries to strings, typically the builtin str or repr Returns ------- lines : str The pretty print of the dictionary as a string. """# Do a multi-line justified repr:options=np.get_printoptions()np.set_printoptions(precision=5,threshold=64,edgeitems=2)params_list=list()this_line_length=offsetline_sep=',\n'+(1+offset//2)*' 'fori,(k,v)inenumerate(sorted(params.items())):iftype(v)isfloat:# use str for representing floating point numbers# this way we get consistent representation across# architectures and versions.this_repr='%s=%s'%(k,str(v))else:# use repr of the restthis_repr='%s=%s'%(k,printer(v))iflen(this_repr)>500:this_repr=this_repr[:300]+'...'+this_repr[-100:]ifi>0:if(this_line_length+len(this_repr)>=75or'\n'inthis_repr):params_list.append(line_sep)this_line_length=len(line_sep)else:params_list.append(', ')this_line_length+=2params_list.append(this_repr)this_line_length+=len(this_repr)np.set_printoptions(**options)lines=''.join(params_list)# Strip trailing space to avoid nightmare in doctestslines='\n'.join(li.rstrip(' ')forliinlines.split('\n'))returnlinesdef_pbar(iterable,desc,verbose=True,**kwargs):raise_error=Falseifisinstance(verbose,str):ifverbosenotin{"tqdm","tqdm_notebook","progressbar"}:raise_error=Trueverbose=bool(verbose)elifisinstance(verbose,(int,bool)):verbose=bool(verbose)# this can happen with picklingelse:raise_error=Truewarnings.warn((f"verbose flag only supports boolean inputs. Option {verbose} "f"coerced into type {bool(verbose)}"),DeprecationWarning)verbose=bool(verbose)ifraise_error:raiseValueError(f"verbose must be a boolean value. Got {repr(verbose)}")ifverbose:frommne.utils.progressbarimportProgressBarpbar=ProgressBar(iterable,mesg=desc,**kwargs)else:pbar=iterablereturnpbardef_get_epochs_type():ifhasattr(mne.epochs,'_BaseEpochs'):BaseEpochs=mne.epochs._BaseEpochselse:BaseEpochs=mne.epochs.BaseEpochsreturnBaseEpochsdefclean_by_interp(inst,picks=None,verbose=True):"""Clean epochs/evoked by LOOCV. Parameters ---------- inst : mne.Evoked | mne.Epochs The evoked or epochs object. picks : str | list | slice | None Channels to include. Slices and lists of integers will be interpreted as channel indices. In lists, channel *type* strings (e.g., ``['meg', 'eeg']``) will pick channels of those types, channel *name* strings (e.g., ``['MEG0111', 'MEG2623']`` will pick the given channels. Can also be the string values ``'all'`` to pick all channels, or ``'data'`` to pick data channels. None (default) will pick data channels {'meg', 'eeg'}. Note that channels in ``info['bads']`` *will be included* if their names or indices are explicitly provided. verbose : bool The verbosity of progress messages. If False, suppress all output messages. Returns ------- inst_clean : instance of mne.Evoked or mne.Epochs Instance after interpolation of bad channels. """return_clean_by_interp(inst,picks=picks,verbose=verbose)def_clean_by_interp(inst,picks=None,dots=None,verbose=True):inst_interp=inst.copy()mesg='Creating augmented epochs'picks=_handle_picks(info=inst_interp.info,picks=picks)BaseEpochs=_get_epochs_type()ch_names=[inst.info['ch_names'][p]forpinpicks]forch_idx,(pick,ch)inenumerate(_pbar(list(zip(picks,ch_names)),desc=mesg,verbose=verbose)):inst.info['bads']=[ch]pick_interp=mne.pick_channels(inst.info['ch_names'],[ch])[0]data_orig=inst._data[:,pick_interp].copy()interpolate_bads(inst,picks=picks,dots=dots,reset_bads=True,mode='fast')ifisinstance(inst,mne.Evoked):inst_interp.data[pick]=inst.data[pick_interp]elifisinstance(inst,BaseEpochs):inst_interp._data[:,pick]=inst._data[:,pick_interp]else:raiseValueError('Unrecognized type for inst')inst._data[:,pick_interp]=data_orig.copy()inst.info['bads']=inst_interp.info['bads'].copy()returninst_interpdefinterpolate_bads(inst,picks,dots=None,reset_bads=True,mode='accurate'):"""Interpolate bad MEG and EEG channels."""importmneeeg_picks=set(pick_types(inst.info,meg=False,eeg=True,exclude=[]))eeg_picks_interp=[pforpinpicksifpineeg_picks]meg_picks=set(pick_types(inst.info,meg=True,eeg=False,exclude=[]))meg_picks_interp=[pforpinpicksifpinmeg_picks]# to prevent cobyla printf errorwithmne.utils.use_log_level('error'):iflen(eeg_picks_interp)>0:_interpolate_bads_eeg(inst,picks=eeg_picks_interp)iflen(meg_picks_interp)>0:_interpolate_bads_meg_fast(inst,picks=meg_picks_interp,dots=dots,mode=mode)ifreset_badsisTrue:inst.info['bads']=[]returninstdef_interpolate_bads_eeg(inst,picks=None):""" Interpolate bad EEG channels. Operates in place. Parameters ---------- inst : mne.io.Raw, mne.Epochs or mne.Evoked The data to interpolate. Must be preloaded. picks : str | list | slice | None Channels to include for interpolation. Slices and lists of integers will be interpreted as channel indices. In lists, channel *name* strings (e.g., ``['EEG 01', 'EEG 02']``) will pick the given channels. None (default) will pick all EEG channels. Note that channels in ``info['bads']`` *will be included* if their names or indices are explicitly provided. """frommne.bemimport_fit_spherefrommne.utilsimportlogger,warnfrommne.channels.interpolationimport_do_interp_dotsfrommne.channels.interpolationimport_make_interpolation_matriximportnumpyasnpinst.info._check_consistency()ifpicksisNone:picks=pick_types(inst.info,meg=False,eeg=True,exclude=[])else:picks=_handle_picks(inst.info,picks)bads_idx=np.zeros(len(inst.ch_names),dtype=bool)goods_idx=np.zeros(len(inst.ch_names),dtype=bool)bads_idx[picks]=[inst.ch_names[ch]ininst.info['bads']forchinpicks]iflen(picks)==0orbads_idx.sum()==0:returngoods_idx[picks]=Truegoods_idx[bads_idx]=Falsepos=inst._get_channel_positions(picks)# Make sure only good EEG are usedbads_idx_pos=bads_idx[picks]goods_idx_pos=goods_idx[picks]pos_good=pos[goods_idx_pos]pos_bad=pos[bads_idx_pos]# test spherical fitradius,center=_fit_sphere(pos_good)distance=np.sqrt(np.sum((pos_good-center)**2,1))distance=np.mean(distance/radius)ifnp.abs(1.-distance)>0.1:warn('Your spherical fit is poor, interpolation results are ''likely to be inaccurate.')logger.info('Computing interpolation matrix from {0} sensor ''positions'.format(len(pos_good)))interpolation=_make_interpolation_matrix(pos_good,pos_bad)logger.info('Interpolating {0} sensors'.format(len(pos_bad)))_do_interp_dots(inst,interpolation,goods_idx,bads_idx)def_interpolate_bads_meg_fast(inst,picks,mode='accurate',dots=None):"""Interpolate bad channels from data in good channels."""# We can have pre-picked instances or not.# And we need to handle it.inst_picked=Trueiflen(inst.ch_names)>len(picks):picked_info=pick_info(inst.info,picks)dots=_pick_dots(dots,picks,picks)inst_picked=Falseelse:picked_info=inst.info.copy()defget_picks_bad_good(info,picks_meg):picks_good=[pforpinpicks_megifinfo['ch_names'][p]notininfo['bads']]# select the bad meg channel to be interpolatediflen(info['bads'])==0:picks_bad=[]else:picks_bad=[pforpinpicks_megifinfo['ch_names'][p]ininfo['bads']]returnpicks_meg,picks_good,picks_badpicks_meg,picks_good,picks_bad=get_picks_bad_good(picked_info,range(picked_info['nchan']))# return without doing anything if there are no meg channelsiflen(picks_meg)==0orlen(picks_bad)==0:return# we need to make sure that only meg channels are passed here# as the MNE interpolation code is not fogriving.# This is why we picked the info.mapping=_fast_map_meg_channels(picked_info,pick_from=picks_good,pick_to=picks_bad,dots=dots,mode=mode)# the downside is that the mapping matrix now does not match# the unpicked info of the data.# Since we may have picked the info, we need to double map# the indices._,picks_good_,picks_bad_orig=get_picks_bad_good(inst.info,picks)ch_names_a=[picked_info['ch_names'][pp]forppinpicks_bad]ch_names_b=[inst.info['ch_names'][pp]forppinpicks_bad_orig]assertch_names_a==ch_names_bifnotinst_picked:picks_good_=[ppforppinpicksifppinpicks_good_]assertlen(picks_good_)==len(picks_good)ch_names_a=[picked_info['ch_names'][pp]forppinpicks_good]ch_names_b=[inst.info['ch_names'][pp]forppinpicks_good_]assertch_names_a==ch_names_b_do_interp_dots(inst,mapping,picks_good_,picks_bad_orig)def_compute_dots(info,mode='fast',*,templates):"""Compute all-to-all dots."""frommne.forward._lead_dotsimport_do_self_dots,_do_cross_dotsfrommne.forward._make_forwardimport_create_meg_coils,_read_coil_defsfrommne.bemimport_check_originiftemplatesisNone:templates=_read_coil_defs(verbose='error')coils=_create_meg_coils(info['chs'],'normal',info['dev_head_t'],templates)my_origin=_check_origin((0.,0.,0.04),info)int_rad,noise,lut_fun,n_fact=_setup_dots(mode,info,coils,'meg')self_dots=_do_self_dots(int_rad,False,coils,my_origin,'meg',lut_fun,n_fact,n_jobs=1)cross_dots=_do_cross_dots(int_rad,False,coils,coils,my_origin,'meg',lut_fun,n_fact).Treturnself_dots,cross_dotsdef_pick_dots(dots,pick_from,pick_to):ifdotsisNone:returndotsself_dots,cross_dots=dotsself_dots=self_dots[pick_from,:][:,pick_from]cross_dots=cross_dots[pick_to,:][:,pick_from]return[self_dots,cross_dots]def_fast_map_meg_channels(info,pick_from,pick_to,dots=None,mode='fast'):frommne.forward._field_interpolationimport_compute_mapping_matrixfrommne.forward._make_forwardimport_create_meg_coils,_read_coil_defsfrommne.bemimport_check_originmiss=1e-4# Smoothing criterion for MEGinfo_from=pick_info(info,pick_from,copy=True)templates=_read_coil_defs(verbose='error')coils_from=_create_meg_coils(info_from['chs'],'normal',info_from['dev_head_t'],templates)my_origin=_check_origin((0.,0.,0.04),info_from)int_rad,noise,lut_fun,n_fact=_setup_dots(mode,info_from,coils_from,'meg')# This function needs a clean input. It hates the presence of other# channels than MEG channels. Make sure all is picked.ifdotsisNone:dots=self_dots,cross_dots=_compute_dots(info,mode=mode,templates=templates)else:self_dots,cross_dots=dotsself_dots,cross_dots=_pick_dots(dots,pick_from,pick_to)ch_names=[c['ch_name']forcininfo_from['chs']]fmd=dict(kind='meg',ch_names=ch_names,origin=my_origin,noise=noise,self_dots=self_dots,surface_dots=cross_dots,int_rad=int_rad,miss=miss)# XXX: hack to silence _compute_mapping_matrixwithmne.utils.use_log_level('error'):fmd['data']=_compute_mapping_matrix(fmd,info_from)returnfmd['data']def_get_channel_type(epochs,picks):"""return whether a set of picks are all meg or all eeg channels."""picked_info=pick_info(epochs.info,picks)ch_types_picked={channel_type(picked_info,idx)foridxinrange(len(picks))}invalid_ch_types_present=[keyforkeyinch_types_pickedifkeynotin['mag','grad','eeg']andkeyinepochs]iflen(invalid_ch_types_present)>0:raiseValueError('Invalid channel types present in epochs.'' Expected ONLY `meg` or ONLY `eeg`. Got %s'%', '.join(invalid_ch_types_present))has_meg=any(kkinch_types_pickedforkkin('mag','grad'))if'eeg'inch_types_pickedandhas_meg:raiseValueError('Got mixed channel types. Pick either eeg or meg'' but not both')if'eeg'inch_types_picked:return'eeg'elifhas_meg:return'meg'else:raiseValueError('Oh no! Your channel type is not known.')