Skip to content

Commit

Permalink
Merge branch 'release/2.3.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
oliche committed Nov 4, 2021
2 parents 5dee8e4 + a7ce057 commit 5ca5764
Show file tree
Hide file tree
Showing 19 changed files with 1,337 additions and 221 deletions.
100 changes: 99 additions & 1 deletion brainbox/ephys_plots.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
from matplotlib import cm

import matplotlib.pyplot as plt
from brainbox.plot_base import (ImagePlot, ScatterPlot, ProbePlot, LinePlot, plot_line,
plot_image, plot_probe, plot_scatter, arrange_channels2banks)
from brainbox.processing import bincount2D, compute_cluster_average
from ibllib.atlas.regions import BrainRegions


def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 300),
Expand Down Expand Up @@ -372,3 +373,100 @@ def line_amp_plot(spike_amps, spike_depths, spike_times, chn_coords, d_bin=10, d
fig, ax = plot_line(data.convert2dict())
return data.convert2dict(), fig, ax
return data


def plot_brain_regions(channel_ids, channel_depths=None, brain_regions=None, display=True, ax=None):
"""
Plot brain regions along probe, if channel depths is provided will plot along depth otherwise along channel idx
:param channel_ids: atlas ids for each channel
:param channel_depths: depth along probe for each channel
:param brain_regions: BrainRegions object
:param display: whether to output plot
:param ax: axis to plot on
:return:
"""

if channel_depths is not None:
assert channel_ids.shape[0] == channel_depths.shape[0]

br = brain_regions or BrainRegions()

region_info = br.get(channel_ids)
boundaries = np.where(np.diff(region_info.id) != 0)[0]
boundaries = np.r_[0, boundaries, region_info.id.shape[0] - 1]

regions = np.c_[boundaries[0:-1], boundaries[1:]]
if channel_depths is not None:
regions = channel_depths[regions]
region_labels = np.c_[np.mean(regions, axis=1), region_info.acronym[boundaries[1:]]]
region_colours = region_info.rgb[boundaries[1:]]

if display:
if ax is None:
fig, ax = plt.subplots()

for reg, col in zip(regions, region_colours):
height = np.abs(reg[1] - reg[0])
color = col / 255
ax.bar(x=0.5, height=height, width=1, color=color, bottom=reg[0], edgecolor='w')
ax.set_yticks(region_labels[:, 0].astype(int))
ax.yaxis.set_tick_params(labelsize=8)
ax.get_xaxis().set_visible(False)
ax.set_yticklabels(region_labels[:, 1])
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)

return fig, ax
else:
return regions, region_labels, region_colours


def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None,
display=False, cmap='hot'):
"""
Plot cumulative amplitude of spikes across depth
:param spike_amps:
:param spike_depths:
:param spike_times:
:param n_amp_bins: number of amplitude bins to use
:param d_bin: the value of the depth bins in um (default is 40 um)
:param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps
:param d_range: depth range to use, by default [0, 3840]
:param display: whether or not to display plot
:param cmap:
:return:
"""

amp_range = amp_range or np.quantile(spike_amps, (0, 0.9))
amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins)
d_range = d_range or [0, 3840]
depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin)
t_bin = np.max(spike_times)

def histc(x, bins):
map_to_bins = np.digitize(x, bins) # Get indices of the bins to which each value in input array belongs.
res = np.zeros(bins.shape)

for el in map_to_bins:
res[el - 1] += 1 # Increment appropriate bin.
return res

cdfs = np.empty((len(depth_bins) - 1, n_amp_bins))
for d in range(len(depth_bins) - 1):
spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1])
h = histc(spike_amps[spikes], amp_bins) / t_bin
hcsum = np.cumsum(h[::-1])
cdfs[d, :] = hcsum[::-1]

cdfs[cdfs == 0] = np.nan

data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap)
data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)',
ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)')

if display:
fig, ax = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]})
return data.convert2dict(), fig, ax

