forked from turcato-niccolo/neurorobotics_prj
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfisher_score.m
43 lines (36 loc) · 1.88 KB
/
fisher_score.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
function [fisher_scores] = fisher_score(PSD, run_labels, type_labels, run_codes, class_codes)
% [fisher_scores] = fisher_score(PSD, run_labels, type_labels, run_codes, class_codes)
%
% The function returns the fisher's score of each run for the PSD given in input
% (remember to normalize the distribution of the features-- e.g. log(PSD))
%
% Input arguments:
% - PSD PSD matrix [windows x frequences x channels]
% - run_labels labelling vector (window indexed) containing info on the run the
% window is assigned to
% - type_labels labelling vector (window indexed) containing info on the type of
% task the window is assigned to
% - run_codes codes of the runs used in run_labels
% - class_codes class of tasks used in type_labels
%
% Output arguments:
% - fisher_scores fisher's score matrix for each run [frequences x channels x run]
num_classes = length(class_codes);
num_runs = length(run_codes);
num_freqs = size(PSD,2);
num_channels = size(PSD,3);
%we obtain one freqs x channels FisherScore matrix for each run
fisher_scores = nan(num_freqs, num_channels, num_runs);
for run_i = 1 : num_runs
mask_run = (run_labels == run_codes(run_i));
%vectors to accumulate expected values and standard deviations (freqs x channels x classes)
expected_values = nan(size(fisher_scores));
standart_deviations = nan(size(fisher_scores));
for class_i = 1 : num_classes
mask_run_and_class = mask_run & (type_labels == class_codes(class_i));
expected_values(:, :, class_i) = squeeze(mean(PSD(mask_run_and_class, :, :)));
standart_deviations(:, :, class_i) = squeeze(std(PSD(mask_run_and_class, :, :)));
end
fisher_scores(:, :, run_i) = abs(expected_values(:, :, 2) - expected_values(:, :, 1)) ./ sqrt( ( standart_deviations(:, :, 1).^2 + standart_deviations(:, :, 2).^2 ) );
end
end