Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing list of dict as argument to events parameter (=allow writing mne.annotations to VMRK) #86

Merged
merged 21 commits into from
May 27, 2022
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 171 additions & 49 deletions pybv/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,82 +34,120 @@ def write_brainvision(*, data, sfreq, ch_names,
unit='µV',
fmt='binary_float32',
meas_date=None):
"""Write raw data to BrainVision format [1]_.
"""Write raw data to the BrainVision format [1]_.

Parameters
----------
data : np.ndarray, shape (n_channels, n_times)
The raw data to export. Voltage data is assumed to be in **Volts** and
will be scaled as specified by ``unit``. Non-voltage channels (as
specified by ``unit``) are never scaled (e.g. ``'°C'``).
will be scaled as specified by `unit`. Non-voltage channels (as
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
specified by `unit`) are never scaled (e.g. '°C').
sfreq : int | float
The sampling frequency of the data.
The sampling frequency of the data in Hz.
ch_names : list of {str | int}, len (n_channels)
The names of the channels.
ref_ch_names : str | list of str, len (n_channels) | None
The name of the channel used as a reference during the recording. If
references differed between channels, you may supply a list of
reference channel names corresponding to each channel in ``ch_names``.
reference channel names corresponding to each channel in `ch_names`.
If ``None`` (default), assume that all channels are referenced to a
common channel that is not further specified (BrainVision default).

.. note:: The reference channel name specified here does not need to
appear in ``ch_names``. It is permissible to specify a
reference channel that is not present in ``data``.
appear in `ch_names`. It is permissible to specify a
reference channel that is not present in `data`.
fname_base : str
The base name for the output files. Three files will be created
(.vhdr, .vmrk, .eeg) and all will share this base name.
(*.vhdr*, *.vmrk*, *.eeg*) and all will share this base name.
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
folder_out : str
The folder where output files will be saved. Will be created if it does
not exist yet.
overwrite : bool
Whether or not to overwrite existing files. Defaults to False.
events : np.ndarray, shape (n_events, 2) or (n_events, 3) | None
Events to write in the marker file. This array has either two or three
columns. The first column is always the zero-based index of each event
(corresponding to the "time" dimension of the data array). The second
column is a number associated with the "type" of event. The (optional)
third column specifies the length of each event (default 1 sample).
Currently all events are written as type "Stimulus" and must be
numeric. Defaults to None (not writing any events).
Whether or not to overwrite existing files. Defaults to ``False``.
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
events : np.ndarray, shape (n_events, {2, 3}) | list of dict, len (n_events) | None
Events to write in the marker file (*.vmrk*). Defaults to ``None``
(not writing any events).

If an array is passed, it must have either two or three columns.
The first column is always the zero-based index of each event
(corresponding to the "time" dimension of the `data` array).
The second column is a number associated with the "description" of
event. The (optional) third column specifies the length of each event
(default 1 sample). All events are written as type "Stimulus" and
must be numeric. For more fine grained control over how to write
events, pass a list of dict as described next.

If list of dict is passed, each dict in the list corresponds to an
event and may have the following entries:

- ``onset`` : int
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
The zero-based index of the event onset, corresponding to the
"time" dimension of the `data` array.
- ``duration`` : int
The duration of the event in samples (defaults to ``1``).
- ``description`` : str | int
The description of the event. Must be an integer when `type`
(see below) is either "Stimulus" or "Response", and may be
a string when `type` is "Comment".
- ``type`` : str
The type of the event, must be one of {"Stimulus", "Comment",
"Response"} (defaults to ``"Stimulus"``). The following
known BrainVision "types" are currently **not** supported:
"New Segment", "SyncStatus".
- ``channels`` : str | list of {str | int}
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
The channels that are impacted by the event. Can be "all"
(reflecting all channels) or a channel name, or a list of
channel names. Defaults to ``"all"``.

Note that ``onset`` and ``description`` MUST be specified in each
dict.

.. note:: When specifying more than one but less than "all" channels
that are impacted by an event, ``pybv`` will write the same
event for as many times as channels are specified (see
:gh:`77` for a discussion). This is valid according to the
BrainVision specification, however for maximum compatibility
with other BrainVision readers, we do not (yet) recommend
using this feature.