return data
55 changes: 37 additions & 18 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,24 @@ def _channels_traj2bunch(xyz_chans, brain_atlas):
return channels


def _channels_bunch2alf(channels):
channels_ = {
'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6,
'brainLocationIds_ccf_2017': channels['atlas_id'],
'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]}
return channels_


def _channels_alf2bunch(channels, brain_regions=None):
# reformat the dictionary according to the standard that comes out of Alyx
channels_ = {
'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6,
'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6,
'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6,
'acronym': None,
'atlas_id': channels['brainLocationIds_ccf_2017']
'atlas_id': channels['brainLocationIds_ccf_2017'],
'axial_um': channels['localCoordinates'][:, 1],
'lateral_um': channels['localCoordinates'][:, 0],
}
if brain_regions:
channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym']
Expand Down Expand Up @@ -207,24 +217,33 @@ def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=
channels_aligned = one.load_object(eid, 'channels', collection=ac_collection)
channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe])
# only have to reformat channels if we were able to load coordinates from disk
channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions)
channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions)
return channels


def channel_locations_interpolation(channels_aligned, channels):
def channel_locations_interpolation(channels_aligned, channels, brain_regions=None):
"""
oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto
if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field
so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts
:param channels_aligned: Bunch or dictionary of aligned channels containing at least keys
'mlapdv' and 'brainLocationIds_ccf_2017' - those are the guide for the interpolation
'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017'
OR
'x', 'y', 'z', 'acronym', 'axial_um'
those are the guide for the interpolation
:param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates'
:return: Bunch or dictionary of channels with extra keys 'mlapdv' and 'brainLocationIds_ccf_2017'
:param brain_regions: None (default) or ibllib.atlas.BrainRegions object
if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017
if a brain region object is provided, outputts a dict with keys
'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um'
:return: Bunch or dictionary of channels with brain coordinates keys
"""
nch = channels['localCoordinates'].shape[0]
if set(['x', 'y', 'z']).issubset(set(channels_aligned.keys())):
channels_aligned = _channels_bunch2alf(channels_aligned)
if 'localCoordinates' in channels_aligned.keys():
aligned_depths = channels_aligned['localCoordinates'][:, 1]
else:
else: # this is a edge case for a few spike sorting sessions
assert channels_aligned['mlapdv'].shape[0] == 384
NEUROPIXEL_VERSION = 1
from ibllib.ephys.neuropixel import trace_header
Expand All @@ -238,7 +257,10 @@ def channel_locations_interpolation(channels_aligned, channels):
# the brain locations have to be interpolated by nearest neighbour
fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest')
channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32)
return channels
if brain_regions is not None:
return _channels_alf2bunch(channels, brain_regions=brain_regions)
else:
return channels


