Skip to content

Commit 040d2a6

Browse files
authored
Merge pull request #42 from nossleinad/metropolis-ext
Metropolis extension
2 parents c49f93f + a6bcc3d commit 040d2a6

File tree

7 files changed

+210
-118
lines changed

7 files changed

+210
-118
lines changed

src/MolecularEvolution.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,12 @@ export
119119
isinternalnode,
120120
isbranchnode,
121121
reroot!,
122+
nni_update!,
122123
nni_optim!,
124+
branchlength_update!,
123125
branchlength_optim!,
124126
metropolis_sample,
127+
metropolis_step,
125128
copy_tree,
126129

127130
#util functions
@@ -137,6 +140,7 @@ export
137140
P_from_diagonalized_Q,
138141
scale_cols_by_vec!,
139142
BranchlengthSampler,
143+
softmax_sampler,
140144

141145
#things the user might overload
142146
copy_partition_to!,

src/bayes/sampling.jl

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
"""
2-
function metropolis_sample(
2+
metropolis_sample(
3+
update!::Function,
34
initial_tree::FelNode,
45
models::Vector{<:BranchModel},
56
num_of_samples;
6-
bl_modifier::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1))
7-
burn_in=1000,
8-
sample_interval=10,
7+
burn_in = 1000,
8+
sample_interval = 10,
99
collect_LLs = false,
10-
midpoint_rooting=false,
10+
midpoint_rooting = false,
11+
ladderize = false,
1112
)
1213
13-
Samples tree topologies from a posterior distribution.
14+
Samples tree topologies from a posterior distribution using a custom `update!` function.
1415
1516
# Arguments
17+
- `update!`: A function that takes (tree::FelNode, models::Vector{<:BranchModel}) and updates `tree`. `update!` takes (tree::FelNode, models::Vector{<:BranchModel}) and updates `tree`. One call to `update!` corresponds to one iteration of the Metropolis algorithm.
1618
- `initial_tree`: An initial tree topology with the leaves populated with data, for the likelihood calculation.
1719
- `models`: A list of branch models.
1820
- `num_of_samples`: The number of tree samples drawn from the posterior.
19-
- `bl_sampler`: Sampler used to drawn branchlengths from the posterior.
2021
- `burn_in`: The number of samples discarded at the start of the Markov Chain.
2122
- `sample_interval`: The distance between samples in the underlying Markov Chain (to reduce sample correlation).
2223
- `collect_LLs`: Specifies if the function should return the log-likelihoods of the trees.
@@ -30,16 +31,16 @@ Samples tree topologies from a posterior distribution.
3031
- `sample_LLs`: The associated log-likelihoods of the tree (optional).
3132
"""
3233
function metropolis_sample(
34+
update!::Function,
3335
initial_tree::FelNode,
3436
models::Vector{<:BranchModel},
3537
num_of_samples;
36-
bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1)),
37-
burn_in=1000,
38-
sample_interval=10,
38+
burn_in = 1000,
39+
sample_interval = 10,
3940
collect_LLs = false,
40-
midpoint_rooting=false,
41+
midpoint_rooting = false,
4142
ladderize = false,
42-
)
43+
)
4344

