-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbeat_tracking.cpp
230 lines (204 loc) · 8.51 KB
/
beat_tracking.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
/*
* This file describes a function that takes in
* a set of clusters, and uses the clusters to
* make a prediction about the beat of the song
*
*/
#include "beat_tracking.h"
std::vector<Agent *> agents;
// NOTE: THIS FUNCTION SHOULD ONLY BE CALLED ONCE FROM SOMEWHERE ELSE SINCE REAL TIME
void beat_tracking_initialisation(std::vector<Cluster > &clusters, std::deque<Onset> &onsets, int numClusters)
{
// Create the initial agents based on the first set of onsets
for (int i = clusters.size() - numClusters; i < clusters.size(); ++i) {
for (int j = 0; j < 20; ++j) {
Agent *new_agent = new Agent();
new_agent->interval = clusters[i].average_IOI;
new_agent->prediction = onsets[j].time_stamp + new_agent->interval;
new_agent->history = onsets[j];
new_agent->score = onsets[j].magnitude;
agents.push_back(new_agent);
}
}
}
Agent* beat_tracking(std::deque<Onset> &new_onsets)
{
std::vector<Agent *> new_agents;
static Agent *highest_score_agent = agents[0];
for (int i = 0; i < new_onsets.size(); ++i) {
for (int j = 0; j < agents.size(); ++j) {
// If this new onset is way past agent's last onset, this agent is invalid
if (new_onsets[i].time_stamp - agents[j]->history.time_stamp > TIME_OUT) {
agents.erase(agents.begin() + j);
}
// Otherwise create tolerance windows and check whether this onset can be added to this agent
else {
double tolerance_pre = 0.2 * agents[j]->interval; // 20% of current interval
double tolerance_post = 0.4 * agents[j]->interval; // 40% of current interval
while (agents[j]->prediction + tolerance_post < new_onsets[i].time_stamp) {
agents[j]->prediction += agents[j]->interval;
}
// This onset falls within an existing agent's prediction interval
if ((new_onsets[i].time_stamp >= agents[j]->prediction - tolerance_pre) &&
(new_onsets[i].time_stamp < agents[j]->prediction + tolerance_post)) {
// Now see if onset falls between inner and outer interval
if (std::abs(agents[j]->prediction - new_onsets[i].time_stamp) > tolerance_inner) {
Agent *new_agent = new Agent();
new_agent->interval = agents[j]->interval;
new_agent->prediction = agents[j]->prediction;
new_agent->history = agents[j]->history;
new_agent->score = agents[j]->score;
new_agents.push_back(new_agent);
}
double error = new_onsets[i].time_stamp - agents[j]->prediction;
double relative_error;
if (error <= 0) relative_error = -error / tolerance_pre;
else relative_error = error / tolerance_post;
agents[j]->interval += error / CORRECTION_FACTOR;
agents[j]->prediction = new_onsets[i].time_stamp + agents[j]->interval;
agents[j]->history = new_onsets[i];
agents[j]->score += (1 - relative_error/2) * new_onsets[i].magnitude;
if (agents[j]->score > highest_score_agent->score) {
highest_score_agent = agents[j];
}
}
}
}
}
// Add new agents
for (int i = 0; i < new_agents.size(); ++i) {
agents.push_back(new_agents[i]);
}
// Remove duplicate agents
for (int i = 0; i < agents.size(); ++i) {
for (int j = i+1; j < agents.size(); ++j) {
if ((std::abs(agents[i]->interval - agents[j]->interval) < fs * 20/1000) &&
(std::abs(agents[i]->prediction - agents[j]->prediction) < fs * 40/1000)) {
if (agents[i]->score >= agents[j]->score) {
delete agents[j];
agents.erase(agents.begin() + j);
}
else {
delete agents[i];
agents.erase(agents.begin() + i);
}
}
}
}
return highest_score_agent;
}
Agent* beat_tracking_update(std::deque<Onset> &new_onsets)
{
std::vector<Agent *> new_agents;
static Agent *highest_score_agent = agents[0];
for (int i = 0; i < new_onsets.size(); ++i) {
for (int j = agents.size() - 20; j < agents.size(); ++j) {
// If this new onset is way past agent's last onset, this agent is invalid
if (new_onsets[i].time_stamp - agents[j]->history.time_stamp > TIME_OUT) {
agents.erase(agents.begin() + j);
}
// Otherwise create tolerance windows and check whether this onset can be added to this agent
else {
double tolerance_pre = 0.2 * agents[j]->interval; // 20% of current interval
double tolerance_post = 0.4 * agents[j]->interval; // 40% of current interval
while (agents[j]->prediction + tolerance_post < new_onsets[i].time_stamp) {
agents[j]->prediction += agents[j]->interval;
}
// This onset falls within an existing agent's prediction interval
if ((new_onsets[i].time_stamp >= agents[j]->prediction - tolerance_pre) &&
(new_onsets[i].time_stamp < agents[j]->prediction + tolerance_post)) {
// Now see if onset falls between inner and outer interval
if (std::abs(agents[j]->prediction - new_onsets[i].time_stamp) > tolerance_inner) {
Agent *new_agent = new Agent();
new_agent->interval = agents[j]->interval;
new_agent->prediction = agents[j]->prediction;
new_agent->history = agents[j]->history;
new_agent->score = agents[j]->score;
new_agents.push_back(new_agent);
}
double error = new_onsets[i].time_stamp - agents[j]->prediction;
double relative_error;
if (error <= 0) relative_error = -error / tolerance_pre;
else relative_error = error / tolerance_post;
agents[j]->interval += error / CORRECTION_FACTOR;
agents[j]->prediction = new_onsets[i].time_stamp + agents[j]->interval;
agents[j]->history = new_onsets[i];
agents[j]->score += (1 - relative_error / 2) * new_onsets[i].magnitude;
if (agents[j]->score > highest_score_agent->score) {
highest_score_agent = agents[j];
}
}
}
}
}
// Add new agents
for (int i = 0; i < new_agents.size(); ++i) {
agents.push_back(new_agents[i]);
}
// Remove duplicate agents
for (int i = 0; i < agents.size(); ++i) {
for (int j = i + 1; j < agents.size(); ++j) {
if ((std::abs(agents[i]->interval - agents[j]->interval) < fs * 20 / 1000) &&
(std::abs(agents[i]->prediction - agents[j]->prediction) < fs * 40 / 1000)) {
if (agents[i]->score >= agents[j]->score) {
delete agents[j];
agents.erase(agents.begin() + j);
}
else {
delete agents[i];
agents.erase(agents.begin() + i);
}
}
}
}
return highest_score_agent;
}
void new_clusters_tracking(std::vector<Cluster *> &clusters, std::vector<Cluster > &clusters_max, Agent* highest_score_agent) {
Cluster* max_score_cluster = new Cluster();
Cluster* secmax_score_cluster = new Cluster();
Cluster* temp = new Cluster();
bool max_cluster_change = false;
bool secmax_cluster_change = false;
Agent* highest_score_agent_temp;
max_score_cluster = temp;
secmax_score_cluster = temp;
for (int t = 0; t < clusters.size(); t++) {
if (clusters[t]->score > max_score_cluster->score) max_score_cluster = clusters[t];
}
for (int t = 0; t < clusters.size(); t++) {
if (clusters[t] == max_score_cluster) continue;
if (clusters[t]->score > secmax_score_cluster->score) secmax_score_cluster = clusters[t];
}
if (!clusters_max.empty()) {
max_cluster_change = true;
secmax_cluster_change = true;
for (int t = 0; t < clusters_max.size(); t++) {
if (abs(clusters_max[t].average_IOI - max_score_cluster->average_IOI) < cluster_width) {
max_cluster_change = false;
}
if (abs(clusters_max[t].average_IOI - secmax_score_cluster->average_IOI) < cluster_width) {
secmax_cluster_change = false;
}
}
int tempo_1 = 6000 / max_score_cluster->average_IOI;
//std::cout << "Clustering estimated BPM: " << tempo_1 << '\n';
//std::cout<<"max_score_cluster:"<<max_score_cluster->average_IOI<<"\n";
//std::cout<<"secmax_score_cluster:"<<secmax_score_cluster->average_IOI<<"\n";
if (max_cluster_change) {
clusters_max.push_back(*max_score_cluster);
beat_tracking_initialisation(clusters_max, onsets, 1);
highest_score_agent_temp = beat_tracking_update(onsets);
if (highest_score_agent->score < highest_score_agent_temp->score) highest_score_agent = highest_score_agent_temp;
}
if (secmax_cluster_change) {
clusters_max.push_back(*secmax_score_cluster);
beat_tracking_initialisation(clusters_max, onsets, 1);
highest_score_agent_temp = beat_tracking_update(onsets);
if (highest_score_agent->score < highest_score_agent_temp->score) highest_score_agent = highest_score_agent_temp;
}
}
else {
clusters_max.push_back(*max_score_cluster);
clusters_max.push_back(*secmax_score_cluster);
}
}