resolution : float | np.ndarray, shape (n_channels,)
The resolution in `unit` in which you'd like the data to be stored. If
float, the same resolution is applied to all channels. If ndarray with
float, the same resolution is applied to all channels. If array with
n_channels elements, each channel is scaled with its own corresponding
resolution from the ndarray. Note that `resolution` is applied on top
resolution from the array. Note that `resolution` is applied on top
of the default resolution that a data format (see `fmt`) has. For
example, the binary_int16 format by design has no floating point
example, the 'binary_int16' format by design has no floating point
support, but when scaling the data in µV for 0.1 resolution (default),
accurate writing for all values >= 0.1 µV is guaranteed. In contrast,
the binary_float32 format by design already supports floating points up
to 1e-6 resolution, and writing data in µV with 0.1 resolution will
thus guarantee accurate writing for all values >= 1e-7 µV
the 'binary_float32' format by design already supports floating points
up to 1e-6 resolution, and writing data in 'µV' with 0.1 resolution
will thus guarantee accurate writing for all values >= 1e-7 'µV'
(``1e-6 * 0.1``).
unit : str | list of str
The unit of the exported data. This can be one of 'V', 'mV', 'µV' (or
equivalently 'uV') , or 'nV', which will scale the data accordingly.
Defaults to 'µV'. Can also be a list of units with one unit per
channel. Non-voltage channels are stored as is, for example temperature
might be available in ``°C``, which ``pybv`` will not scale.
Defaults to ``'µV'``. Can also be a list of units with one unit per
channel. Non-voltage channels are stored as is, for example
temperature might be available in '°C', which ``pybv`` will not scale.
fmt : str
Binary format the data should be written as. Valid choices are
'binary_float32' (default) and 'binary_int16'.
meas_date : datetime.datetime | str | None
The measurement date specified as a datetime.datetime object.
Alternatively, can be a string in the format 'YYYYMMDDhhmmssuuuuuu'
The measurement date specified as a :class:`datetime.datetime` object.
Alternatively, can be a str in the format 'YYYYMMDDhhmmssuuuuuu'
('u' stands for microseconds). Note that setting a measurement date
implies that one additional event is created in the .vmrk file. To
prevent this, set this parameter to None (default).
implies that one additional event is created in the *.vmrk* file. To
prevent this, set this parameter to ``None`` (default).

Notes
-----
iEEG/EEG/MEG data is assumed to be in V, and we will scale these data to µV
by default. Any unit besides µV is officially unsupported in the
iEEG/EEG/MEG data is assumed to be in 'V', and we will scale these data to
'µV' by default. Any unit besides 'µV' is officially unsupported in the
BrainVision specification. However, if one specifies other voltage units
such as 'mV' or 'nV', we will still scale the signals accordingly in the
exported file. We will also write channels with non-voltage units such as
``°C`` as is (without scaling). For maximum compatibility, all signals
should be written as µV.
'°C' as is (without scaling). For maximum compatibility, all signals
should be written as 'µV'.

