Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 3a2d04f

Browse files
authored
Merge pull request #87 from tknopp/rfft
Change FFT to Real to Complex FFT
2 parents 170ae0e + fed68b8 commit 3a2d04f

File tree

6 files changed

+25
-14
lines changed

6 files changed

+25
-14
lines changed

src/Transform/chebyshev_transform.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray)
1515
return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch]
1616
end
1717

18-
function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray) where {N}
18+
function inverse(t::ChebyshevTransform, 𝐱̂::AbstractArray{T, N},
19+
M::NTuple{N, Int64}) where {T, N}
1920
normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1)))
2021
return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch]
2122
end

src/Transform/fourier_transform.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Base.ndims(::FourierTransform{N}) where {N} = N
88
Base.eltype(::Type{FourierTransform}) = ComplexF32
99

1010
function transform(ft::FourierTransform, 𝐱::AbstractArray)
11-
return fft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch]
11+
return rfft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch]
1212
end
1313

1414
function low_pass(ft::FourierTransform, 𝐱_fft::AbstractArray)
@@ -17,6 +17,7 @@ end
1717

1818
truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft)
1919

20-
function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray)
21-
return real(ifft(𝐱_fft, 1:ndims(ft))) # [size(x_fft)..., out_chs, batch]
20+
function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray{T, N},
21+
M::NTuple{N, Int64}) where {T, N}
22+
return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch]
2223
end

src/operator_kernel.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function operator_conv(m::OperatorConv, 𝐱::AbstractArray)
9292
𝐱_padded = pad_modes(𝐱_applied_pattern,
9393
(size(𝐱_transformed)[1:(end - 2)]...,
9494
size(𝐱_applied_pattern)[(end - 1):end]...)) # [size(x)..., out_chs, batch] <- [modes..., out_chs, batch]
95-
𝐱_inversed = inverse(m.transform, 𝐱_padded)
95+
𝐱_inversed = inverse(m.transform, 𝐱_padded, size(𝐱))
9696

9797
return 𝐱_inversed
9898
end

test/Transform/chebyshev_transform.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
@test ndims(t) == 3
99
@test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch)
1010
@test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch)
11-
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)))) == (3, 4, 5, ch, batch)
11+
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) ==
12+
(3, 4, 5, ch, batch)
1213

13-
g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)))), 𝐱)
14+
g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)), size(𝐱))), 𝐱)
1415
@test size(g[1]) == (30, 40, 50, ch, batch)
1516
end

test/Transform/fourier_transform.jl

+11-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,18 @@
55

66
ft = FourierTransform((3, 4, 5))
77

8-
@test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch)
8+
@test size(transform(ft, 𝐱)) == (16, 40, 50, ch, batch)
99
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch)
10-
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, ch, batch)
10+
@test size(inverse(ft,
11+
NeuralOperators.pad_modes(truncate_modes(ft, transform(ft, 𝐱)),
12+
size(transform(ft, 𝐱))),
13+
size(𝐱))) == (30, 40, 50, ch, batch)
1114

12-
g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)))), 𝐱)
15+
g = Zygote.gradient(x -> sum(inverse(ft,
16+
NeuralOperators.pad_modes(truncate_modes(ft,
17+
transform(ft,
18+
x)),
19+
(16, 40, 50, ch, batch)),
20+
(30, 40, 50, ch, batch))), 𝐱)
1321
@test size(g[1]) == (30, 40, 50, ch, batch)
1422
end

test/operator_kernel.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ end
7171
end
7272

7373
@testset "2D OperatorConv" begin
74-
modes = (16, 16)
74+
modes = (10, 10)
7575
ch = 64 => 64
7676

7777
m = Chain(Dense(1, 64),
@@ -87,7 +87,7 @@ end
8787
end
8888

8989
@testset "permuted 2D OperatorConv" begin
90-
modes = (16, 16)
90+
modes = (10, 10)
9191
ch = 64 => 64
9292

9393
m = Chain(Conv((1, 1), 1 => 64),
@@ -104,7 +104,7 @@ end
104104
end
105105

106106
@testset "2D OperatorKernel" begin
107-
modes = (16, 16)
107+
modes = (10, 10)
108108
ch = 64 => 64
109109

110110
m = Chain(Dense(1, 64),
@@ -119,7 +119,7 @@ end
119119
end
120120

121121
@testset "permuted 2D OperatorKernel" begin
122-
modes = (16, 16)
122+
modes = (10, 10)
123123
ch = 64 => 64
124124

125125
m = Chain(Conv((1, 1), 1 => 64),

0 commit comments

Comments
 (0)