Skip to content

Commit

Permalink
Merge pull request #109 from JamesOwers/rule
Browse files Browse the repository at this point in the history
Rule
  • Loading branch information
JamesOwers authored Jan 9, 2020
2 parents 1d9255c + e057274 commit 993e82a
Showing 1 changed file with 118 additions and 0 deletions.
118 changes: 118 additions & 0 deletions baselines/rule_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/usr/bin/env python

import numpy as np
import torch.nn
import mdtk.pytorch_datasets
from mdtk.formatters import FORMATTERS, create_corpus_csvs
import argparse
import os

def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument("-i", "--input", default='acme', help='The '
'base directory of the ACME dataset to use as input.')
parser.add_argument("-s", "--seq_len", type=int, default=250,
help="maximum sequence length for a pianoroll.")
parser.add_argument("--reformat", action="store_true", help="Force the "
"creation of the pianoroll csvs, even if they "
"already exist.")

args = parser.parse_args()
return args

if __name__ == '__main__':
args = parse_args()

prefix = FORMATTERS['pianoroll']["prefix"]
if (not all([os.path.exists(
os.path.join(args.input, f'{split}_{prefix}_corpus.csv')
) for split in ['train', 'test']])) or args.reformat:
create_corpus_csvs(args.input, FORMATTERS['pianoroll'])
train_dataset = os.path.join(args.input, f'train_{prefix}_corpus.csv')
test_dataset = os.path.join(args.input, f'test_{prefix}_corpus.csv')

# Calculate outputs
seq_len = args.seq_len
dataset = mdtk.pytorch_datasets.PianorollDataset(train_dataset, seq_len)
deg_counts = np.zeros(9)
frame_counts = np.zeros(2)
pr_outputs = np.zeros((2, 2))
for data in dataset:
deg_counts[data['deg_label']] += 1
sum_frames = np.sum(data['changed_frames'])
frame_counts[0] += seq_len - sum_frames
frame_counts[1] += sum_frames
deg_pr = data['deg_pr']
clean_pr = data['clean_pr']
num_pitches = np.shape(clean_pr)[1]
for deg in [0, 1]:
for clean in [0, 1]:
pr_outputs[clean, deg] += (
np.sum(np.where(np.logical_and(deg_pr == deg, clean_pr == clean), 1, 0))
)
deg_counts /= np.sum(deg_counts)
frame_counts /= np.sum(frame_counts)
pr_outputs /= np.sum(pr_outputs, axis=0)

# Calculate losses and metrics
dataset = mdtk.pytorch_datasets.PianorollDataset(test_dataset, seq_len)

labels = []
binary_labels = []
frame_labels = []
clean_prs = np.zeros((seq_len, num_pitches * len(dataset)))
deg_prs = np.zeros((seq_len, num_pitches * len(dataset)))
for i, data in enumerate(dataset):
labels.append(data['deg_label'])
binary_labels.append(0 if data['deg_label'] == 0 else 1)
frame_labels.extend(data['changed_frames'])
deg_prs[:, i * num_pitches : (i + 1) * num_pitches] = data['deg_pr']
clean_prs[:, i * num_pitches : (i + 1) * num_pitches] = data['clean_pr']
labels = np.array(labels)
binary_labels = np.array(binary_labels)
frame_labels = np.array(frame_labels)

# Task 1
outputs = np.zeros((len(binary_labels), 2))
outputs[:] = [deg_counts[0], np.sum(deg_counts[1:])]

loss = torch.nn.CrossEntropyLoss()
task1 = loss(torch.from_numpy(outputs).float(), torch.from_numpy(binary_labels).long())
print(f'Task 1 loss = {task1}')
if deg_counts[0] < 0.5:
print(f'Task 1 rev F-measure = 0.0')
else:
precision = np.sum(1 - binary_labels) / len(binary_labels)
print(f'Task 1 rev F-measure = {(1 + precision) / (2 * precision)}')

# Task 2
outputs = np.zeros((len(labels), 9))
outputs[:] = deg_counts

task2 = loss(torch.from_numpy(outputs).float(), torch.from_numpy(labels).long())
print(f'Task 2 loss = {task2}')
print(f'Task 2 acc = {np.mean(1 - binary_labels)}')

# Task 3
outputs = np.zeros((len(frame_labels), 2))
outputs[:] = frame_counts

task3 = loss(torch.from_numpy(outputs).float(), torch.from_numpy(frame_labels).long())
print(f'Task 3 loss = {task3}')
if frame_counts[0] > 0.5:
print(f'Task 3 F-measure = 0.0')
else:
precision = np.sum(frame_labels) / len(frame_labels)
print(f'Task 3 F-measure = {(1 + precision) / (2 * precision)}')

# Task 4
outputs = np.zeros(np.shape(clean_prs))
outputs[deg_prs == 0] = pr_outputs[1, 0]
outputs[deg_prs == 1] = pr_outputs[1, 1]

loss = torch.nn.BCEWithLogitsLoss()
task4 = loss(torch.from_numpy(outputs).float(), torch.from_numpy(clean_prs).float())
helpfulness = (0.5 * np.sum(binary_labels) + np.sum(1 - binary_labels)) / len(binary_labels)
print(f'Task 4 loss = {task4}')
print(f'Task 4 Helpfulness = {helpfulness}')

0 comments on commit 993e82a

Please sign in to comment.