-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathSampleNRandomTrials.m
43 lines (39 loc) · 1.32 KB
/
SampleNRandomTrials.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
function [train, test] = SampleNRandomTrials(Labels, n, param)
% Labels - cell contains source and target labels
% n - number of random trials
% param - num_train_source, num_train_target, held_out_categories,
% num_categories, categories
if ~isfield(param, 'num_train_source') || ~isfield(param, 'num_train_target') ...
|| ~isfield(param, 'held_out_categories') || ...
~isfield(param, 'categories')
disp(['Parameters should include -num_train_source,'...
'-num_train_target, -categories and -held_out_categories.'...
'Missing some value, cannot return samples']);
train = [];
test = [];
return;
end
result_filename = param.result_filename;
if exist(result_filename, 'file')
load(result_filename);
return;
end
for iter = 1:n
if param.held_out_categories
[tr_s, te_s, tr_t, te_t] = SplitDiffCategories(Labels{1}, ...
Labels{2}, param);
train.source{iter} = tr_s;
train.target{iter} = tr_t;
test.source{iter} = te_s;
test.target{iter} = te_t;
else
[tr_s, te_s, tr_t, te_t] = SplitAllCategories(Labels{1}, ...
Labels{2}, param);
train.source{iter} = tr_s;
train.target{iter} = tr_t;
test.source{iter} = te_s;
test.target{iter} = te_t;
end
end
save(result_filename, 'param', 'train', 'test');
end