Skip to content

Commit

Permalink
Merge remote-tracking branch 'public/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Dec 1, 2022
2 parents 902a07f + 979935c commit 8031f9c
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 39 deletions.
93 changes: 58 additions & 35 deletions pipeline/export/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@
from pipeline.ingest import ephys as ephys_ingest
from pipeline.ingest import tracking as tracking_ingest
from pipeline.ingest.utils.paths import get_ephys_paths
from nwb_conversion_tools.tools.spikeinterface.spikeinterfacerecordingdatachunkiterator import (
SpikeInterfaceRecordingDataChunkIterator
)
from spikeinterface import extractors
from nwb_conversion_tools.datainterfaces.behavior.movie.moviedatainterface import MovieInterface


ephys_root_data_dir = pathlib.Path(get_ephys_paths()[0])
tracking_root_data_dir = pathlib.Path(tracking_ingest.get_tracking_paths()[0])


# Helper functions for raw ephys data import
Expand Down Expand Up @@ -100,7 +91,7 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):
institution='Janelia Research Campus',
experiment_description=experiment_description,
related_publications='',
keywords=[])
keywords=['electrophysiology'])

# ==================================== SUBJECT ==================================
subject = (lab.Subject * lab.WaterRestriction.proj('water_restriction_number') & session_key).fetch1()
Expand All @@ -109,27 +100,29 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):
date_of_birth=datetime.combine(subject['date_of_birth'], zero_time) if subject['date_of_birth'] else None,
description=subject['water_restriction_number'],
sex=subject['sex'],
species='mus musculus')
species='Mus musculus')

# ==================================== EPHYS ==================================
# add additional columns to the electrodes table
electrodes_query = lab.ProbeType.Electrode * lab.ElectrodeConfig.Electrode
for additional_attribute in ['shank', 'shank_col', 'shank_row']:
for additional_attribute in ['electrode', 'shank', 'shank_col', 'shank_row']:
nwbfile.add_electrode_column(
name=electrodes_query.heading.attributes[additional_attribute].name,
description=electrodes_query.heading.attributes[additional_attribute].comment)

# add additional columns to the units table
if dj.__version__ >= '0.13.0':
units_query = (ephys.ProbeInsertion.RecordingSystemSetup
* ephys.Unit & session_key).proj('unit_amp', 'unit_snr').join(
* ephys.Unit & session_key).proj(
..., '-spike_times', '-spike_sites', '-spike_depths').join(
ephys.UnitStat, left=True).join(
ephys.MAPClusterMetric.DriftMetric, left=True).join(
ephys.ClusterMetric, left=True).join(
ephys.WaveformMetric, left=True)
else:
units_query = (ephys.ProbeInsertion.RecordingSystemSetup
* ephys.Unit & session_key).proj('unit_amp', 'unit_snr').aggr(
* ephys.Unit & session_key).proj(
..., '-spike_times', '-spike_sites', '-spike_depths').aggr(
ephys.UnitStat, ..., **{n: n for n in ephys.UnitStat.heading.names if n not in ephys.UnitStat.heading.primary_key},
keep_all_rows=True).aggr(
ephys.MAPClusterMetric.DriftMetric, ..., **{n: n for n in ephys.MAPClusterMetric.DriftMetric.heading.names if n not in ephys.MAPClusterMetric.DriftMetric.heading.primary_key},
Expand Down Expand Up @@ -157,11 +150,12 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):
# ---- Probe Insertion Location ----
if ephys.ProbeInsertion.InsertionLocation & insert_key:
insert_location = {
k: str(v) for k, v in (ephys.ProbeInsertion.InsertionLocation
& insert_key).aggr(
ephys.ProbeInsertion.RecordableBrainRegion.proj(
..., brain_region='CONCAT(hemisphere, " ", brain_area)'),
..., brain_regions='GROUP_CONCAT(brain_region SEPARATOR ", ")').fetch1().items()
k: str(v) for k, v in (
(ephys.ProbeInsertion.proj() & insert_key).aggr(
ephys.ProbeInsertion.RecordableBrainRegion.proj(
..., brain_region='CONCAT(hemisphere, " ", brain_area)'),
..., brain_regions='GROUP_CONCAT(brain_region SEPARATOR ", ")')
* ephys.ProbeInsertion.InsertionLocation).fetch1().items()
if k not in ephys.ProbeInsertion.primary_key}
insert_location = json.dumps(insert_location)
else:
Expand All @@ -173,7 +167,8 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):
ephys_device_name = f'{electrode_config["probe"]} ({electrode_config["probe_type"]})'
ephys_device = (nwbfile.get_device(ephys_device_name)
if ephys_device_name in nwbfile.devices
else nwbfile.create_device(name=ephys_device_name))
else nwbfile.create_device(name=ephys_device_name,
description=electrode_config["probe_type"]))

