Skip to content

Metropolis extension #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/MolecularEvolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ export
isinternalnode,
isbranchnode,
reroot!,
nni_update!,
nni_optim!,
branchlength_update!,
branchlength_optim!,
metropolis_sample,
metropolis_step,
copy_tree,

#util functions
Expand All @@ -137,6 +140,7 @@ export
P_from_diagonalized_Q,
scale_cols_by_vec!,
BranchlengthSampler,
softmax_sampler,

#things the user might overload
copy_partition_to!,
Expand Down
102 changes: 64 additions & 38 deletions src/bayes/sampling.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
"""
function metropolis_sample(
metropolis_sample(
update!::Function,
initial_tree::FelNode,
models::Vector{<:BranchModel},
num_of_samples;
bl_modifier::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1))
burn_in=1000,
sample_interval=10,
burn_in = 1000,
sample_interval = 10,
collect_LLs = false,
midpoint_rooting=false,
midpoint_rooting = false,
ladderize = false,
)

Samples tree topologies from a posterior distribution.
Samples tree topologies from a posterior distribution using a custom `update!` function.

# Arguments
- `update!`: A function that takes (tree::FelNode, models::Vector{<:BranchModel}) and updates `tree`. `update!` takes (tree::FelNode, models::Vector{<:BranchModel}) and updates `tree`. One call to `update!` corresponds to one iteration of the Metropolis algorithm.
- `initial_tree`: An initial tree topology with the leaves populated with data, for the likelihood calculation.
- `models`: A list of branch models.
- `num_of_samples`: The number of tree samples drawn from the posterior.
- `bl_sampler`: Sampler used to drawn branchlengths from the posterior.
- `burn_in`: The number of samples discarded at the start of the Markov Chain.
- `sample_interval`: The distance between samples in the underlying Markov Chain (to reduce sample correlation).
- `collect_LLs`: Specifies if the function should return the log-likelihoods of the trees.
Expand All @@ -30,16 +31,16 @@ Samples tree topologies from a posterior distribution.
- `sample_LLs`: The associated log-likelihoods of the tree (optional).
"""
function metropolis_sample(
update!::Function,
initial_tree::FelNode,
models::Vector{<:BranchModel},
num_of_samples;
bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1)),
burn_in=1000,
sample_interval=10,
burn_in = 1000,
sample_interval = 10,
collect_LLs = false,
midpoint_rooting=false,
midpoint_rooting = false,
ladderize = false,
)
)

# The prior over the (log) of the branchlengths should be specified in bl_sampler.
# Furthermore, a non-informative/uniform prior is assumed over the tree topolgies (excluding the branchlengths).
Expand All @@ -48,30 +49,27 @@ function metropolis_sample(
samples = FelNode[]
tree = deepcopy(initial_tree)
iterations = burn_in + num_of_samples * sample_interval

softmax_sampler = x -> rand(Categorical(softmax(x)))
for i=1:iterations

# Updates the tree topolgy and branchlengths.
nni_optim!(tree, x -> models, selection_rule = softmax_sampler)
branchlength_optim!(tree, x -> models, bl_modifier = bl_sampler)

if (i-burn_in) % sample_interval == 0 && i > burn_in

push!(samples, copy_tree(tree, true))

if collect_LLs
push!(sample_LLs, log_likelihood!(tree, models))
end


for i = 1:iterations
# Updates the tree topolgy and branchlengths.
update!(tree, models)

if (i - burn_in) % sample_interval == 0 && i > burn_in

push!(samples, copy_tree(tree, true))

if collect_LLs
push!(sample_LLs, log_likelihood!(tree, models))
end

end

end

if midpoint_rooting
for (i,sample) in enumerate(samples)
for (i, sample) in enumerate(samples)
node, len = midpoint(sample)
samples[i] = reroot!(node, dist_above_child=len)
samples[i] = reroot!(node, dist_above_child = len)
end
end

Expand All @@ -88,6 +86,36 @@ function metropolis_sample(
return samples
end

"""
metropolis_sample(
initial_tree::FelNode,
models::Vector{<:BranchModel},
num_of_samples;
bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1))
burn_in=1000,
sample_interval=10,
collect_LLs = false,
midpoint_rooting=false,
)

A convenience method. One step of the Metropolis algorithm is performed by calling [`nni_update!`](@ref) with `softmax_sampler` and [`branchlength_update!`](@ref) with `bl_sampler`.

# Additional Arguments
- `bl_sampler`: Sampler used to drawn branchlengths from the posterior.
"""
function metropolis_sample(
initial_tree::FelNode,
models::Vector{<:BranchModel},
num_of_samples;
bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0, 2), Normal(-1, 1)),
kwargs...,
)
metropolis_sample(initial_tree, models, num_of_samples; kwargs...) do tree, models
nni_update!(softmax_sampler, tree, x -> models)
branchlength_update!(bl_sampler, tree, x -> models)
end
end

# Below are some functions that help to assess the mixing by looking at the distance between leaf nodes.

"""
Expand All @@ -109,16 +137,14 @@ end
Returns a matrix of the distances between the leaf nodes where the index on the columns and rows are sorted by the leaf names.
"""
function leaf_distmat(tree)

distmat, node_dic = MolecularEvolution.tree2distances(tree)

leaflist = getleaflist(tree)
sort!(leaflist, by = x-> x.name)

sort!(leaflist, by = x -> x.name)