4445
# The prior over the (log) of the branchlengths should be specified in bl_sampler.
4546
# Furthermore, a non-informative/uniform prior is assumed over the tree topolgies (excluding the branchlengths).
@@ -48,30 +49,27 @@ function metropolis_sample(
4849
samples = FelNode[]
4950
tree = deepcopy(initial_tree)
5051
iterations = burn_in + num_of_samples * sample_interval
51-
52-
softmax_sampler = x -> rand(Categorical(softmax(x)))
53-
for i=1:iterations
54-
55-
# Updates the tree topolgy and branchlengths.
56-
nni_optim!(tree, x -> models, selection_rule = softmax_sampler)
57-
branchlength_optim!(tree, x -> models, bl_modifier = bl_sampler)
58-
59-
if (i-burn_in) % sample_interval == 0 && i > burn_in
60-
61-
push!(samples, copy_tree(tree, true))
62-
63-
if collect_LLs
64-
push!(sample_LLs, log_likelihood!(tree, models))
65-
end
66-
52+
53+
for i = 1:iterations
54+
# Updates the tree topolgy and branchlengths.
55+
update!(tree, models)
56+
57+
if (i - burn_in) % sample_interval == 0 && i > burn_in
58+
59+
push!(samples, copy_tree(tree, true))
60+
61+
if collect_LLs
62+
push!(sample_LLs, log_likelihood!(tree, models))
6763
end
6864

65+
end
66+
6967
end
7068

7169
if midpoint_rooting
72-
for (i,sample) in enumerate(samples)
70+
for (i, sample) in enumerate(samples)
7371
node, len = midpoint(sample)
74-
samples[i] = reroot!(node, dist_above_child=len)
72+
samples[i] = reroot!(node, dist_above_child = len)
7573
end
7674
end
7775

@@ -88,6 +86,36 @@ function metropolis_sample(
8886
return samples
8987
end
9088

89+
"""
90+
metropolis_sample(
91+
initial_tree::FelNode,
92+
models::Vector{<:BranchModel},
93+
num_of_samples;
94+
bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1))
95+
burn_in=1000,
96+
sample_interval=10,
97+
collect_LLs = false,
98+
midpoint_rooting=false,
99+
)
100+
101+
A convenience method. One step of the Metropolis algorithm is performed by calling [`nni_update!`](@ref) with `softmax_sampler` and [`branchlength_update!`](@ref) with `bl_sampler`.
102+
103+
# Additional Arguments
104+
- `bl_sampler`: Sampler used to drawn branchlengths from the posterior.
105+
"""
106+
function metropolis_sample(
107+
initial_tree::FelNode,
108+
models::Vector{<:BranchModel},
109+
num_of_samples;
110+
bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0, 2), Normal(-1, 1)),
111+
kwargs...,
112+
)
113+
metropolis_sample(initial_tree, models, num_of_samples; kwargs...) do tree, models
114+
nni_update!(softmax_sampler, tree, x -> models)
115+
branchlength_update!(bl_sampler, tree, x -> models)
116+
end
117+
end
118+
91119
# Below are some functions that help to assess the mixing by looking at the distance between leaf nodes.
92120

93121
"""
@@ -109,16 +137,14 @@ end
109137
Returns a matrix of the distances between the leaf nodes where the index on the columns and rows are sorted by the leaf names.
110138
"""
111139
function leaf_distmat(tree)
112-
140+
113141
distmat, node_dic = MolecularEvolution.tree2distances(tree)
114-
142+
115143
leaflist = getleaflist(tree)
116-
117-
sort!(leaflist, by = x-> x.name)
118-
144+
145+
sort!(leaflist, by = x -> x.name)
146+
119147
order = [node_dic[leaf] for leaf in leaflist]
120-
148+
121149
return distmat[order, order]
122150
end
123-
124-

src/core/algorithms/branchlength_optim.jl

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,31 @@ function branch_LL_up(
1919
return tot_LL
2020
end
2121

22+
get_next_branchlength(
23+
bl_sampler::UnivariateModifier,
24+
ll_curr::Real,
25+
ll_prop::Real,
26+
bl_curr::Real,
27+
bl_prop::Real
28+
) = bl_prop
29+
30+
get_next_branchlength(
31+
bl_optimizer::UnivariateOpt,
32+
ll_curr::Real,
33+
ll_prop::Real,
34+
bl_curr::Real,
35+
bl_prop::Real
36+
) = ifelse(ll_prop > ll_curr, bl_prop, bl_curr)
37+
2238
#I need to add a version of this that takes a generic optimizer function and uses that instead of golden_section_maximize on just the branchlength.
2339
#This is for cases where the user stores node-level parameters and wants to optimize them.
24-
function branchlength_optim!(
40+
function branchlength_update!(
41+
bl_modifier::UnivariateModifier,
2542
temp_messages::Vector{Vector{T}},
2643
tree::FelNode,
2744
models,
2845
partition_list,
2946
tol;
30-
bl_modifier::UnivariateModifier = GoldenSectionOpt(),
3147
traversal = Iterators.reverse
3248
) where {T <: Partition}
3349

@@ -111,9 +127,12 @@ function branchlength_optim!(
111127
model_list = models(node)
112128
fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list)
113129
bl = univariate_modifier(fun, bl_modifier; a=0+tol, b=1-tol, tol=tol, transform=unit_transform, curr_value=node.branchlength)
114-
if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt)
115-
node.branchlength = bl
116-
end
130+
#Next, we dispatch on the bl_modifier type to get the next branchlength
131+
#=
132+
Note: for a user-defined bl_modifier, this can be overloaded,
133+
the default behvaiour is just to return bl
134+
=#
135+
node.branchlength = get_next_branchlength(bl_modifier, fun(node.branchlength), fun(bl), node.branchlength, bl)
117136
#Consider checking for improvement, and bailing if none.
118137
#Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one]
119138
for part in partition_list
@@ -131,50 +150,59 @@ end
131150

132151
#BM: Check if running felsenstein_down! makes a difference.
133152
"""
134-
branchlength_optim!(tree::FelNode, models; <keyword arguments>)
153+
branchlength_update!(bl_modifier::UnivariateModifier, tree::FelNode, models; <keyword arguments>)
135154
136-
Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages.
137-
Requires felsenstein!() to have been run first.
138-
models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or
139-
a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another.
155+
A more general version of [`branchlength_optim!`](@ref). Here `bl_modifier` can be either an optimizer or a sampler (or more generally, a UnivariateModifier).
140156
141157
# Keyword Arguments
142-
- `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models, the default option).
143-
- `tol=1e-5`: absolute tolerance for the `bl_modifier`.
144-
- `bl_modifier=GoldenSectionOpt()`: can either be a optimizer or a sampler (subtype of UnivariateModifier). For optimization, in addition to golden section search, Brent's method can be used by setting bl_modifier=BrentsMethodOpt().
145-
- `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized.
146-
- `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable.
147-
- `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`.
158+
See [`branchlength_optim!`](@ref).
159+
!!! note
160+
`bl_modifier` is a positional argument here, and not a keyword argument.
148161
"""
149-
function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_modifier::UnivariateModifier = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse, shuffle = false)
162+
function branchlength_update!(bl_modifier::UnivariateModifier, tree::FelNode, models; partition_list = nothing, tol = 1e-5, sort_tree = false, traversal = Iterators.reverse, shuffle = false)
150163
sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed
151164
temp_messages = [copy_message(tree.message)]
152165

