Skip to content

Commit

Permalink
Make getting the knots MUCH faster, and add the ability to use averag…
Browse files Browse the repository at this point in the history
…ed leading edge for knots instead of a maximum
  • Loading branch information
DanielVandH committed Jul 8, 2023
1 parent d5f3ccd commit 3913861
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EpithelialDynamics1D"
uuid = "ace8a2d7-7779-48a6-a8a4-cf6831a7e55b"
authors = ["Daniel VandenHeuvel <danj.vandenheuvel@gmail.com>"]
version = "1.3.3"
version = "1.4.0"

[deps]
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Expand Down
46 changes: 33 additions & 13 deletions src/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ function node_densities(cell_positions::AbstractVector{T}) where {T<:Number}
end

"""
get_knots(sol, num_knots = 500; indices = eachindex(sol))
get_knots(sol, num_knots = 500; indices = eachindex(sol), use_max=true)
Computes knots for each time, covering the extremum of the cell positions across all
cell simulations. You can restrict the simultaions to consider using the `indices`.
If `use_max` is true, then the knots will be obtained by taking the extreme node positions
for each `time`, otherwise the average is used.
"""
function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol))
function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol), use_extrema=true)
@static if VERSION < v"1.7"
knots = Vector{LinRange{Float64}}(undef, length(first(sol)))
else
Expand All @@ -74,14 +76,30 @@ function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol))
times = first(sol).t
Base.Threads.@threads for i in eachindex(times)
local a, b
a = Inf
b = -Inf
if use_extrema
a = Inf
b = -Inf
else
a = 0.0
b = 0.0
ctr = 0
end
for j in indices
for r in sol[j][i]
a = min(a, r[begin])
b = max(b, r[end])
_a = sol[j][i][begin]
_b = sol[j][i][end]
if use_extrema
a = min(a, _a)
b = max(b, _b)
else
a += _a
b += _b
ctr += 1
end
end
if !use_extrema
a /= ctr
b /= ctr
end
knots[i] = LinRange(a, b, num_knots)
end
return knots
Expand All @@ -104,7 +122,8 @@ Computes summary statistics for the node densities from an `EnsembleSolution` to
# Keyword Arguments
- `indices = eachindex(sol)`: The indices of the cell simulations to consider.
- `num_knots::Int = 500`: The number of knots to use for the spline interpolation.
- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices)`: The knots to use for the spline interpolation.
- `use_extrema::Bool = true`: Whether to use the extrema of the cell positions for the knots, or the average.
- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices, use_extrema)`: The knots to use for the spline interpolation.
- `alpha::Float64 = 0.05`: The significance level for the confidence intervals.
- `interp_fnc = (u, t) -> LinearInterpolation{true}(u, t)`: The function to use for constructing the interpolant.
Expand All @@ -116,12 +135,13 @@ Computes summary statistics for the node densities from an `EnsembleSolution` to
- `uppers::Vector{Vector{Float64}}`: The upper bounds of the confidence intervals for the node densities for each cell simulation.
- `knots::Vector{Vector{Float64}}`: The knots used for the spline interpolation.
"""
function node_densities(sol::EnsembleSolution;
indices=eachindex(sol),
num_knots=500,
knots=get_knots(sol, num_knots; indices),
function node_densities(sol::EnsembleSolution;
indices=eachindex(sol),
num_knots=500,
use_extrema=true,
knots=get_knots(sol, num_knots; indices, use_extrema),
alpha=0.05,
interp_fnc = (u, t) -> LinearInterpolation{true}(u, t))
interp_fnc=(u, t) -> LinearInterpolation{true}(u, t))
q = Vector{Vector{Vector{Float64}}}(undef, length(indices))
r = Vector{Vector{Vector{Float64}}}(undef, length(indices))
Base.Threads.@threads for i in eachindex(indices)
Expand Down
67 changes: 67 additions & 0 deletions test/step_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,35 @@ end
@test quantile(all_q, 0.975) uppers[j][i]
end
end

# Using average leading edge
_indices = rand(eachindex(sol), 40)
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, use_extrema=false)
@inferred node_densities(sol; indices=_indices, use_extrema=false)
@test all((LinRange(0, 30, 500)), knots)
for (enum_k, k) in enumerate(_indices)
for j in rand(1:length(sol[k]), 40)
for i in rand(1:length(sol[k][j]), 60)
if i == 1
@test q[enum_k][j][1] 1 / (r[enum_k][j][2] - r[enum_k][j][1])
elseif i == length(sol[k][j])
n = length(sol[k][j])
@test q[enum_k][j][n] 1 / (r[enum_k][j][n] - r[enum_k][j][n-1])
else
@test q[enum_k][j][i] 2 / (r[enum_k][j][i+1] - r[enum_k][j][i-1])
end
@test r[enum_k][j][i] == sol[k][j][i]
end
end
end
for j in rand(1:length(fvm_sol), 50)
for i in rand(1:length(knots[j]), 50)
all_q = [LinearInterpolation(q[k][j], r[k][j])(knots[j][i]) for k in eachindex(_indices)]
@test mean(all_q) means[j][i]
@test quantile(all_q, 0.025) lowers[j][i]
@test quantile(all_q, 0.975) uppers[j][i]
end
end
end

@testset "Proliferation with a Moving Boundary" begin
Expand Down Expand Up @@ -519,4 +548,42 @@ end
@test quantile(all_q, 0.975) uppers[j][i] rtol = 1e-3
end
end

# Using the average leading edge
(; L) = leading_edges(sol)
_L = stack(L)
_indices = rand(eachindex(sol), 20)
_L = _L[:, _indices]
_mL = mean.(eachrow(_L))
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, use_extrema=false)
@inferred node_densities(sol; indices=_indices, use_extrema=false)
for j in eachindex(knots)
a = mean(sol[k][j][begin] for k in _indices)
b = mean(sol[k][j][end] for k in _indices)
@test knots[j] LinRange(a, b, 500)
@test knots[j][end] _mL[j]
end
for (enum_k, k) in enumerate(_indices)
for j in rand(1:length(sol[k]), 40)
for i in 1:length(sol[k][j])
if i == 1
@test q[enum_k][j][1] 1 / (r[enum_k][j][2] - r[enum_k][j][1])
elseif i == length(sol[k][j])
n = length(sol[k][j])
@test q[enum_k][j][n] 1 / (r[enum_k][j][n] - r[enum_k][j][n-1])
else
@test q[enum_k][j][i] 2 / (r[enum_k][j][i+1] - r[enum_k][j][i-1])
end
@test r[enum_k][j][i] == sol[k][j][i]
end
end
end
for j in rand(eachindex(mb_sol), 40)
for i in eachindex(knots[j])
all_q = max.(0.0, [LinearInterpolation(q[k][j], r[k][j])(knots[j][i]) * (knots[j][i] r[k][j][end]) for k in eachindex(_indices)])
@test mean(all_q) means[j][i] rtol = 1e-3
@test quantile(all_q, 0.025) lowers[j][i] rtol = 1e-3
@test quantile(all_q, 0.975) uppers[j][i] rtol = 1e-3
end
end
end

0 comments on commit 3913861

Please sign in to comment.