electrode_group = nwbfile.create_electrode_group(
name=f'{electrode_config["probe"]} {electrode_config["electrode_config_name"]}',
Expand All @@ -190,15 +185,15 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):

for electrode in electrode_query.fetch(as_dict=True):
nwbfile.add_electrode(
id=electrode['electrode'], group=electrode_group,
electrode=electrode['electrode'], group=electrode_group,
filtering='', imp=-1.,
**electrode_ccf.get(electrode['electrode'], {'x': np.nan, 'y': np.nan, 'z': np.nan}),
rel_x=electrode['x_coord'], rel_y=electrode['y_coord'], rel_z=np.nan,
shank=electrode['shank'], shank_col=electrode['shank_col'], shank_row=electrode['shank_row'],
location=electrode_group.location)

electrode_df = nwbfile.electrodes.to_dataframe()
electrode_ind = electrode_df.index[electrode_df.group_name == electrode_group.name]
electrode_ind = electrode_df.electrode[electrode_df.group_name == electrode_group.name]

# ---- Units ----
unit_query = units_query & insert_key
Expand All @@ -220,6 +215,12 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):

# ---- Raw Ephys Data ---
if raw_ephys:
from spikeinterface import extractors
from nwb_conversion_tools.tools.spikeinterface.spikeinterfacerecordingdatachunkiterator import (
SpikeInterfaceRecordingDataChunkIterator
)
ephys_root_data_dir = pathlib.Path(get_ephys_paths()[0])

ks_dir_relpath = (ephys_ingest.EphysIngest.EphysFile.proj(
..., insertion_number='probe_insertion_number')
& insert_key).fetch('ephys_file')
Expand Down Expand Up @@ -310,13 +311,13 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):

tracking_timestamps = np.hstack([np.arange(nsample) / trk_fs + float(trial_start_time)
for nsample, trial_start_time in zip(samples, start_time)])
position_data = np.vstack([np.hstack(d) for d in position_data])
position_data = np.vstack([np.hstack(d) for d in position_data]).T

behav_ts_name = f'{trk_device_name}_{feature}' + (f'_{r["whisker_name"]}' if r else '')

behav_acq.create_timeseries(name=behav_ts_name,
data=position_data,
timestamps=tracking_timestamps,
timestamps=tracking_timestamps[:position_data.shape[0]],
description=f'Time series for {feature} position: {tuple(ft_attrs)}',
unit='a.u.',
conversion=1.0)
Expand All @@ -327,12 +328,20 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):
* experiment.Photostim & session_key).proj(
'photostim_event_time', 'power', 'duration')
q_trial = experiment.SessionTrial * experiment.BehaviorTrial & session_key
q_trial = q_trial.aggr(
q_photostim, ...,
photostim_onset='IFNULL(GROUP_CONCAT(photostim_event_time SEPARATOR ", "), "N/A")',
photostim_power='IFNULL(GROUP_CONCAT(power SEPARATOR ", "), "N/A")',
photostim_duration='IFNULL(GROUP_CONCAT(duration SEPARATOR ", "), "N/A")',
keep_all_rows=True)
if dj.__version__ >= '0.13.0':
q_trial = q_trial.proj().aggr(
q_photostim, ...,
photostim_onset='IFNULL(GROUP_CONCAT(photostim_event_time SEPARATOR ", "), "N/A")',
photostim_power='IFNULL(GROUP_CONCAT(power SEPARATOR ", "), "N/A")',
photostim_duration='IFNULL(GROUP_CONCAT(duration SEPARATOR ", "), "N/A")',
keep_all_rows=True) * q_trial
else:
q_trial = q_trial.aggr(
q_photostim, ...,
photostim_onset='IFNULL(GROUP_CONCAT(photostim_event_time SEPARATOR ", "), "N/A")',
photostim_power='IFNULL(GROUP_CONCAT(power SEPARATOR ", "), "N/A")',
photostim_duration='IFNULL(GROUP_CONCAT(duration SEPARATOR ", "), "N/A")',
keep_all_rows=True)

skip_adding_columns = experiment.Session.primary_key

Expand Down Expand Up @@ -367,19 +376,21 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):
for trial_event_type in (experiment.TrialEventType & q_trial_event).fetch('trial_event_type'):
trial, event_starts, event_stops = (q_trial_event
& {'trial_event_type': trial_event_type}).fetch(
'trial', 'event_start', 'event_stop', order_by='trial')
'trial', 'event_start', 'event_stop', order_by='trial, event_start')

