Skip to content

Commit 45f32cd

Browse files
authored
Merge pull request #53 from nossleinad/Hastings-root-pos
Metropolis-Hastings root position sampler
2 parents b860f44 + 69ade96 commit 45f32cd

File tree

6 files changed

+97
-15
lines changed

6 files changed

+97
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MolecularEvolution"
22
uuid = "9f975960-e239-4209-8aa0-3d3ad5a82892"
33
authors = ["Ben Murrell <[email protected]> and contributors"]
4-
version = "0.2.5"
4+
version = "0.2.6"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/core/algorithms/AbstractUpdate.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ end
7070
Normal(0, 2),
7171
Normal(-1, 1),
7272
),
73-
root_sampler::RootSample = StandardRootSample(1),
73+
root_sampler::RootSample = StandardRootSample(1.0, 1),
7474
models_sampler::ModelsUpdate = StandardModelsUpdate()
7575
)
7676
@@ -86,7 +86,7 @@ BayesUpdate(;
8686
Normal(0, 2),
8787
Normal(-1, 1),
8888
),
89-
root_sampler = StandardRootSample(1),
89+
root_sampler = StandardRootSample(1e-2, 1),
9090
models_sampler::ModelsUpdate = StandardModelsUpdate(),
9191
) = StandardUpdate(
9292
nni,

src/core/algorithms/new_update_template.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ either implement the metropolis_step interface separately for current values of
6161
and...
6262
=#
6363
Base.length(root_sample::MyRootSample) = error("length() not yet implemented for $(typeof(root_sample)). Required for root_update!.") # the number of consecutive samples of root state and position for a single update call
64-
64+
#=
65+
(if you want to be a subtype of UniformRootPositionSample, implement
66+
radius(::MyRootSample, total_bl::Real) = error("radius() not yet implemented for $(typeof(root_sample)). Required for root_update!.") # the local radius of the uniform proposal. Can be absolute or relative to the total branchlength.
67+
=#
6568
#or
6669
function (root_sample::MyRootSample)(tree::FelNode, models, partition_list, node_message::Vector{<:Partition}, temp_message::Vector{<:Partition})
6770
error("")

src/core/algorithms/root_optim.jl

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@ mutable struct StandardRootOpt <: RootOpt
99
K::Int
1010
end
1111

12-
struct StandardRootSample <: UniformRootPositionSample
12+
mutable struct StandardRootSample <: UniformRootPositionSample
1313
acc_ratio::Tuple{Float64, Int64, Int64}
14+
radius::Float64 #[0, 1] multiplier of the total branchlength
1415
consecutive::Int64
1516

16-
function StandardRootSample(consecutive::Int64)
17-
new((0.0, 0, 0), consecutive)
17+
function StandardRootSample(radius::Float64,consecutive::Int64)
18+
new((0.0, 0, 0), radius, consecutive)
1819
end
1920
end
2021

2122
Base.length(root_sample::StandardRootSample) = root_sample.consecutive
23+
radius(root_sample::StandardRootSample, total_bl::Real) = root_sample.radius * total_bl
2224

2325
#Assume that felsenstein_roundtrip! has been called
2426
#Compute the log likelihood of observations below this root-candidate
@@ -211,13 +213,83 @@ function (root_sample::RootSample)(tree::FelNode, models, partition_list, node_m
211213
return merge(sampled_position, (state=sampled_state,))
212214
end
213215

214-
# Propose a new root position with a global uniform distribution
215-
function proposal(::UniformRootPositionSample, curr_value::@NamedTuple{root::FelNode, dist_above_node::Float64})
216-
nodelist = getnodelist(curr_value.root)
217-
cum = cumsum(n.branchlength for n in nodelist)
216+
217+
function traverse(node0::FelNode, dist_above_node0::Float64, radius::Float64)
218+
stack = [(node0, dist_above_node0, radius, length(node0.children)+1)]
219+
points = Vector{Tuple{FelNode, Float64, Int64}}()
220+
weights = Vector{Float64}()
221+
222+
while !isempty(stack)
223+
node, dist_above_node, radius_left, prev_ind = pop!(stack)
224+
#interpret 1,2,... as which child ind we're coming from (if it's out of range, we're at node0),
225+
#0 as we're coming from root,
226+
if prev_ind > 0 && !isroot(node) #Upward traversal
227+
child_ind = findfirst(x -> x == node, node.parent.children)
228+
radius_that_would_be_left = radius_left - (node.branchlength - dist_above_node)
229+
if radius_that_would_be_left < 0.0
230+
push!(points, (node, dist_above_node + radius_left, child_ind))
231+
push!(weights, radius_left)
232+
else
233+
radius_that_would_be_left != 0.0 && push!(stack, (node.parent, 0.0, radius_that_would_be_left, child_ind))
234+
push!(points, (node, node.branchlength, child_ind)) #(node, node.branchlength) <=> (node.parent, 0.0)
235+
push!(weights, node.branchlength - dist_above_node)
236+
end
237+
end
238+
#Downward traversal
239+
radius_that_would_be_left = radius_left - dist_above_node
240+
if radius_that_would_be_left < 0.0
241+
push!(points, (node, -radius_that_would_be_left, 0))
242+
push!(weights, radius_left)
243+
continue
244+
end
245+
if dist_above_node != 0.0 #node0 may add itself, otherwise dist_above_node is 0.0
246+
push!(points, (node, 0.0, 0)) #most cases down will be true, but when it is false, we want to remember that
247+
push!(weights, dist_above_node)
248+
end
249+
for (child_ind, child) in enumerate(node.children)
250+
if child_ind == prev_ind
251+
continue
252+
end
253+
radius_that_would_be_left = radius_left - dist_above_node - child.branchlength
254+
if radius_that_would_be_left < 0.0
255+
push!(points, (child, -radius_that_would_be_left, 0))
256+
push!(weights, radius_left - dist_above_node)
257+
continue
258+
end
259+
radius_that_would_be_left != 0.0 && push!(stack, (child, 0.0, radius_that_would_be_left, 0))
260+
push!(points, (child, 0.0, 0))
261+
push!(weights, child.branchlength)
262+
end
263+
end
264+
return points, weights
265+
end
266+
267+
function total_bl(node::FelNode)
268+
while !isroot(node)
269+
node = node.parent
270+
end
271+
return sum(n.branchlength for n in nodes(node))
272+
end
273+
274+
function log_proposal(modifier::UniformRootPositionSample,
275+
x::@NamedTuple{root::FelNode, dist_above_node::Float64},
276+
conditioned_on::@NamedTuple{root::FelNode, dist_above_node::Float64})
277+
points, weights = traverse(conditioned_on..., radius(modifier, total_bl(x.root)))
278+
#if x is not within radius(modifier) of conditioned_on, then we should return log(0.0)
279+
return -log(sum(weights))
280+
end
281+
282+
# Propose a new root position with a local uniform distribution
283+
function proposal(modifier::UniformRootPositionSample, curr_value::@NamedTuple{root::FelNode, dist_above_node::Float64})
284+
points, weights = traverse(curr_value..., radius(modifier, total_bl(curr_value.root)))
285+
#Sample a new root position within radius radius(modifier, <total branchlength>)
286+
cum = cumsum(weights)
218287
sample = rand() * cum[end]
219288
idx = searchsortedfirst(cum, sample)
220-
return (root=nodelist[idx], dist_above_node=cum[idx] - sample)
289+
diff = cum[idx] - sample
290+
node, dist_above_node, prev_ind = points[idx]
291+
dist_above_node += prev_ind > 0 ? -diff : diff
292+
return (root=node, dist_above_node=dist_above_node)
221293
end
222294

223295
log_prior(::UniformRootPositionSample, curr_value::@NamedTuple{root::FelNode, dist_above_node::Float64}) = 0.0 #Uninformative/improper prior

src/utils/simple_sample.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ You need a `MySampler <: Any` to implement
3939
Although, it is possible to transform the current value before proposing a new value, and
4040
then take the inverse transform to match the argument `LL` expects.
4141
# Extended interface
42+
## Hastings
43+
To allow for asymmetric proposals, you must overload
44+
- `log_proposal(modifier::MySampler, x, conditioned_on)`
45+
which returns a constant (`0.0` in particular) by default.
46+
## Transformations
4247
To make proposals in a transformed space, you overload
4348
- `tr(modifier::MySampler, x)`
4449
- `invtr(modifier::MySampler, x)`
@@ -49,13 +54,15 @@ function metropolis_step(LL::Function, modifier, curr_value)
4954
tr_prop = proposal(modifier, tr_curr_value)
5055
accept_proposal =
5156
rand() <= exp(
52-
LL(invtr(modifier, tr_prop)) + log_prior(modifier, tr_prop) -
53-
LL(invtr(modifier, tr_curr_value)) - log_prior(modifier, tr_curr_value),
57+
LL(invtr(modifier, tr_prop)) + log_prior(modifier, tr_prop) + log_proposal(modifier, tr_curr_value, tr_prop) -
58+
LL(invtr(modifier, tr_curr_value)) - log_prior(modifier, tr_curr_value) - log_proposal(modifier, tr_prop, tr_curr_value),
5459
)
5560
apply_decision(modifier, accept_proposal)
5661
return invtr(modifier, ifelse(accept_proposal, tr_prop, tr_curr_value))
5762
end
5863

64+
log_proposal(modifier, x, y) = 0.0
65+
5966
# The prior distribution for the variable log(branchlength). A small perturbation of +1e-12 is added to enhance numerical stability near zero.
6067
proposal(modifier::BranchlengthSampler, curr_value) = exp(log(curr_value) + rand(modifier.log_bl_proposal))
6168
log_prior(modifier::BranchlengthSampler, x) = logpdf(modifier.log_bl_prior, log(x + 1e-12))

test/partition_selection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,6 @@ begin
8787

8888
updater = MaxLikUpdate(root=1)
8989
tree, bm_models = updater(tree, bm_models, partition_list = [1])
90-
updater = BayesUpdate(root=0, models=1)
90+
updater = BayesUpdate(root=1, models=1)
9191
tree, bm_models = updater(tree, bm_models, partition_list = [1])
9292
end

0 commit comments

Comments
 (0)