Skip to content

Commit

Permalink
formatted with black package
Browse files Browse the repository at this point in the history
  • Loading branch information
briangow committed Oct 29, 2024
1 parent c9808ad commit 6cba5a7
Showing 1 changed file with 81 additions and 39 deletions.
120 changes: 81 additions & 39 deletions wfdb/io/record.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import multiprocessing.dummy
import posixpath
import os
import re
Expand Down Expand Up @@ -2822,6 +2821,9 @@ def wrsamp(
sig_name,
p_signal=None,
d_signal=None,
e_p_signal=None,
e_d_signal=None,
samps_per_frame=None,
fmt=None,
adc_gain=None,
baseline=None,
Expand Down Expand Up @@ -2860,6 +2862,14 @@ def wrsamp(
file(s). The dtype must be an integer type. Either p_signal or d_signal
must be set, but not both. In addition, if d_signal is set, fmt, gain
and baseline must also all be set.
e_p_signal : ndarray, optional
The expanded physical conversion of the signal. Either a 2d numpy
array or a list of 1d numpy arrays.
e_d_signal : ndarray, optional
The expanded digital conversion of the signal. Either a 2d numpy
array or a list of 1d numpy arrays.
samps_per_frame : int or list of ints, optional
The total number of samples per frame.
fmt : list, optional
A list of strings giving the WFDB format of each file used to store each
channel. Accepted formats are: '80','212','16','24', and '32'. There are
Expand Down Expand Up @@ -2911,59 +2921,91 @@ def wrsamp(
if "." in record_name:
raise Exception("Record name must not contain '.'")
# Check input field combinations
if p_signal is not None and d_signal is not None:
signal_list = [p_signal, d_signal, e_p_signal, e_d_signal]
signals_set = sum(1 for var in signal_list if var is not None)
if signals_set != 1:
raise Exception(
"Must only give one of the inputs: p_signal or d_signal"
"Must provide one and only one input signal: p_signal, d_signal, e_p_signal, or e_d_signal"
)
if d_signal is not None:
if d_signal is not None or e_d_signal is not None:
if fmt is None or adc_gain is None or baseline is None:
raise Exception(
"When using d_signal, must also specify 'fmt', 'gain', and 'baseline' fields."
"When using d_signal or e_d_signal, must also specify 'fmt', 'gain', and 'baseline' fields"
)
# Depending on whether d_signal or p_signal was used, set other
# required features.

# If samps_per_frame is a list, check that it aligns as expected with the channels in the signal
if len(samps_per_frame) > 1:
# Get properties of the signal being passed
non_none_signal = next(signal for signal in signal_list if signal is not None)
if isinstance(non_none_signal, np.ndarray):
num_sig_channels = non_none_signal.shape[1]
channel_samples = [non_none_signal.shape[0]] * non_none_signal.shape[1]
elif isinstance(non_none_signal, list):
num_sig_channels = len(non_none_signal)
channel_samples = [len(channel) for channel in non_none_signal]
else:
raise TypeError("Unsupported signal format. Must be ndarray or list of lists.")

# Check that the number of channels matches the number of samps_per_frame entries
if num_sig_channels != len(samps_per_frame):
raise Exception(
"When passing samps_per_frame as a list, it must have the same number of entries as the signal has channels"
)

# Check that the number of frames is the same across all channels
frames = [a / b for a, b in zip(channel_samples, samps_per_frame)]
if len(set(frames)) > 1:
raise Exception(
"The number of samples in a channel divided by the corresponding samples_per_frame entry must be uniform"
)

# Create the Record object
record = Record(
record_name=record_name,
p_signal=p_signal,
d_signal=d_signal,
e_p_signal=e_p_signal,
e_d_signal=e_d_signal,
samps_per_frame=samps_per_frame,
fs=fs,
fmt=fmt,
units=units,
sig_name=sig_name,
adc_gain=adc_gain,
baseline=baseline,
comments=comments,
base_time=base_time,
base_date=base_date,
base_datetime=base_datetime,
)

# Depending on which signal was used, set other required fields.
if p_signal is not None:
# Create the Record object
record = Record(
record_name=record_name,
p_signal=p_signal,
fs=fs,
fmt=fmt,
units=units,
sig_name=sig_name,
adc_gain=adc_gain,
baseline=baseline,
comments=comments,
base_time=base_time,
base_date=base_date,
base_datetime=base_datetime,
)
# Compute optimal fields to store the digital signal, carry out adc,
# and set the fields.
record.set_d_features(do_adc=1)
else:
# Create the Record object
record = Record(
record_name=record_name,
d_signal=d_signal,
fs=fs,
fmt=fmt,
units=units,
sig_name=sig_name,
adc_gain=adc_gain,
baseline=baseline,
comments=comments,
base_time=base_time,
base_date=base_date,
base_datetime=base_datetime,
)
elif d_signal is not None:
# Use d_signal to set the fields directly
record.set_d_features()
elif e_p_signal is not None:
# Compute optimal fields to store the digital signal, carry out adc,
# and set the fields.
record.set_d_features(do_adc=1, expanded=True)
elif e_d_signal is not None:
# Use e_d_signal to set the fields directly
record.set_d_features(expanded=True)

# Set default values of any missing field dependencies
record.set_defaults()

# Determine whether the signal is expanded
if (e_d_signal or e_p_signal) is not None:
expanded = True
else:
expanded = False

# Write the record files - header and associated dat
record.wrsamp(write_dir=write_dir)
record.wrsamp(write_dir=write_dir, expanded=expanded)


def dl_database(
Expand Down

0 comments on commit 6cba5a7

Please sign in to comment.