def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False,
Expand Down Expand Up @@ -531,7 +553,7 @@ def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None):
dic_clus : dict of one.alf.io.AlfBunch
1 bunch per probe, containing cluster information
channels : dict of one.alf.io.AlfBunch
1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id')
1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates')
keys_to_add_extra : list of str
Any extra keys to load into channels bunches
Expand All @@ -541,7 +563,7 @@ def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None):
clusters (1 bunch per probe) with new keys values.
"""
probe_labels = list(channels.keys()) # Convert dict_keys into list
keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z']
keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um']

if keys_to_add_extra is None:
keys_to_add = keys_to_add_default
Expand All @@ -550,10 +572,9 @@ def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None):
keys_to_add = list(set(keys_to_add_extra + keys_to_add_default))

for label in probe_labels:
try:
clu_ch = dic_clus[label]['channels']

for key in keys_to_add:
clu_ch = dic_clus[label]['channels']
for key in keys_to_add:
try:
assert key in channels[label].keys() # Check key is in channels
ch_key = channels[label][key]
nch_key = len(ch_key) if ch_key is not None else 0
Expand All @@ -564,11 +585,9 @@ def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None):
f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels'
f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.')
dic_clus[label][key] = []
except AssertionError:
_logger.warning(
f'Either clusters or channels does not have key {label}, could not'
f' merge')
continue
except AssertionError:
_logger.warning(f'Either clusters or channels does not have key {key}, could not merge')
continue

return dic_clus

Expand Down
6 changes: 3 additions & 3 deletions brainbox/io/spikeglx.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def stream(pid, t0, nsecs=1, one=None, cache_folder=None, remove_cached=False, t
samples_folder = Path(one.alyx._par.CACHE_DIR).joinpath('cache', typ)

eid, pname = one.pid2eid(pid)
cbin_rec = one.list_datasets(eid, collection=f"*{pname}", filename='*ap.*bin', details=True)
ch_rec = one.list_datasets(eid, collection=f"*{pname}", filename='*ap.ch', details=True)
meta_rec = one.list_datasets(eid, collection=f"*{pname}", filename='*ap.meta', details=True)
cbin_rec = one.list_datasets(eid, collection=f"*{pname}", filename=f'*{typ}.*bin', details=True)
ch_rec = one.list_datasets(eid, collection=f"*{pname}", filename=f'*{typ}.ch', details=True)
meta_rec = one.list_datasets(eid, collection=f"*{pname}", filename=f'*{typ}.meta', details=True)
ch_file = one._download_datasets(ch_rec)[0]
one._download_datasets(meta_rec)[0]

Expand Down
6 changes: 4 additions & 2 deletions brainbox/task/passive.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,14 @@ def get_stim_aligned_activity(stim_events, spike_times, spike_depths, z_score_fl
stim_activity = {}
for stim_type, stim_times in stim_events.items():

# Get rid of any nan values
stim_times = stim_times[~np.isnan(stim_times)]
stim_intervals = np.c_[stim_times - pre_stim, stim_times + post_stim]
base_intervals = np.c_[stim_times - base_stim, stim_times - pre_stim]
out_intervals = stim_intervals[:, 1] > times[-1]

idx_stim = np.searchsorted(times, stim_intervals)[np.invert(out_intervals)]
idx_base = np.searchsorted(times, base_intervals)[np.invert(out_intervals)]
idx_stim = np.searchsorted(times, stim_intervals, side='right')[np.invert(out_intervals)]
idx_base = np.searchsorted(times, base_intervals, side='right')[np.invert(out_intervals)]

stim_trials = np.zeros((depths.shape[0], n_bins, idx_stim.shape[0]))
noise_trials = np.zeros((depths.shape[0], n_bins_base, idx_stim.shape[0]))
Expand Down
11 changes: 11 additions & 0 deletions ibllib/atlas/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ def _mapping_from_regions_list(self, new_map, lateralize=False):
mapind = mapind[iregion]
return mapind

def remap(self, region_ids, source_map='Allen', target_map='Beryl'):
"""
Remap atlas regions ids from source map to target map
:param region_ids: atlas ids to map
:param source_map: map name which original region_ids are in
:param target_map: map name onto which to map
:return:
"""
_, inds = ismember(region_ids, self.id[self.mappings[source_map]])
return self.id[self.mappings[target_map][inds]]


def regions_from_allen_csv():
"""
Expand Down
3 changes: 2 additions & 1 deletion ibllib/dsp/voltage.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def destripe(x, fs, tr_sel=None, neuropixel_version=1, butter_kwargs=None, k_kwa
x = scipy.signal.sosfiltfilt(sos, x)
# apply ADC shift
if neuropixel_version is not None:
x = fshift(x, h['sample_shift'], axis=1)
sample_shift = h['sample_shift'] if (30000 / fs) < 10 else h['sample_shift'] * fs / 30000
x = fshift(x, sample_shift, axis=1)
# apply spatial filter on good channel selection only
x_ = kfilt(x, **k_kwargs)
return x_
Expand Down
Loading

0 comments on commit 5ca5764

Please sign in to comment.