Skip to content

Commit

Permalink
Implement WLS fit for kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
KeitaNakamura committed Jul 30, 2024
1 parent 3235a6d commit 6e958cb
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 35 deletions.
35 changes: 2 additions & 33 deletions src/Interpolations/kernelcorrection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ gridspan(kc::KernelCorrection) = gridspan(get_kernel(kc))
indices = neighboringnodes(mp)
isnearbounds = size(mp.w) != size(indices) || !alltrue(filter, indices)
if isnearbounds
update_property_nearbounds!(mp, it, pt, mesh, filter)
update_property!(mp, WLS(get_kernel(it), get_polynomial(it)), pt, mesh, filter)
else
@inbounds @simd for ip in eachindex(indices)
i = indices[ip]
Expand All @@ -37,41 +37,10 @@ end
indices = neighboringnodes(mp)
isnearbounds = size(mp.w) != size(indices) || !alltrue(filter, indices)
if isnearbounds
update_property_nearbounds!(mp, it, pt, mesh, filter)
update_property!(mp, WLS(get_kernel(it), get_polynomial(it)), pt, mesh, filter)
else
set_kernel_values!(mp, values(difftype(mp), get_kernel(it), getx(pt), mesh))
end
end

@inline function update_property_nearbounds!(mp::MPValue, it::KernelCorrection, pt, mesh::CartesianMesh{dim}, filter::AbstractArray{Bool}) where {dim}
indices = neighboringnodes(mp)
kernel = get_kernel(it)
poly = get_polynomial(it)
xₚ = getx(pt)

M = fastsum(eachindex(indices)) do ip
@inbounds begin
i = indices[ip]
xᵢ = mesh[i]
w = mp.w[ip] = value(kernel, pt, mesh, i) * filter[i]
P = value(poly, xᵢ - xₚ)
w * P P
end
end
M⁻¹ = inv(M)

P₀, ∇P₀, ∇∇P₀, ∇∇∇P₀ = value(all, poly, zero(xₚ))
@inbounds for ip in eachindex(indices)
i = indices[ip]
xᵢ = mesh[i]
w = mp.w[ip]
P = value(poly, xᵢ - xₚ)
wq = w * (M⁻¹ P)
hasproperty(mp, :w) && set_kernel_values!(mp, ip, (wqP₀,))
hasproperty(mp, :∇w) && set_kernel_values!(mp, ip, (wqP₀, wq∇P₀))
hasproperty(mp, :∇∇w) && set_kernel_values!(mp, ip, (wqP₀, wq∇P₀, wq∇∇P₀))
hasproperty(mp, :∇∇∇w) && set_kernel_values!(mp, ip, (wqP₀, wq∇P₀, wq∇∇P₀, wq∇∇∇P₀))
end
end

Base.show(io::IO, kc::KernelCorrection) = print(io, KernelCorrection, "(", get_kernel(kc), ", ", get_polynomial(kc), ")")
45 changes: 45 additions & 0 deletions src/Interpolations/wls.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
struct WLS{K <: Kernel, P <: AbstractPolynomial} <: Interpolation
kernel::K
poly::P
end

WLS(k::Kernel) = WLS(k, LinearPolynomial())

get_kernel(wls::WLS) = wls.kernel
get_polynomial(wls::WLS) = wls.poly
gridspan(wls::WLS) = gridspan(get_kernel(wls))
@inline neighboringnodes(wls::WLS, pt, mesh::CartesianMesh) = neighboringnodes(get_kernel(wls), pt, mesh)

# implementation is not fast
@inline function update_property!(mp::MPValue, it::WLS, pt, mesh::CartesianMesh, filter::AbstractArray{Bool} = Trues(size(mesh)))
indices = neighboringnodes(mp)
kernel = get_kernel(it)
poly = get_polynomial(it)
xₚ = getx(pt)

M = fastsum(eachindex(indices)) do ip
@inbounds begin
i = indices[ip]
xᵢ = mesh[i]
w = mp.w[ip] = value(kernel, pt, mesh, i) * filter[i]
P = value(poly, xᵢ - xₚ)
w * P P
end
end
M⁻¹ = inv(M)

P₀, ∇P₀, ∇∇P₀, ∇∇∇P₀ = value(all, poly, zero(xₚ))
@inbounds for ip in eachindex(indices)
i = indices[ip]
xᵢ = mesh[i]
w = mp.w[ip]
P = value(poly, xᵢ - xₚ)
wq = w * (M⁻¹ P)
hasproperty(mp, :w) && set_kernel_values!(mp, ip, (wqP₀,))
hasproperty(mp, :∇w) && set_kernel_values!(mp, ip, (wqP₀, wq∇P₀))
hasproperty(mp, :∇∇w) && set_kernel_values!(mp, ip, (wqP₀, wq∇P₀, wq∇∇P₀))
hasproperty(mp, :∇∇∇w) && set_kernel_values!(mp, ip, (wqP₀, wq∇P₀, wq∇∇P₀, wq∇∇∇P₀))
end
end

Base.show(io::IO, wls::WLS) = print(io, WLS, "(", get_kernel(wls), ", ", get_polynomial(wls), ")")
2 changes: 2 additions & 0 deletions src/Tesserae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export
SteffenCubicBSpline,
uGIMP,
Interpolation,
WLS,
KernelCorrection,
# MPValue
generate_mpvalues,
Expand Down Expand Up @@ -93,6 +94,7 @@ include("Interpolations/mpvalue.jl")
include("Interpolations/bspline.jl")
include("Interpolations/gimp.jl")
include("Interpolations/polynomials.jl")
include("Interpolations/wls.jl")
include("Interpolations/kernelcorrection.jl")

include("transfer.jl")
Expand Down
4 changes: 2 additions & 2 deletions test/interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ end
end
end

@testset "KernelCorrection($kernel)" for kernel in (LinearBSpline(), QuadraticBSpline(), CubicBSpline(), SteffenLinearBSpline(), SteffenQuadraticBSpline(), SteffenCubicBSpline(), uGIMP())
it = KernelCorrection(kernel)
@testset "$(Wrapper(kernel))" for Wrapper in (WLS, KernelCorrection), kernel in (LinearBSpline(), QuadraticBSpline(), CubicBSpline(), SteffenLinearBSpline(), SteffenQuadraticBSpline(), SteffenCubicBSpline(), uGIMP())
it = Wrapper(kernel)
for dim in (1,2,3)
Random.seed!(1234)
mp = MPValue(Vec{dim}, it)
Expand Down

0 comments on commit 6e958cb

Please sign in to comment.