behavioral_event.create_timeseries(
name=trial_event_type + '_start_times',
unit='a.u.', conversion=1.0,
data=np.full_like(event_starts.astype(float), 1),
timestamps=event_starts.astype(float))
timestamps=event_starts.astype(float),
description=f'Timestamps for event type: {trial_event_type} - Start Time')

behavioral_event.create_timeseries(
name=trial_event_type + '_stop_times',
unit='a.u.', conversion=1.0,
data=np.full_like(event_stops.astype(float), 1),
timestamps=event_stops.astype(float))
timestamps=event_stops.astype(float),
description=f'Timestamps for event type: {trial_event_type} - Stop Time')

# ---- action events

Expand All @@ -396,7 +407,8 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):
name=action_event_type.replace(' ', '_') + '_times',
unit='a.u.', conversion=1.0,
data=np.full_like(event_starts.astype(float), 1),
timestamps=event_starts.astype(float))
timestamps=event_starts.astype(float),
description=f'Timestamps for event type: {action_event_type}')

# ---- photostim events ----

Expand Down Expand Up @@ -426,6 +438,10 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):

# ----- Raw Video Files -----
if raw_video:
from nwb_conversion_tools.datainterfaces.behavior.movie.moviedatainterface import MovieInterface

tracking_root_data_dir = pathlib.Path(tracking_ingest.get_tracking_paths()[0])

tracking_files_info = (tracking_ingest.TrackingIngest.TrackingFile & session_key).fetch(
as_dict=True, order_by='tracking_device, trial')
for tracking_file_info in tracking_files_info:
Expand All @@ -448,7 +464,7 @@ def datajoint_to_nwb(session_key, raw_ephys=False, raw_video=False):
return nwbfile


def export_recording(session_keys, output_dir='./', overwrite=False):
def export_recording(session_keys, output_dir='./', overwrite=False, validate=False):
if not isinstance(session_keys, list):
session_keys = [session_keys]

Expand All @@ -464,3 +480,10 @@ def export_recording(session_keys, output_dir='./', overwrite=False):
with NWBHDF5IO(output_fp.as_posix(), mode='w') as io:
io.write(nwbfile)
print(f'\tWrite NWB 2.0 file: {save_file_name}')
if validate:
import nwbinspector
with NWBHDF5IO(output_fp.as_posix(), mode='r') as io:
validation_status = pynwb.validate(io=io)
print(validation_status)
for inspection_message in nwbinspector.inspect_all(path=output_fp):
print(inspection_message)
6 changes: 2 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ statannot
streamlit
streamlit-aggrid
watchdog
pynwb==2.2.0
h5py==3.3.0
pynwb==2.0.1
streamlit
streamlit-aggrid
watchdog
spikeinterface
nwb-conversion-tools
nwbinspector
opencv-python
dspca
42 changes: 42 additions & 0 deletions scripts/delay_response_NWB_export_Nov2022.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import datajoint as dj
import pathlib

from pipeline import lab, experiment, ephys
from pipeline.export.nwb import export_recording


output_dir = pathlib.Path(r'D:/map/NWB_EXPORT/delay_response')

subjects_to_export = ("SC011", "SC013", "SC015", "SC016", "SC017",
"SC022", "SC023", "SC026", "SC027", "SC030",
"SC031", "SC032", "SC033", "SC035", "SC038",
"SC043", "SC045", "SC048", "SC049", "SC050",
"SC052", "SC053", "SC060", "SC061", "SC064",
"SC065", "SC066", "SC067")


def main(limit=None):
subject_keys = (lab.Subject * lab.WaterRestriction.proj('water_restriction_number')
& f'water_restriction_number in {subjects_to_export}').fetch('KEY')
session_keys = (experiment.Session & ephys.Unit & subject_keys).fetch('KEY', limit=limit)
export_recording(session_keys, output_dir=output_dir, overwrite=False, validate=False)


dandiset_id = os.getenv('DANDISET_ID')
dandi_api_key = os.getenv('DANDI_API_KEY')


def publish_to_dandi(dandiset_id, dandi_api_key):
from element_interface.dandi import upload_to_dandi

dandiset_dir = output_dir / 'dandi'
dandiset_dir.mkdir(parents=True, exist_ok=True)

upload_to_dandi(
data_directory=output_dir,
dandiset_id=dandiset_id,
staging=False,
working_directory=dandiset_dir,
api_key=dandi_api_key,
sync=True)

0 comments on commit 8031f9c

Please sign in to comment.