Skip to content

Commit

Permalink
DA on prior x likelihood factorization
Browse files Browse the repository at this point in the history
  • Loading branch information
arzwa committed Jan 15, 2020
1 parent 1342414 commit 77ecf8f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/Whale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Whale
using Optim
using ForwardDiff
using Random
using Parameters
# using MCMCChains
using DataFrames
using CSV
Expand Down
21 changes: 10 additions & 11 deletions src/mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ end

function Base.rand(m::IRModel, st)
x = State()
for f in fieldnames(typeof(m))
d = getfield(m, f)
x[f] = typeof(d) <: Real ? d : (
typeof(d) <: AbstractArray ? rand.(d) : rand(d))
end
x[] = rand(MvLogNormal(repeat([log(x[])], nrates(st)), x[]))
x[] = rand(MvLogNormal(repeat([log(x[])], nrates(st)), x[]))
x[] = rand(m.η)
x[] = rand(m.ν)
x[:q] = rand(m.q, nwgd(st))
l = rand(m.λ)
m = rand(m.μ)
x[] = rand(MvLogNormal(repeat([log(l)], nrates(st)), x[]))
x[] = rand(MvLogNormal(repeat([log(m)], nrates(st)), x[]))
return x
end

Expand Down Expand Up @@ -305,11 +305,10 @@ function mcmc!(w::WhaleChain, D::CCDArray, n::Int64, args...;
log_mcmc(w, stdout, show_trace, show_every)
backtrack ? backtrack!(D, WhaleModel(w)) : nothing
end
# Chains(w)
w
end

function mcmc!(w::WhaleChain, n::Int64, args...; show_trace=true, show_every=10)
function mcmc!(w::WhaleChain, n::Int64, args...;
show_trace=true, show_every=10, kwargs...)
@warn "No data provided, sampling from the prior"
mcmc!(w, distribute(CCD[get_dummy_ccd()]), n, args...,
show_trace=show_trace, show_every=show_every, backtrack=false)
Expand Down Expand Up @@ -356,7 +355,7 @@ function acceptreject_da(chain, f, g, q)
α1 = min- chain[] + q, 0.)
accept = log(rand()) < α1
= accept ? f() : -Inf
α2 = min(ℓ - state[:l], 0.)
α2 = min(ℓ - chain[:l], 0.)
accept = log(rand()) < α2
return accept, ℓ, π
end
Expand Down
7 changes: 7 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
ConsensusTrees = "d3d4590a-60df-11e9-122f-f56e3d820c59"
DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94"
PhyloTrees = "0d4d4e69-8856-4a2e-902c-8a6c3add14ba"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Whale = "eab5b8f2-ac71-4eb1-ac2b-64791b8dae63"
54 changes: 54 additions & 0 deletions test/mcmc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using Whale, Distributions
using Plots, StatsPlots

# prior check
begin
tree = Whale.example_tree()
w = WhaleChain(tree,
IRModel=LogNormal(-1, 0.5),
μ=LogNormal(-1, 0.5),
ν=Exponential(0.1),
η=Beta(16,4)))
w.da = true
for i=1:100
@time mcmc!(w, 100, show_every=100, backtrack=false)
p = plot(stephist(w.df[!,:λ1],alpha=0.2,fill=true,color=:black,normalize=true),
stephist(w.df[!,:λ7],alpha=0.2,fill=true,color=:black,normalize=true),
stephist(w.df[!,:q1],alpha=0.2,fill=true,color=:black,normalize=true),
stephist(w.df[!,], alpha=0.2,fill=true,color=:black,normalize=true),
plot(log.(w.df[!,:μ4]), color=:black, linewidth=1.5),
plot(log.(w.df[!,:μ9]), color=:black, linewidth=1.5),
legend=false, grid=false)
plot!(p[1], color=:black, w.prior.λ, linewidth=2)
plot!(p[2], color=:black, w.prior.λ, linewidth=2)
plot!(p[3], color=:black, w.prior.q, linewidth=2)
plot!(p[4], color=:black, w.prior.η, linewidth=2)
display(p)
end
end

begin
tree = Whale.example_tree()
ccd = read_ale("example/example-ale/", tree)
w = WhaleChain(tree,
IRModel=LogNormal(-1, 0.5),
μ=LogNormal(-1, 0.5),
ν=Exponential(0.1),
η=Beta(16,4)))
w.da = true
for i=1:100
@time mcmc!(w, ccd, 100, show_every=100, backtrack=false)
p = plot(stephist(w.df[!,:λ1],alpha=0.2,fill=true,color=:black,normalize=true),
stephist(w.df[!,:λ7],alpha=0.2,fill=true,color=:black,normalize=true),
stephist(w.df[!,:q1],alpha=0.2,fill=true,color=:black,normalize=true),
stephist(w.df[!,], alpha=0.2,fill=true,color=:black,normalize=true),
plot(log.(w.df[!,:μ4]), color=:black, linewidth=1.5),
plot(log.(w.df[!,:μ9]), color=:black, linewidth=1.5),
legend=false, grid=false)
plot!(p[1], color=:black, w.prior.λ, linewidth=2)
plot!(p[2], color=:black, w.prior.λ, linewidth=2)
plot!(p[3], color=:black, w.prior.q, linewidth=2)
plot!(p[4], color=:black, w.prior.η, linewidth=2)
display(p)
end
end

0 comments on commit 77ecf8f

Please sign in to comment.