References
----------
Expand All @@ -130,25 +168,11 @@ def write_brainvision(*, data, sfreq, ch_names,
>>> # remove the files
>>> for ext in ['.vhdr', '.vmrk', '.eeg']:
... os.remove('pybv_test_file' + ext)
"""
""" # noqa: E501
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
# Input checks
if not isinstance(overwrite, bool):
raise ValueError("overwrite must be a boolean (True or False).")

ev_err = ("events must be an ndarray of shape (n_events, 2) or "
"(n_events, 3) containing numeric values, or None")
if not isinstance(events, (np.ndarray, type(None))):
raise ValueError(ev_err)
if isinstance(events, np.ndarray):
if events.ndim != 2:
raise ValueError(ev_err)
if events.shape[1] not in (2, 3):
raise ValueError(ev_err)
try:
events.astype(float)
except ValueError:
raise ValueError(ev_err)

nchan = len(ch_names)
for ch in ch_names:
if not isinstance(ch, (str, int)):
Expand All @@ -162,6 +186,8 @@ def write_brainvision(*, data, sfreq, ch_names,
if len(set(ch_names)) != nchan:
raise ValueError("Channel names must be unique, found duplicate name.")

events = _chk_events(events, ch_names)

# Ensure we have a list of strings as reference channel names
if ref_ch_names is None:
ref_ch_names = [''] * nchan # common but unspecified reference
Expand Down Expand Up @@ -281,6 +307,102 @@ def write_brainvision(*, data, sfreq, ch_names,
raise


def _chk_events(events, ch_names):
"""Check that the events parameter is as expected.

This function may change events in-place. It will add missing keys with
default values, and it will turn events[i]["channels"] into a list of
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
1-based channel name indices, where 0 = "all".
"""
if not isinstance(events, (type(None), np.ndarray, list)):
raise ValueError("events must be an array, a list of dict, or None")

# validate input: None
if isinstance(events, type(None)):
return events

# validate input: ndarray
if isinstance(events, np.ndarray):
if events.ndim != 2:
raise ValueError(f"When array, events must be 2D, but got {events.ndim}")
if events.shape[1] not in (2, 3):
raise ValueError(f"When array, events must have 2 or 3 columns, but got: {events.shape[1]}")
try:
events.astype(float)
except ValueError:
raise ValueError("When array, events must be numeric, but found non-numeric types")
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved

# validate input: list of dict
assert isinstance(events, list) # must be true
for event in list:

# each item must be dict
if not isinstance(event, dict):
raise ValueError("When list, events must be a list of dict, but found non-dict element in list")

# required keys
for required_key in ["onset", "description"]:
if required_key not in event:
raise ValueError("When list of dict, each dict in events must have the keys 'onset' and 'description'")

# populate keys with default if missing (in-place)
# NOTE: using "ch_names" as default for channels translates directly
# into "all" but is robust with respect to channels named
# "all"
event_defaults = dict(duration=1, type="Stimulus", channels=ch_names)
for optional_key, default in event_defaults.items():
event[optional_key] = event.get(optional_key, default)

# validate key types
# `onset`, `duration`
for key in ["onset", "duration"]:
if not isinstance(event[key], int):
raise ValueError(f"events: `{key}` must be int")
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved

# `type`
event_types = ["Stimulus", "Response", "Comment"]
if event["type"] not in event_types:
raise ValueError(f"events: `type` must be one of {event_types}")

# `description`
if event["type"] in ["Stimulus", "Response"]:
if not isinstance(event["description"], int):
raise ValueError(f"events: when `type` is {event['type']}, `description` must be int")
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
else:
assert event["type"] == "Comment"
if not isinstance(event["description"], (int, str)):
raise ValueError(f"events: when `type` is {event['type']}, `description` must be str or int")

# `channels`
# "all" becomes ch_names (list of all channel names)
# single str 'ch_name' becomes [ch_name]
if not isinstance(event["channels"], (list, str)):
raise ValueError("events: `channels` must be str or list of str")

if isinstance(event["channels"], str):
if event["channels"] == "all":
if "all" in ch_names:
raise ValueError("Found channel named 'all'. Your `channels` specification in events is also 'all': This is ambiguous, because 'all' is a reserved keyword. Either rename the channel called 'all', or explicitly list all ch_names in `channels` in each event instead of using 'all'")
event["channels"] = ch_names
else:
event["channels"] = [event["channels"]]

sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
# now channels is a list
for ch in event["channels"]:
if not isinstance(ch, (str, int)):
raise ValueError("events: `channels` must be list of str or int corresponding to ch_names")

if str(ch) not in ch_names:
raise ValueError(f"events: found channel name that is not present in the data: {ch}")

# check for duplicates
event["channels"] = [str(ch) for ch in event["channels"]]
if len(set(event["channels"])) != len(event["channels"]):
raise ValueError("events: found duplicate channel names")

return events


def _chk_fmt(fmt):
"""Check that the format string is valid, return (BV, numpy) datatypes."""
if fmt not in SUPPORTED_FORMATS:
Expand Down