Skip to content

Commit

Permalink
Allow for passing a subset of simulation indices
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielVandH committed Jun 29, 2023
1 parent 1d145b7 commit c0f5447
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EpithelialDynamics1D"
uuid = "ace8a2d7-7779-48a6-a8a4-cf6831a7e55b"
authors = ["Daniel VandenHeuvel <danj.vandenheuvel@gmail.com>"]
version = "1.1.0"
version = "1.2.0"

[deps]
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Expand Down
45 changes: 20 additions & 25 deletions src/statistics.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# densities
# cell numbers
# right endpoints

"""
cell_densities(cell_positions::AbstractVector{T}) where {T<:Number}
Expand Down Expand Up @@ -66,12 +62,12 @@ function node_densities(cell_positions::AbstractVector{T}) where {T<:Number}
end

"""
get_knots(sol, num_knots = 500)
get_knots(sol, num_knots = 500; indices = eachindex(sol))
Computes knots for each time, covering the extremum of the cell positions across all
cell simulations.
cell simulations. You can restrict the simultaions to consider using the `indices`.
"""
function get_knots(sol::EnsembleSolution, num_knots=500)
function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol))
@static if VERSION < v"1.7"
knots = Vector{LinRange{Float64}}(undef, length(first(sol)))
else
Expand All @@ -81,7 +77,7 @@ function get_knots(sol::EnsembleSolution, num_knots=500)
for i in eachindex(times)
a = Inf
b = -Inf
for j in eachindex(sol)
for j in indices
for r in sol[j][i]
a = min(a, r[begin])
b = max(b, r[end])
Expand All @@ -107,8 +103,9 @@ Computes summary statistics for the node densities from an `EnsembleSolution` to
- `sol::EnsembleSolution`: The ensemble solution to a `CellProblem`.
# Keyword Arguments
- `indices = eachindex(sol)`: The indices of the cell simulations to consider.
- `num_knots::Int = 500`: The number of knots to use for the spline interpolation.
- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots)`: The knots to use for the spline interpolation.
- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices)`: The knots to use for the spline interpolation.
- `alpha::Float64 = 0.05`: The significance level for the confidence intervals.
# Outputs
Expand All @@ -119,15 +116,11 @@ Computes summary statistics for the node densities from an `EnsembleSolution` to
- `uppers::Vector{Vector{Float64}}`: The upper bounds of the confidence intervals for the node densities for each cell simulation.
- `knots::Vector{Vector{Float64}}`: The knots used for the spline interpolation.
"""
function node_densities(sol::EnsembleSolution; num_knots=500, knots=get_knots(sol, num_knots), alpha=0.05)
q = map(sol) do sol
node_densities.(sol.u)
end
r = map(sol) do sol
sol.u
end
function node_densities(sol::EnsembleSolution; indices=eachindex(sol), num_knots=500, knots=get_knots(sol, num_knots; indices), alpha=0.05)
q = [node_densities.(sol[i].u) for i in indices]
r = [sol[i].u for i in indices]
nt = length(first(sol))
nsims = length(sol)
nsims = length(indices)
q_splines = zeros(num_knots, nt, nsims)
q_means = [zeros(num_knots) for _ in 1:nt]
q_lowers = [zeros(num_knots) for _ in 1:nt]
Expand Down Expand Up @@ -159,14 +152,15 @@ function node_densities(sol::EnsembleSolution; num_knots=500, knots=get_knots(so
end

"""
cell_numbers(sol::EnsembleSolution; alpha=0.05)
cell_numbers(sol::EnsembleSolution; indices = eachindex(sol), alpha=0.05)
Computes summary statistics for the cell numbers from an `EnsembleSolution` to a [`CellProblem`](@ref).
# Arguments
- `sol::EnsembleSolution`: The ensemble solution to a `CellProblem`.
# Keyword Arguments
- `indices = eachindex(sol)`: The indices of the cell simulations to consider.
- `alpha::Float64 = 0.05`: The significance level for the confidence intervals.
# Outputs
Expand All @@ -175,9 +169,9 @@ Computes summary statistics for the cell numbers from an `EnsembleSolution` to a
- `lowers::Vector{Float64}`: The lower bounds of the confidence intervals for the cell numbers for each cell simulation.
- `uppers::Vector{Float64}`: The upper bounds of the confidence intervals for the cell numbers for each cell simulation.
"""
function cell_numbers(sol::EnsembleSolution; alpha=0.05)
N = map(sol) do sol
length.(sol.u) .- 1
function cell_numbers(sol::EnsembleSolution; indices=eachindex(sol), alpha=0.05)
N = map(indices) do i
length.(sol[i].u) .- 1
end |> x -> reduce(hcat, x)
N_means = zeros(size(N, 1))
N_lowers = zeros(size(N, 1))
Expand All @@ -195,14 +189,15 @@ function cell_numbers(sol::EnsembleSolution; alpha=0.05)
end

"""
leading_edges(sol::EnsembleSolution; alpha=0.05)
leading_edges(sol::EnsembleSolution; indices = eachindex(sol), alpha=0.05)
Computes summary statistics for the leading edges from an `EnsembleSolution` to a [`CellProblem`](@ref).
# Arguments
- `sol::EnsembleSolution`: The ensemble solution to a `CellProblem`.
# Keyword Arguments
- `indices = eachindex(sol)`: The indices of the cell simulations to consider.
- `alpha::Float64 = 0.05`: The significance level for the confidence intervals.
# Outputs
Expand All @@ -211,9 +206,9 @@ Computes summary statistics for the leading edges from an `EnsembleSolution` to
- `lowers::Vector{Float64}`: The lower bounds of the confidence intervals for the leading edges for each cell simulation.
- `uppers::Vector{Float64}`: The upper bounds of the confidence intervals for the leading edges for each cell simulation.
"""
function leading_edges(sol::EnsembleSolution; alpha=0.05)
L = map(sol) do sol
map(sol) do sol
function leading_edges(sol::EnsembleSolution; indices=eachindex(sol), alpha=0.05)
L = map(indices) do i
map(sol[i]) do sol
sol[end]
end
end |> x -> reduce(hcat, x)
Expand Down
92 changes: 92 additions & 0 deletions test/step_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,46 @@ end
resize_to_layout!(fig)
fig_path = normpath(@__DIR__, "..", "docs", "src", "figures")
@test_reference joinpath(fig_path, "step_function_proliferation.png") fig by = psnr_equality(16.5)

# Test the statistics when restricting to a specific set of simulation indices
_indices = rand(eachindex(sol), 20)
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices)
@inferred node_densities(sol; indices=_indices)
N, N_means, N_lowers, N_uppers = cell_numbers(sol; indices=_indices)
@inferred cell_numbers(sol; indices=_indices)
@test all((LinRange(0, 30, 500)), knots)
for (enum_k, k) in enumerate(_indices)
for j in rand(1:length(sol[k]), 40)
for i in rand(1:length(sol[k][j]), 60)
if i == 1
@test q[enum_k][j][1] 1 / (r[enum_k][j][2] - r[enum_k][j][1])
elseif i == length(sol[k][j])
n = length(sol[k][j])
@test q[enum_k][j][n] 1 / (r[enum_k][j][n] - r[enum_k][j][n-1])
else
@test q[enum_k][j][i] 2 / (r[enum_k][j][i+1] - r[enum_k][j][i-1])
end
@test r[enum_k][j][i] == sol[k][j][i]
end
@test N[enum_k][j] length(sol[k][j]) - 1
end
end
for j in rand(1:length(fvm_sol), 50)
Nj = [N[k][j] for k in eachindex(_indices)]
@test mean(Nj) N_means[j]
@test quantile(Nj, 0.025) N_lowers[j]
@test quantile(Nj, 0.975) N_uppers[j]
@test pde_N[j] DataInterpolations.integral(
LinearInterpolation(fvm_sol.u[j], fvm_sol.prob.p.geometry.mesh_points),
0.0, 30.0
)
for i in rand(1:length(knots[j]), 50)
all_q = [LinearInterpolation(q[k][j], r[k][j])(knots[j][i]) for k in eachindex(_indices)]
@test mean(all_q) means[j][i]
@test quantile(all_q, 0.025) lowers[j][i]
@test quantile(all_q, 0.975) uppers[j][i]
end
end
end

@testset "Proliferation with a Moving Boundary" begin
Expand Down Expand Up @@ -363,4 +403,56 @@ end
resize_to_layout!(fig)
fig_path = normpath(@__DIR__, "..", "docs", "src", "figures")
@test_reference joinpath(fig_path, "step_function_proliferation_moving_boundary.png") fig by = psnr_equality(15)

# Test the statistics when restricting to a specific set of simulation indices
_indices = rand(eachindex(sol), 20)
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices)
@inferred node_densities(sol; indices=_indices)
N, N_means, N_lowers, N_uppers = cell_numbers(sol; indices=_indices)
@inferred cell_numbers(sol; indices=_indices)
L, L_means, L_lowers, L_uppers = leading_edges(sol; indices=_indices)
for j in eachindex(knots)
a = Inf
b = -Inf
m = minimum(sol[k][j][begin] for k in _indices)
M = maximum(sol[k][j][end] for k in _indices)
@test knots[j] == LinRange(m, M, 500)
end
for (enum_k, k) in enumerate(_indices)
for j in rand(1:length(sol[k]), 40)
for i in rand(1:length(sol[k][j]), 60)
if i == 1
@test q[enum_k][j][1] 1 / (r[enum_k][j][2] - r[enum_k][j][1])
elseif i == length(sol[k][j])
n = length(sol[k][j])
@test q[enum_k][j][n] 1 / (r[enum_k][j][n] - r[enum_k][j][n-1])
else
@test q[enum_k][j][i] 2 / (r[enum_k][j][i+1] - r[enum_k][j][i-1])
end
@test r[enum_k][j][i] == sol[k][j][i]
end
@test N[enum_k][j] length(sol[k][j]) - 1
end
end
for j in rand(eachindex(mb_sol), 40)
Nj = [N[k][j] for k in eachindex(_indices)]
@test @views mean(Nj) N_means[j]
@test @views quantile(Nj, 0.025) N_lowers[j]
@test @views quantile(Nj, 0.975) N_uppers[j]
Lj = [L[k][j] for k in eachindex(_indices)]
@test @views mean(Lj) L_means[j]
@test @views quantile(Lj, 0.025) L_lowers[j]
@test @views quantile(Lj, 0.975) L_uppers[j]
@test pde_N[j] DataInterpolations.integral(
LinearInterpolation(mb_sol.u[j][begin:(end-1)], mb_sol.u[j][end] * mb_sol.prob.p.geometry.mesh_points),
0.0, mb_sol.u[j][end]
)
@test pde_L[j] mb_sol.u[j][end]
for i in rand(eachindex(knots[j]), 60)
all_q = max.(0.0, [LinearInterpolation(q[k][j], r[k][j])(knots[j][i]) * (knots[j][i] r[k][j][end]) for k in eachindex(_indices)])
@test mean(all_q) means[j][i] rtol = 1e-3
@test quantile(all_q, 0.025) lowers[j][i] rtol = 1e-3
@test quantile(all_q, 0.975) uppers[j][i] rtol = 1e-3
end
end
end

0 comments on commit c0f5447

Please sign in to comment.