Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
yonghakim committed Nov 22, 2023
1 parent 530c64a commit c2f9ebb
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions meent/on_numpy/emsolver/convolution_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def to_conv_mat_discrete(ucell, fourier_order_x, fourier_order_y, device=None, t
for i, layer in enumerate(ucell_pmt):
n = minimum_pattern_size // layer.shape[1]
layer = np.repeat(layer, n + 1, axis=1)
f_coeffs = np.fft.fftshift(np.fft.fft(layer / layer.size))
o_f_coeffs = np.fft.fftshift(np.fft.fft(1/layer / layer.size))
f_coeffs = np.fft.fftshift(np.fft.fft(layer / layer.size).astype(type_complex))
o_f_coeffs = np.fft.fftshift(np.fft.fft(1/layer / layer.size).astype(type_complex))
# FFT scaling:
# https://kr.mathworks.com/matlabcentral/answers/15770-scaling-the-fft-and-the-ifft?s_tid=srchtitle

Expand Down Expand Up @@ -257,8 +257,8 @@ def to_conv_mat_discrete(ucell, fourier_order_x, fourier_order_y, device=None, t
n = minimum_pattern_size_x // layer.shape[1]
layer = np.repeat(layer, n + 1, axis=1)

f_coeffs = np.fft.fftshift(np.fft.fft2(layer / layer.size))
o_f_coeffs = np.fft.fftshift(np.fft.fft2(1/layer / layer.size))
f_coeffs = np.fft.fftshift(np.fft.fft2(layer / layer.size).astype(type_complex))
o_f_coeffs = np.fft.fftshift(np.fft.fft2(1/layer / layer.size).astype(type_complex))
center = np.array(f_coeffs.shape) // 2

conv_idx_y = np.arange(-ff_y + 1, ff_y, 1)
Expand Down
4 changes: 2 additions & 2 deletions meent/on_torch/emsolver/convolution_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def to_conv_mat_raster_discrete(ucell, fourier_order_x, fourier_order_y, device=
for i, layer in enumerate(ucell_pmt):
n = minimum_pattern_size // layer.shape[1]
layer = layer.repeat_interleave(n + 1, axis=1)
f_coeffs = torch.fft.fftshift(torch.fft.fft(layer / layer.numel()))
o_f_coeffs = torch.fft.fftshift(torch.fft.fft(1/layer / layer.numel()))
f_coeffs = torch.fft.fftshift(torch.fft.fft(layer / layer.numel()).type(type_complex))
o_f_coeffs = torch.fft.fftshift(torch.fft.fft(1/layer / layer.numel()).type(type_complex))
center = torch.tensor(f_coeffs.shape, device=device) // 2
# center = torch.div(torch.tensor(f_coeffs.shape, device=device), 2, rounding_mode='trunc')
# center = torch.tensor(center, device=device)
Expand Down
6 changes: 3 additions & 3 deletions meent/on_torch/mee.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ def __init__(self, device=0, type_complex=0, *args, **kwargs):
self.device = device
self.type_complex = type_complex

ModelingTorch.__init__(self, *args, **kwargs)
RCWATorch.__init__(self, *args, **kwargs)
OptimizerTorch.__init__(self, *args, **kwargs)
ModelingTorch.__init__(self, device=device, type_complex=type_complex, *args, **kwargs)
RCWATorch.__init__(self, device=device, type_complex=type_complex, *args, **kwargs)
OptimizerTorch.__init__(self, device=device, type_complex=type_complex, *args, **kwargs)

0 comments on commit c2f9ebb

Please sign in to comment.