Skip to content

Commit

Permalink
Merge pull request #119 from JamesOwers/amt
Browse files Browse the repository at this point in the history
Amt
  • Loading branch information
JamesOwers authored Jul 22, 2020
2 parents 993e82a + 2054062 commit 17393c4
Show file tree
Hide file tree
Showing 10 changed files with 1,146 additions and 20 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Tools to generate datasets of Altered and Corrupted MIDI Excerpts -`ACME`
datasets.

The accompanying paper (submitted to ICASSP, available upon request)
The accompanying paper (submitted to ISMIR)
"Symbolic Music Correction using The MIDI Degradation Toolkit" describes the
toolkit and its motivation in detail. For instructions to reproduce the results
from the paper, see [`./baselines/README.md`](./baselines/README.md).
Expand Down Expand Up @@ -38,6 +38,8 @@ Some highlights include:
data for use
* [`mdtk.degradations`](./mdtk/degradations.py) - functions to alter midi data
e.g. `pitch_shift` or `time_shift`
* [`mdtk.degrader`](./mdtk/degrader.py) - Degrader class that can be used to
degrade data points randomly on the fly
* [`mdtk.eval`](./mdtk/eval.py) - functions for evaluating model performance
on each task, given a list of outputs and targets
* [`mdtk.formatters`](./mdtk/formatters.py) - functions converting between
Expand Down Expand Up @@ -71,3 +73,10 @@ pip install . # use pip install -e . for dev mode if you want to edit files

To generate an `ACME` dataset simply install the package with instructions
above and run `./make_dataset.py`.

For usage instructions for the `measure_errors.py` script, run
`python measure_errors.py -h` you should create a directory of transcriptions
and a directory of ground truth files (in mid or csv format). The ground truth
and corresponding transcription should be named the exact same thing.

