Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiAnnotator class #515

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions mirdata/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@ def __repr__(self):
return repr_str


class MultiAnnotator(Annotation):
magdalenafuentes marked this conversation as resolved.
Show resolved Hide resolved
"""Multiple annotator class.
This class should be used for datasets with multiple annotators (e.g. multiple annotators per track).

Attributes:
magdalenafuentes marked this conversation as resolved.
Show resolved Hide resolved
annotators (list): list with annotator ids
annotations (list): list of annotations (e.g. [annotations.BeatData, annotations.ChordData]
magdalenafuentes marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, annotators, annotations) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these two inputs would be better as a single dictionary:
{ 'annotator-id': Annotation }

same with how it's stored, self.annotations = {'annotator1': ... , 'annotator2', ...}
because it will make it easier to look up a specific dictionary instead of having to look for matching indexes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this out and noticed that some tracks have the same annotator id for different annotations in the same track (which might be a bit weird actually... but it's there :/). Also the sintax gets a bit messy, since I have to iterate in the keys each time I want to access an annotation because the annotator id changes from track to track. So I changed back to what was before for now. Let me know what you think

validate_array_like(annotators, list, str, none_allowed=True)
validate_array_like(
annotations, list, Annotation, check_child=True, none_allowed=True
magdalenafuentes marked this conversation as resolved.
Show resolved Hide resolved
)
validate_lengths_equal([annotators, annotations])

self.annotators = annotators
self.annotations = annotations
magdalenafuentes marked this conversation as resolved.
Show resolved Hide resolved


class BeatData(Annotation):
"""BeatData class

Expand Down Expand Up @@ -1383,7 +1403,9 @@ def closest_index(input_array, target_array):
return indexes


def validate_array_like(array_like, expected_type, expected_dtype, none_allowed=False):
def validate_array_like(
array_like, expected_type, expected_dtype, check_child=False, none_allowed=False
):
"""Validate that array-like object is well formed

If array_like is None, validation passes automatically.
Expand All @@ -1392,11 +1414,12 @@ def validate_array_like(array_like, expected_type, expected_dtype, none_allowed=
array_like (array-like): object to validate
expected_type (type): expected type, either list or np.ndarray
expected_dtype (type): expected dtype
check_child (bool): if True, checks if all elements of array are children of expected_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if there are multiple children with different expected dtypes (e.g a beat annotation with ints for the beat positions and floats for the time stamps?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we're only checking if the list is composed of BeatData, ChordData, etc. The checks within the annotation type are done for each annotation type independently

none_allowed (bool): if True, allows array to be None

Raises:
TypeError: if type/dtype does not match expected_type/expected_dtype
ValueError: if array
ValueError: if array is empty but it shouldn't be

"""
if array_like is None:
Expand All @@ -1416,8 +1439,20 @@ def validate_array_like(array_like, expected_type, expected_dtype, none_allowed=
)

if expected_type == list and not all(
isinstance(n, expected_dtype) for n in array_like
isinstance(n, expected_dtype)
for n in array_like
if not ((n is None) and none_allowed)
):
if check_child:
if not all(
issubclass(type(n), expected_dtype)
for n in array_like
if not ((n is None) and none_allowed)
):
raise TypeError(
f"List elements should all be instances of {expected_dtype} class"
)

raise TypeError(f"List elements should all have type {expected_dtype}")

