Skip to content

Commit

Permalink
Merge pull request #60 from kc-ml2/DEV/main
Browse files Browse the repository at this point in the history
Dev/main
  • Loading branch information
yonghakim authored Nov 22, 2023
2 parents 98fd479 + c2f9ebb commit 7a97137
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
12 changes: 6 additions & 6 deletions meent/on_numpy/emsolver/convolution_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,15 @@ def to_conv_mat_discrete(ucell, fourier_order_x, fourier_order_y, device=None, t
e_conv_all = np.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex)
o_e_conv_all = np.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex)
if improve_dft:
minimum_pattern_size = 2 * ff * ucell_pmt.shape[2]
minimum_pattern_size = 2 * ff * ucell_pmt.shape[2] # TODO: scale factor is 2? to avoid alias?
else:
minimum_pattern_size = 2 * ff
minimum_pattern_size = 4 * fourier_order_x + 1 # TODO: other bds

for i, layer in enumerate(ucell_pmt):
n = minimum_pattern_size // layer.shape[1]
layer = np.repeat(layer, n + 1, axis=1)
f_coeffs = np.fft.fftshift(np.fft.fft(layer / layer.size))
o_f_coeffs = np.fft.fftshift(np.fft.fft(1/layer / layer.size))
f_coeffs = np.fft.fftshift(np.fft.fft(layer / layer.size).astype(type_complex))
o_f_coeffs = np.fft.fftshift(np.fft.fft(1/layer / layer.size).astype(type_complex))
# FFT scaling:
# https://kr.mathworks.com/matlabcentral/answers/15770-scaling-the-fft-and-the-ifft?s_tid=srchtitle

Expand Down Expand Up @@ -257,8 +257,8 @@ def to_conv_mat_discrete(ucell, fourier_order_x, fourier_order_y, device=None, t
n = minimum_pattern_size_x // layer.shape[1]
layer = np.repeat(layer, n + 1, axis=1)

f_coeffs = np.fft.fftshift(np.fft.fft2(layer / layer.size))
o_f_coeffs = np.fft.fftshift(np.fft.fft2(1/layer / layer.size))
f_coeffs = np.fft.fftshift(np.fft.fft2(layer / layer.size).astype(type_complex))
o_f_coeffs = np.fft.fftshift(np.fft.fft2(1/layer / layer.size).astype(type_complex))
center = np.array(f_coeffs.shape) // 2

conv_idx_y = np.arange(-ff_y + 1, ff_y, 1)
Expand Down
4 changes: 2 additions & 2 deletions meent/on_torch/emsolver/convolution_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def to_conv_mat_raster_discrete(ucell, fourier_order_x, fourier_order_y, device=
for i, layer in enumerate(ucell_pmt):
n = minimum_pattern_size // layer.shape[1]
layer = layer.repeat_interleave(n + 1, axis=1)
f_coeffs = torch.fft.fftshift(torch.fft.fft(layer / layer.numel()))
o_f_coeffs = torch.fft.fftshift(torch.fft.fft(1/layer / layer.numel()))
f_coeffs = torch.fft.fftshift(torch.fft.fft(layer / layer.numel()).type(type_complex))
o_f_coeffs = torch.fft.fftshift(torch.fft.fft(1/layer / layer.numel()).type(type_complex))
center = torch.tensor(f_coeffs.shape, device=device) // 2
# center = torch.div(torch.tensor(f_coeffs.shape, device=device), 2, rounding_mode='trunc')
# center = torch.tensor(center, device=device)
Expand Down
2 changes: 1 addition & 1 deletion meent/on_torch/emsolver/transfer_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff_xy, delt

final_RT = torch.linalg.inv(final_A) @ final_B

R_s = final_RT[:ff_xy, :].flatten()
R_s = final_RT[:ff_xy, :].flatten() # TODO: why flatten?
R_p = final_RT[ff_xy:2 * ff_xy, :].flatten()

big_T1 = final_RT[2 * ff_xy:, :]
Expand Down
8 changes: 4 additions & 4 deletions meent/on_torch/mee.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class MeeTorch(ModelingTorch, RCWATorch, OptimizerTorch):

def __init__(self, device=None, type_complex=None, *args, **kwargs):
def __init__(self, device=0, type_complex=0, *args, **kwargs):

# device
if device in (0, 'cpu'):
Expand All @@ -35,6 +35,6 @@ def __init__(self, device=None, type_complex=None, *args, **kwargs):
self.device = device
self.type_complex = type_complex

ModelingTorch.__init__(self, *args, **kwargs)
RCWATorch.__init__(self, *args, **kwargs)
OptimizerTorch.__init__(self, *args, **kwargs)
ModelingTorch.__init__(self, device=device, type_complex=type_complex, *args, **kwargs)
RCWATorch.__init__(self, device=device, type_complex=type_complex, *args, **kwargs)
OptimizerTorch.__init__(self, device=device, type_complex=type_complex, *args, **kwargs)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
}
setup(
name='meent',
version='0.9.6',
version='0.9.7',
url='https://github.com/kc-ml2/meent',
author='KC ML2',
author_email='[email protected]',
Expand Down

0 comments on commit 7a97137

Please sign in to comment.