See `measure_errors_example.ipynb` for an example of the script's usage.
17 changes: 17 additions & 0 deletions make_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def parse_args(args_input=None):
parser.add_argument('-i', '--input-dir', type=str, default=default_indir,
help='the directory to store the preprocessed '
'downloaded data to.')
parser.add_argument('--config', default=None, help='Load a json config '
'file, in the format created by measure_errors.py. '
'This will override --degradations, --degradation-'
'dist, and --clean-prop.')
parser.add_argument('--formats', metavar='format', help='Create '
'custom versions of the acme data for easier loading '
'with our provided pytorch Dataset classes. Choices are'
Expand Down Expand Up @@ -175,6 +179,7 @@ def parse_args(args_input=None):
seed = ARGS.seed
print(f'Setting random seed to {seed}.')
np.random.seed(seed)

# Check given degradation_kwargs
assert (ARGS.degradation_kwargs is None or
ARGS.degradation_kwarg_json is None), ("Don't specify both "
Expand All @@ -185,6 +190,18 @@ def parse_args(args_input=None):
degradation_kwargs = parse_degradation_kwargs(
ARGS.degradation_kwarg_json
)

# Load config
if ARGS.config is not None:
with open(ARGS.config, 'r') as file:
config = json.load(file)
if ARGS.verbose:
print(f'Loading from config file {ARGS.config}.')
if 'degradation_dist' in config:
ARGS.degradation_dist = np.array(config['degradation_dist'])
ARGS.degradations = list(degradations.DEGRADATIONS.keys())
if 'clean_prop' in config:
ARGS.clean_prop = config['clean_prop']
# Warn user they specified kwargs for degradation not being used
for deg, args in degradation_kwargs.items():
if deg not in ARGS.degradations and len(args) > 0:
Expand Down
2 changes: 1 addition & 1 deletion mdtk/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Description

The accompanying paper (submitted to ICASSP, available upon request) referenced
The accompanying paper (submitted to ISMIR) referenced
below is "Symbolic Music Correction using The MIDI Degradation Toolkit" and
describes the toolkit and its motivation in detail.

Expand Down
3 changes: 3 additions & 0 deletions mdtk/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ def fix_overlaps(df):
df.loc[prev_idx, 'offset'] = note.onset
current_offset = max(current_offset, note.offset)
df.loc[idx, 'offset'] = current_offset
else:
# No overlap. Update latest offset.
current_offset = note.offset
# Always iterate, but no need to update current_offset here,
# because it will definitely be < next_note.onset (because sorted).
prev_idx = idx
Expand Down
20 changes: 11 additions & 9 deletions mdtk/degradations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
MIN_DURATION_DEFAULT = 50
MAX_DURATION_DEFAULT = np.inf

MAX_GAP_DEFAULT = 50

TRIES_DEFAULT = 10

TRIES_WARN_MSG = ("WARNING: Generated invalid (overlapping) degraded excerpt "
Expand Down Expand Up @@ -259,7 +261,7 @@ def pitch_shift(excerpt, min_pitch=MIN_PITCH_DEFAULT,
'distribution[zero_idx] to 0). Returning None.')
return None

degraded = excerpt
degraded = excerpt.copy()

# Sample a random note
note_index = valid_notes[randint(len(valid_notes))]
Expand Down Expand Up @@ -395,7 +397,7 @@ def time_shift(excerpt, min_shift=MIN_SHIFT_DEFAULT,
else:
onset = split_range_sample([(eeo, leo), (elo, llo)])

degraded = excerpt
degraded = excerpt.copy()

degraded.loc[index, 'onset'] = onset

Expand Down Expand Up @@ -575,7 +577,7 @@ def onset_shift(excerpt, min_shift=MIN_SHIFT_DEFAULT,
# No alignment
onset = split_range_sample([(elo, llo), (eso, lso)])

degraded = excerpt
degraded = excerpt.copy()

degraded.loc[index, 'onset'] = onset
degraded.loc[index, 'dur'] = offset[index] - onset
Expand Down Expand Up @@ -706,7 +708,7 @@ def offset_shift(excerpt, min_shift=MIN_SHIFT_DEFAULT,
else:
duration = split_range_sample([(ssd, lsd), (sld, lld)])

degraded = excerpt
degraded = excerpt.copy()

degraded.loc[index, 'dur'] = duration

Expand Down Expand Up @@ -889,7 +891,7 @@ def add_note(excerpt, min_pitch=MIN_PITCH_DEFAULT, max_pitch=MAX_PITCH_DEFAULT,
'dur': duration,
'track': track}

degraded = excerpt
degraded = excerpt.copy()
degraded = degraded.append(note, ignore_index=True)

# Check if overlaps
Expand Down Expand Up @@ -977,7 +979,7 @@ def split_note(excerpt, min_duration=MIN_DURATION_DEFAULT, num_splits=1,
onsets[i] = int(round(this_onset))
durs[i] = int(round(next_onset)) - int(round(this_onset))

degraded = excerpt
degraded = excerpt.copy()
degraded.loc[note_index]['dur'] = int(round(short_duration_float))
new_df = pd.DataFrame({'onset': onsets,
'track': tracks,
Expand All @@ -991,8 +993,8 @@ def split_note(excerpt, min_duration=MIN_DURATION_DEFAULT, num_splits=1,


@set_random_seed
def join_notes(excerpt, max_gap=50, max_notes=20, only_first=False,
tries=TRIES_DEFAULT):
def join_notes(excerpt, max_gap=MAX_GAP_DEFAULT, max_notes=20,
only_first=False, tries=TRIES_DEFAULT):
"""
Combine two notes of the same pitch and track into one.
Expand Down Expand Up @@ -1083,7 +1085,7 @@ def join_notes(excerpt, max_gap=50, max_notes=20, only_first=False,
start = valid_starts[index]
nexts = valid_nexts[index]

degraded = excerpt
degraded = excerpt.copy()

# Extend first note
degraded.loc[start]['dur'] = (degraded.loc[nexts[-1]]['onset'] +
Expand Down
149 changes: 149 additions & 0 deletions mdtk/degrader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""A degrader object can be used to easily degrade data points on the fly
according to some given parameters."""
import json
import numpy as np
import warnings

import mdtk.degradations as degs

class Degrader():
"""A Degrade object can be used to easily degrade musical excerpts
on the fly."""

def __init__(self, seed=None, degradations=list(degs.DEGRADATIONS.keys()),
degradation_dist=np.ones(len(degs.DEGRADATIONS)),
clean_prop=1 / (len(degs.DEGRADATIONS) + 1), config=None):
"""
Create a new degrader with the given parameters.
Parameters
----------
seed : int
A random seed for numpy.
degradations : list(string)
A list of the names of the degradations to use (and in what order
to label them).
degradation_dist : list(float)
A list of the probability of each degradation given in
degradations. This list will be normalized to sum to 1.
clean_prop : float
The proportion of degrade calls that should return clean excerpts.
config : string
The path of a json config file (created by measure_errors.py).
If given, degradations, degradation_dist, and clean_prop will
all be overwritten by the values in the json file.
"""
if seed is not None:
np.random.seed(seed)

# Load config
if config is not None:
with open(config, 'r') as file:
config = json.load(file)

if 'degradation_dist' in config:
degradation_dist = np.array(config['degradation_dist'])
degradations = list(degs.DEGRADATIONS.keys())
if 'clean_prop' in config:
clean_prop = config['clean_prop']

# Check arg validity
assert len(degradation_dist) == len(degradations), (
"Given degradation_dist is not the same length as degradations:"
f"\nlen({degradation_dist}) != len({degradations})"
)
assert min(degradation_dist) >= 0, ("degradation_dist values must "
"not be negative.")
assert sum(degradation_dist) > 0, ("Some degradation_dist value "
"must be positive.")
assert 0 <= clean_prop <= 1, ("clean_prop must be between 0 and 1 "
"(inclusive).")

self.degradations = degradations
self.degradation_dist = degradation_dist
self.clean_prop = clean_prop
self.failed = np.zeros(len(degradations))


def degrade(self, note_df):
"""
Degrade the given note_df.
Parameters
----------
note_df : pd.DataFrame
A note_df to degrade.
Returns
-------
degraded_df : pd.DataFrame
A degraded version of the given note_df. If self.clean_prop > 0,
this can be a copy of the given note_df.
deg_label : int
The label of the degradation that was performed. 0 means none,
and larger numbers mean the degradation
"self.degradations[deg_label+1]" was performed.
"""
if self.clean_prop > 0 and np.random.rand() <= self.clean_prop:
return note_df.copy(), 0

degraded_df = None
this_deg_dist = self.degradation_dist.copy()
this_failed = self.failed.copy()

# First, sample from failed degradations
while np.any(this_failed > 0):
# Select a degradation proportional to how many have failed
deg_index = np.random.choice(
len(self.degradations),
p=this_failed / np.sum(this_failed)
)
deg_fun = degs.DEGRADATIONS[self.degradations[deg_index]]

# Try to degrade
with warnings.catch_warnings():
warnings.simplefilter("ignore")
degraded_df = deg_fun(note_df)

# Check for success!
if degraded_df is not None:
self.failed[deg_index] -= 1
return degraded_df, deg_index + 1

# Degradation failed -- 0 out this deg and continue
this_failed[deg_index] = 0

# No degradations have remaining failures. Draw from standard dist
while np.any(this_deg_dist > 0):
# Select a degradation proportional to the distribution
deg_index = np.random.choice(
len(self.degradations),
p=this_deg_dist / np.sum(this_deg_dist)
)
# This deg would have already failed in the above loop.
# But we want to sample it and count it as another failure.
if self.failed[deg_index] > 0:
self.failed[deg_index] += 1
continue
deg_fun = degs.DEGRADATIONS[self.degradations[deg_index]]

# Try to degrade
with warnings.catch_warnings():
warnings.simplefilter("ignore")
degraded_df = deg_fun(note_df)

# Check for success!
if degraded_df is not None:
return degraded_df, deg_index + 1

# Degradation failed -- add 1 to failure and continue
self.failed[deg_index] += 1

# Here, all degradations (with dist > 0) failed
return note_df.copy(), 0

3 changes: 2 additions & 1 deletion mdtk/filesystem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import os
import sys
import shutil
import urllib
import urllib.request
import urllib.error
import warnings
import zipfile

Expand Down
16 changes: 8 additions & 8 deletions mdtk/tests/test_data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@
'ch' :[0, 1, 0, 0, 0, 0]
})
note_df_complex_overlap = pd.DataFrame({
'onset': [50, 75, 150, 200, 200, 300, 300, 300],
'track': [0, 0, 0, 0, 0, 0, 0, 1],
'pitch': [10, 10, 20, 10, 20, 30, 30, 10],
'dur': [300, 25, 100, 125, 50, 50, 100, 100]
'onset': [0, 50, 75, 150, 200, 200, 300, 300, 300],
'track': [0, 0, 0, 0, 0, 0, 0, 0, 1],
'pitch': [10, 10, 10, 20, 10, 20, 30, 30, 10],
'dur': [50, 300, 25, 100, 125, 50, 50, 100, 100]
})
note_df_complex_overlap_fixed = pd.DataFrame({
'onset': [50, 75, 150, 200, 200, 300, 300],
'track': [0, 0, 0, 0, 0, 0, 1],
'pitch': [10, 10, 20, 10, 20, 30, 10],
'dur': [25, 125, 50, 150, 50, 100, 100]
'onset': [0, 50, 75, 150, 200, 200, 300, 300],
'track': [0, 0, 0, 0, 0, 0, 0, 1],
'pitch': [10, 10, 10, 20, 10, 20, 30, 10],
'dur': [50, 25, 125, 50, 150, 50, 100, 100]
})
# midinote keyboard range from 0 to 127 inclusive
all_midinotes = list(range(0, 128))
Expand Down
Loading

0 comments on commit 17393c4

Please sign in to comment.