if (
Expand Down
46 changes: 28 additions & 18 deletions mirdata/datasets/salami.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,30 @@ def genre(self):
return self._track_metadata.get("genre")

@core.cached_property
def sections_annotator_1_uppercase(self) -> Optional[annotations.SectionData]:
return load_sections(self.sections_annotator1_uppercase_path)

@core.cached_property
def sections_annotator_1_lowercase(self) -> Optional[annotations.SectionData]:
return load_sections(self.sections_annotator1_lowercase_path)

@core.cached_property
def sections_annotator_2_uppercase(self) -> Optional[annotations.SectionData]:
return load_sections(self.sections_annotator2_uppercase_path)
def sections_uppercase(self) -> Optional[annotations.MultiAnnotator]:
magdalenafuentes marked this conversation as resolved.
Show resolved Hide resolved
return annotations.MultiAnnotator(
[
self._track_metadata.get("annotator_1_id"),
self._track_metadata.get("annotator_2_id"),
],
[
load_sections(self.sections_annotator1_uppercase_path),
load_sections(self.sections_annotator2_uppercase_path),
],
)

@core.cached_property
def sections_annotator_2_lowercase(self) -> Optional[annotations.SectionData]:
return load_sections(self.sections_annotator2_lowercase_path)
def sections_lowercase(self) -> Optional[annotations.MultiAnnotator]:
return annotations.MultiAnnotator(
[
self._track_metadata.get("annotator_1_id"),
self._track_metadata.get("annotator_2_id"),
],
[
load_sections(self.sections_annotator1_lowercase_path),
load_sections(self.sections_annotator2_lowercase_path),
],
)

@property
def audio(self) -> Tuple[np.ndarray, float]:
Expand All @@ -196,17 +206,17 @@ def to_jams(self):
multi_section_data=[
(
[
(self.sections_annotator_1_uppercase, 0),
(self.sections_annotator_1_lowercase, 1),
(self.sections_uppercase.annotations[0], 0),
(self.sections_lowercase.annotations[0], 1),
],
"annotator_1",
self.sections_lowercase.annotators[0],
),
(
[
(self.sections_annotator_2_uppercase, 0),
(self.sections_annotator_2_lowercase, 1),
(self.sections_uppercase.annotations[1], 0),
(self.sections_lowercase.annotations[1], 1),
],
"annotator_2",
self.sections_lowercase.annotators[1],
),
],
metadata=self._track_metadata,
Expand Down
22 changes: 10 additions & 12 deletions tests/datasets/test_salami.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ def test_track():
}

expected_property_types = {
"sections_annotator_1_uppercase": annotations.SectionData,
"sections_annotator_1_lowercase": annotations.SectionData,
"sections_annotator_2_uppercase": annotations.SectionData,
"sections_annotator_2_lowercase": annotations.SectionData,
"sections_uppercase": annotations.MultiAnnotator,
"sections_lowercase": annotations.MultiAnnotator,
"audio": tuple,
}

Expand Down Expand Up @@ -81,10 +79,10 @@ def test_track():
}

# test that cached properties don't fail and have the expected type
assert type(track.sections_annotator_1_uppercase) is annotations.SectionData
assert type(track.sections_annotator_1_lowercase) is annotations.SectionData
assert track.sections_annotator_2_uppercase is None
assert track.sections_annotator_2_lowercase is None
assert type(track.sections_uppercase) is annotations.MultiAnnotator
assert type(track.sections_lowercase) is annotations.MultiAnnotator
assert track.sections_uppercase.annotations[1] is None
assert track.sections_lowercase.annotations[1] is None

# Test file with missing annotations
track = dataset.track("1015")
Expand All @@ -104,10 +102,10 @@ def test_track():
}

# test that cached properties don't fail and have the expected type
assert track.sections_annotator_1_uppercase is None
assert track.sections_annotator_1_lowercase is None
assert type(track.sections_annotator_2_uppercase) is annotations.SectionData
assert type(track.sections_annotator_2_lowercase) is annotations.SectionData
assert track.sections_uppercase.annotations[0] is None
assert track.sections_lowercase.annotations[0] is None
assert type(track.sections_uppercase) is annotations.MultiAnnotator
assert type(track.sections_lowercase) is annotations.MultiAnnotator


def test_to_jams():
Expand Down
29 changes: 29 additions & 0 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,35 @@ def __init__(self):
)


def test_multiannotator():
# test good data
annotators = ["annotator_1", "annotator_2"]
labels_1 = ["Vocals", "Guitar"]
labels_2 = ["Vocals", "Drums"]
intervals_1 = np.array([[0.0, 0.1], [0.5, 1.5]])
intervals_2 = np.array([[0.0, 1.0], [0.5, 1.0]])
multi_annot = [
annotations.EventData(intervals_1, "s", labels_1, "open"),
annotations.EventData(intervals_2, "s", labels_2, "open"),
]
events = annotations.MultiAnnotator(annotators, multi_annot)

assert events.annotations[0].events == labels_1
assert events.annotators[1] == "annotator_2"
assert np.allclose(events.annotations[1].intervals, intervals_2)

# test bad data
bad_labels = ["Is a", "Number", 5]
pytest.raises(TypeError, annotations.MultiAnnotator, annotators, bad_labels)
pytest.raises(TypeError, annotations.MultiAnnotator, [0, 1], multi_annot)
pytest.raises(
TypeError,
annotations.MultiAnnotator,
annotators,
[["bad", "format"], ["indeed"]],
)


def test_beat_data():
times = np.array([1.0, 2.0])
positions = np.array([3, 4])
Expand Down