-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTNS_CP.m
132 lines (110 loc) · 3.94 KB
/
TNS_CP.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
function A = TNS_CP(X, R, J, varargin)
%TNS_CP Computes CP decomposition via TNS sampling
%
%A = TNS_CP(X, R, J) returns the factor matrices in a cell A for a
%rank R CP decomposition of X. J is the sampling rate used in least squares
%sampling. See our paper for details.
%
%A = TNS_CP(___, 'maxiters', maxiters) can be used to control the
%maximum number of iterations. maxiters is 50 by default.
%
%A = TNS_CP(___, 'permute_for_speed', permute_for_speed) can be used to
%permute the modes of the input tensor so that the largest mode comes
%first. This can speed up the sampling, since all 1st indices can be drawn
%together rather than one at a time. Set permute_for_speed to true to
%enable this. It is false by default.
%
%A = TNS_CP(___, 'A_init', A_init) can be used to set how the factor
%matrices are initialized. If A_init is "rand", then all the factor
%matrices are initialized to have entries drawn uniformly at random from
%[0,1]. If A_init is "RRF", then the factor matrices are initalized via a
%randomized range finder applied to the unfoldings of X. A_init can also be
%a cell array containing initializations for the factor matrices.
%
%This function currently does not support any checking of convergence
%criteria.
%
%This is an adaption of the function cp_als_es.m in the repo at
%https://github.com/OsmanMalik/TD-ALS-ES, which is the repo associated with
%the paper [Mal22].
% Handle optional inputs
params = inputParser;
addParameter(params, 'maxiters', 50, @(x) isscalar(x) & x > 0);
addParameter(params, 'permute_for_speed', false);
addParameter(params, 'A_init', "rand")
parse(params, varargin{:});
maxiters = params.Results.maxiters;
permute_for_speed = params.Results.permute_for_speed;
A_init = params.Results.A_init;
N = ndims(X);
if isscalar(J)
J = repmat(J, N, 1);
end
% Permute modes of X for increased speed
if permute_for_speed
sz = size(X);
[~, max_idx] = max(sz);
perm_vec = mod((max_idx : max_idx+N-1) - 1, N) + 1;
inv_perm_vec(perm_vec) = 1:N;
X = permute(X, perm_vec);
J = J(perm_vec);
end
sz = size(X);
% Initialize factor matrices
if iscell(A_init)
A = A_init;
else
A = cell(1,N);
for j = 2:N
if strcmp(A_init, "rand")
A{j} = rand(sz(j), R);
elseif strcmp(A_init, "RRF")
Xn = classical_mode_unfolding(X,j);
A{j} = Xn * randn(size(Xn,2), R);
end
end
end
% Compute factor Gram matrices
AtA = cell(1,N);
for j = 2:N
AtA{j} = A{j}.' * A{j};
end
% Main loop
for it = 1:maxiters
% Iterate through all factor matrices
for n = 1:N
% Draw samples
[samples, sqrt_p] = draw_samples_TNS_CP(A, AtA, n, J(n));
% Merge identical samples and count occurences
[occurs, ~] = groupcounts(samples);
[unq_samples, unq_idx] = unique(samples(:,[1:n-1 n+1:N]), 'rows'); % Since groupcounts can't output this...
J2_unq = size(unq_samples,1);
unq_samples = [unq_samples(:, 1:n-1) nan(J2_unq,1) unq_samples(:, n:N-1)];
% Compute rescaling factors
rescale = sqrt(occurs./J(n)) ./ sqrt_p(unq_idx);
% Construct sketched design matrix
SA = repmat(rescale, 1, R);
for j = N:-1:1
if j ~= n
SA = SA .* A{j}(unq_samples(:, j), :);
end
end
% Construct sketched right hand side
szp = cumprod([1 sz(1:end-1)]);
samples_temp = unq_samples - 1; samples_temp(:,n) = 0;
llin = 1+samples_temp*szp';
llin = repelem(llin, sz(n), 1) + repmat((0:sz(n)-1)'*szp(n), size(unq_samples,1), 1);
SXnT = X(llin);
SXnT = reshape(SXnT, sz(n), size(unq_samples,1))';
SXnT = SXnT .* rescale;
% Solve sketched LS problem and update nth factor matrix
A{n} = (SA \ SXnT).';
% Update AtA
AtA{n} = A{n}.' * A{n};
end
fprintf('\tIteration %d complete\n', it);
end
if permute_for_speed
A = A(inv_perm_vec);
end
end