order = [node_dic[leaf] for leaf in leaflist]

return distmat[order, order]
end


96 changes: 62 additions & 34 deletions src/core/algorithms/branchlength_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,31 @@ function branch_LL_up(
return tot_LL
end

get_next_branchlength(
bl_sampler::UnivariateModifier,
ll_curr::Real,
ll_prop::Real,
bl_curr::Real,
bl_prop::Real
) = bl_prop

get_next_branchlength(
bl_optimizer::UnivariateOpt,
ll_curr::Real,
ll_prop::Real,
bl_curr::Real,
bl_prop::Real
) = ifelse(ll_prop > ll_curr, bl_prop, bl_curr)

#I need to add a version of this that takes a generic optimizer function and uses that instead of golden_section_maximize on just the branchlength.
#This is for cases where the user stores node-level parameters and wants to optimize them.
function branchlength_optim!(
function branchlength_update!(
bl_modifier::UnivariateModifier,
temp_messages::Vector{Vector{T}},
tree::FelNode,
models,
partition_list,
tol;
bl_modifier::UnivariateModifier = GoldenSectionOpt(),
traversal = Iterators.reverse
) where {T <: Partition}

Expand Down Expand Up @@ -111,9 +127,12 @@ function branchlength_optim!(
model_list = models(node)
fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list)
bl = univariate_modifier(fun, bl_modifier; a=0+tol, b=1-tol, tol=tol, transform=unit_transform, curr_value=node.branchlength)
if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt)
node.branchlength = bl
end
#Next, we dispatch on the bl_modifier type to get the next branchlength
#=
Note: for a user-defined bl_modifier, this can be overloaded,
the default behvaiour is just to return bl
=#
node.branchlength = get_next_branchlength(bl_modifier, fun(node.branchlength), fun(bl), node.branchlength, bl)
#Consider checking for improvement, and bailing if none.
#Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one]
for part in partition_list
Expand All @@ -131,50 +150,59 @@ end

#BM: Check if running felsenstein_down! makes a difference.
"""
branchlength_optim!(tree::FelNode, models; <keyword arguments>)
branchlength_update!(bl_modifier::UnivariateModifier, tree::FelNode, models; <keyword arguments>)

Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages.
Requires felsenstein!() to have been run first.
models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or
a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another.
A more general version of [`branchlength_optim!`](@ref). Here `bl_modifier` can be either an optimizer or a sampler (or more generally, a UnivariateModifier).

# Keyword Arguments
- `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models, the default option).
- `tol=1e-5`: absolute tolerance for the `bl_modifier`.
- `bl_modifier=GoldenSectionOpt()`: can either be a optimizer or a sampler (subtype of UnivariateModifier). For optimization, in addition to golden section search, Brent's method can be used by setting bl_modifier=BrentsMethodOpt().
- `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized.
- `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable.
- `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`.
See [`branchlength_optim!`](@ref).
!!! note
`bl_modifier` is a positional argument here, and not a keyword argument.
"""
function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_modifier::UnivariateModifier = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse, shuffle = false)
function branchlength_update!(bl_modifier::UnivariateModifier, tree::FelNode, models; partition_list = nothing, tol = 1e-5, sort_tree = false, traversal = Iterators.reverse, shuffle = false)
sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed
temp_messages = [copy_message(tree.message)]

if partition_list === nothing
partition_list = 1:length(tree.message)
end

branchlength_optim!(temp_messages, tree, models, partition_list, tol, bl_modifier=bl_modifier, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal)
branchlength_update!(bl_modifier, temp_messages, tree, models, partition_list, tol, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal)
end

#Overloading to allow for direct model and model vec inputs
branchlength_optim!(
branchlength_update!(
bl_modifier::UnivariateModifier,
tree::FelNode,
models::Vector{<:BranchModel};
partition_list = nothing,
tol = 1e-5,
bl_modifier::UnivariateModifier = GoldenSectionOpt(),
sort_tree = false,
traversal = Iterators.reverse,
shuffle = false
) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, bl_modifier = bl_modifier, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle)
branchlength_optim!(
kwargs...
) = branchlength_update!(bl_modifier, tree, x -> models; kwargs...)
branchlength_update!(
bl_modifier::UnivariateModifier,
tree::FelNode,
model::BranchModel;
partition_list = nothing,
tol = 1e-5,
bl_modifier::UnivariateModifier = GoldenSectionOpt(),
sort_tree = false,
traversal = Iterators.reverse,
shuffle = false
) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_modifier = bl_modifier, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle)
kwargs...
) = branchlength_update!(bl_modifier, tree, x -> [model]; kwargs...)

"""
branchlength_optim!(tree::FelNode, models; <keyword arguments>)

Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages.
Requires felsenstein!() to have been run first.
models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or
a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another.

# Keyword Arguments
- `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models, the default option).
- `tol=1e-5`: absolute tolerance for the `bl_optimizer`.
- `bl_optimizer::UnivariateModifier=GoldenSectionOpt()`: the algorithm used to optimize the log likelihood of a branch length. In addition to golden section search, Brent's method can be used by setting `bl_optimizer=BrentsMethodOpt()`.
- `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized.
- `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable.
- `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`.
"""
branchlength_optim!(
args...;
bl_optimizer::UnivariateModifier = GoldenSectionOpt(),
kwargs...
) = branchlength_update!(bl_optimizer, args...; kwargs...)

Loading
Loading