153166
if partition_list === nothing
154167
partition_list = 1:length(tree.message)
155168
end
156169

157-
branchlength_optim!(temp_messages, tree, models, partition_list, tol, bl_modifier=bl_modifier, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal)
170+
branchlength_update!(bl_modifier, temp_messages, tree, models, partition_list, tol, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal)
158171
end
159172

160173
#Overloading to allow for direct model and model vec inputs
161-
branchlength_optim!(
174+
branchlength_update!(
175+
bl_modifier::UnivariateModifier,
162176
tree::FelNode,
163177
models::Vector{<:BranchModel};
164-
partition_list = nothing,
165-
tol = 1e-5,
166-
bl_modifier::UnivariateModifier = GoldenSectionOpt(),
167-
sort_tree = false,
168-
traversal = Iterators.reverse,
169-
shuffle = false
170-
) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, bl_modifier = bl_modifier, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle)
171-
branchlength_optim!(
178+
kwargs...
179+
) = branchlength_update!(bl_modifier, tree, x -> models; kwargs...)
180+
branchlength_update!(
181+
bl_modifier::UnivariateModifier,
172182
tree::FelNode,
173183
model::BranchModel;
174-
partition_list = nothing,
175-
tol = 1e-5,
176-
bl_modifier::UnivariateModifier = GoldenSectionOpt(),
177-
sort_tree = false,
178-
traversal = Iterators.reverse,
179-
shuffle = false
180-
) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_modifier = bl_modifier, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle)
184+
kwargs...
185+
) = branchlength_update!(bl_modifier, tree, x -> [model]; kwargs...)
186+
187+
"""
188+
branchlength_optim!(tree::FelNode, models; <keyword arguments>)
189+
190+
Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages.
191+
Requires felsenstein!() to have been run first.
192+
models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or
193+
a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another.
194+
195+
# Keyword Arguments
196+
- `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models, the default option).
197+
- `tol=1e-5`: absolute tolerance for the `bl_optimizer`.
198+
- `bl_optimizer::UnivariateModifier=GoldenSectionOpt()`: the algorithm used to optimize the log likelihood of a branch length. In addition to golden section search, Brent's method can be used by setting `bl_optimizer=BrentsMethodOpt()`.
199+
- `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized.
200+
- `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable.
201+
- `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`.
202+
"""
203+
branchlength_optim!(
204+
args...;
205+
bl_optimizer::UnivariateModifier = GoldenSectionOpt(),
206+
kwargs...
207+
) = branchlength_update!(bl_optimizer, args...; kwargs...)
208+

0 commit comments

Comments
 (0)