diff --git a/src/MolecularEvolution.jl b/src/MolecularEvolution.jl index 5afd96d..b433ce0 100644 --- a/src/MolecularEvolution.jl +++ b/src/MolecularEvolution.jl @@ -29,7 +29,9 @@ abstract type SimulationModel <: BranchModel end #Simulation models typically ca abstract type StatePath end -abstract type UnivariateOpt end +abstract type UnivariateModifier end +abstract type UnivariateOpt <: UnivariateModifier end +abstract type UnivariateSampler <: UnivariateModifier end abstract type LazyDirection end @@ -39,6 +41,7 @@ include("core/algorithms/algorithms.jl") include("core/sim_tree.jl") include("models/models.jl") include("utils/utils.jl") +include("bayes/bayes.jl") #Optional dependencies function __init__() @@ -115,6 +118,8 @@ export reroot!, nni_optim!, branchlength_optim!, + metropolis_sample, + copy_tree, #util functions one_hot_sample, @@ -128,6 +133,7 @@ export HKY85, P_from_diagonalized_Q, scale_cols_by_vec!, + BranchlengthSampler, #things the user might overload copy_partition_to!, diff --git a/src/bayes/bayes.jl b/src/bayes/bayes.jl new file mode 100644 index 0000000..f9201c5 --- /dev/null +++ b/src/bayes/bayes.jl @@ -0,0 +1 @@ +include("sampling.jl") \ No newline at end of file diff --git a/src/bayes/sampling.jl b/src/bayes/sampling.jl new file mode 100644 index 0000000..2525fbb --- /dev/null +++ b/src/bayes/sampling.jl @@ -0,0 +1,124 @@ +""" + function metropolis_sample( + initial_tree::FelNode, + models::Vector{<:BranchModel}, + num_of_samples; + bl_modifier::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1)) + burn_in=1000, + sample_interval=10, + collect_LLs = false, + midpoint_rooting=false, + ) + +Samples tree topologies from a posterior distribution. + +# Arguments +- `initial_tree`: An initial tree topology with the leaves populated with data, for the likelihood calculation. +- `models`: A list of branch models. +- `num_of_samples`: The number of tree samples drawn from the posterior. +- `bl_sampler`: Sampler used to drawn branchlengths from the posterior. +- `burn_in`: The number of samples discarded at the start of the Markov Chain. +- `sample_interval`: The distance between samples in the underlying Markov Chain (to reduce sample correlation). +- `collect_LLs`: Specifies if the function should return the log-likelihoods of the trees. +- `midpoint_rooting`: Specifies whether the drawn samples should be midpoint rerooted (Important! Should only be used for time-reversible branch models starting in equilibrium). + +!!! note + The leaves of the initial tree should be populated with data and felsenstein! should be called on the initial tree before calling this function. + +# Returns +- `samples`: The trees drawn from the posterior. Returns shallow tree copies, which needs to be repopulated before running felsenstein! etc. +- `sample_LLs`: The associated log-likelihoods of the tree (optional). +""" +function metropolis_sample( + initial_tree::FelNode, + models::Vector{<:BranchModel}, + num_of_samples; + bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1)), + burn_in=1000, + sample_interval=10, + collect_LLs = false, + midpoint_rooting=false, + ladderize = false, + ) + + # The prior over the (log) of the branchlengths should be specified in bl_sampler. + # Furthermore, a non-informative/uniform prior is assumed over the tree topolgies (excluding the branchlengths). + + sample_LLs = [] + samples = FelNode[] + tree = deepcopy(initial_tree) + iterations = burn_in + num_of_samples * sample_interval + + softmax_sampler = x -> rand(Categorical(softmax(x))) + for i=1:iterations + + # Updates the tree topolgy and branchlengths. + nni_optim!(tree, x -> models, selection_rule = softmax_sampler) + branchlength_optim!(tree, x -> models, bl_modifier = bl_sampler) + + if (i-burn_in) % sample_interval == 0 && i > burn_in + + push!(samples, copy_tree(tree, true)) + + if collect_LLs + push!(sample_LLs, log_likelihood!(tree, models)) + end + + end + + end + + if midpoint_rooting + for (i,sample) in enumerate(samples) + node, len = midpoint(sample) + samples[i] = reroot!(node, dist_above_child=len) + end + end + + if ladderize + for sample in samples + ladderize!(sample) + end + end + + if collect_LLs + return samples, sample_LLs + end + + return samples +end + +# Below are some functions that help to assess the mixing by looking at the distance between leaf nodes. + +""" + collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) + + Returns a list of distance matrices containing the distance between the leaf nodes, which can be used to assess mixing. +""" +function collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) + distmats = [] + for tree in trees + push!(distmats, leaf_distmat(tree)) + end + return distmats +end + +""" + leaf_distmat(tree) + + Returns a matrix of the distances between the leaf nodes where the index on the columns and rows are sorted by the leaf names. +""" +function leaf_distmat(tree) + + distmat, node_dic = MolecularEvolution.tree2distances(tree) + + leaflist = getleaflist(tree) + + sort!(leaflist, by = x-> x.name) + + order = [node_dic[leaf] for leaf in leaflist] + + return distmat[order, order] +end + + diff --git a/src/core/algorithms/branchlength_optim.jl b/src/core/algorithms/branchlength_optim.jl index 5ef432d..fb20b21 100644 --- a/src/core/algorithms/branchlength_optim.jl +++ b/src/core/algorithms/branchlength_optim.jl @@ -27,7 +27,7 @@ function branchlength_optim!( models, partition_list, tol; - bl_optimizer::UnivariateOpt = GoldenSectionOpt(), + bl_modifier::UnivariateModifier = GoldenSectionOpt(), traversal = Iterators.reverse ) where {T <: Partition} @@ -87,10 +87,11 @@ function branchlength_optim!( temp_message = pop!(temp_messages) model_list = models(node) fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) - opt = univariate_maximize(fun, 0 + tol, 1 - tol, unit_transform, bl_optimizer, tol) - if fun(opt) > fun(node.branchlength) - node.branchlength = opt + bl = univariate_modifier(fun, bl_modifier; a=0+tol, b=1-tol, tol=tol, transform=unit_transform, curr_value=node.branchlength) + if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt) + node.branchlength = bl end + #Consider checking for improvement, and bailing if none. #Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one] for part in partition_list @@ -109,9 +110,9 @@ function branchlength_optim!( #------------------- model_list = models(node) fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) - opt = univariate_maximize(fun, 0 + tol, 1 - tol, unit_transform, bl_optimizer, tol) - if fun(opt) > fun(node.branchlength) - node.branchlength = opt + bl = univariate_modifier(fun, bl_modifier; a=0+tol, b=1-tol, tol=tol, transform=unit_transform, curr_value=node.branchlength) + if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt) + node.branchlength = bl end #Consider checking for improvement, and bailing if none. #Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one] @@ -139,13 +140,13 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th # Keyword Arguments - `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models, the default option). -- `tol=1e-5`: absolute tolerance for the `bl_optimizer`. -- `bl_optimizer=GoldenSectionOpt()`: univariate branchlength optimizer, has Brent's method as an option by setting bl_optimizer=BrentsMethodOpt(). +- `tol=1e-5`: absolute tolerance for the `bl_modifier`. +- `bl_modifier=GoldenSectionOpt()`: can either be a optimizer or a sampler (subtype of UnivariateModifier). For optimization, in addition to golden section search, Brent's method can be used by setting bl_modifier=BrentsMethodOpt(). - `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized. - `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable. - `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`. """ -function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse, shuffle = false) +function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_modifier::UnivariateModifier = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse, shuffle = false) sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed temp_messages = [copy_message(tree.message)] @@ -153,7 +154,7 @@ function branchlength_optim!(tree::FelNode, models; partition_list = nothing, to partition_list = 1:length(tree.message) end - branchlength_optim!(temp_messages, tree, models, partition_list, tol, bl_optimizer=bl_optimizer, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal) + branchlength_optim!(temp_messages, tree, models, partition_list, tol, bl_modifier=bl_modifier, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal) end #Overloading to allow for direct model and model vec inputs @@ -162,18 +163,18 @@ branchlength_optim!( models::Vector{<:BranchModel}; partition_list = nothing, tol = 1e-5, - bl_optimizer::UnivariateOpt = GoldenSectionOpt(), + bl_modifier::UnivariateModifier = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse, shuffle = false -) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, bl_optimizer = bl_optimizer, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle) +) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, bl_modifier = bl_modifier, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle) branchlength_optim!( tree::FelNode, model::BranchModel; partition_list = nothing, tol = 1e-5, - bl_optimizer::UnivariateOpt = GoldenSectionOpt(), + bl_modifier::UnivariateModifier = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse, shuffle = false -) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_optimizer = bl_optimizer, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle) +) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_modifier = bl_modifier, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle) diff --git a/src/core/algorithms/nni_optim.jl b/src/core/algorithms/nni_optim.jl index 6e8cc84..ec7a39c 100644 --- a/src/core/algorithms/nni_optim.jl +++ b/src/core/algorithms/nni_optim.jl @@ -1,13 +1,3 @@ -#= -About clades getting skipped: -- the iterative implementation perfectly mimics the recursive one (they can both skip clades) -- some nnis can lead to some clades not getting optimized and some getting optimized multiple times -- I could push "every other" during first down and use lastind to know if a clade's been visisted, if a sibling clade's not been visited, I'll simply not fel-up yet but continue down -- -- Sanity checks: compare switch_LL with log_likelihood! of deepcopied tree with said switch -full_traversal passed the sanity check -=# - #After a do_nni, we have to update parent_message if we want to continue down (assume that temp_message is the forwarded parent.parent_message) function update_parent_message!( node::FelNode, @@ -29,12 +19,12 @@ function update_parent_message!( end end -function nni_optim_full_traversal!( +function nni_optim!( temp_messages::Vector{Vector{T}}, tree::FelNode, models, partition_list; - acc_rule = (x, y) -> x > y, + selection_rule = x -> argmax(x), traversal = Iterators.reverse ) where {T <: Partition} @@ -99,12 +89,12 @@ function nni_optim_full_traversal!( temp_message = pop!(temp_messages) model_list = models(node) if first #We only do_nni first up - nnid, exceed_sib, exceed_child = do_nni( + nnid, sampled_sib_ind, sampled_child_ind = do_nni( node, temp_message, models; partition_list = partition_list, - acc_rule = acc_rule, + selection_rule = selection_rule, ) if nnid && last(last(stack)) #We nnid a sibling that hasn't been visited (then, down would be true in the next iter)... #... and now we want to continue down the nnid sibling (now a child to node) @@ -128,7 +118,7 @@ function nni_optim_full_traversal!( end pop!(stack) push!(stack, (Vector{T}(), node, ind, lastind, false, false)) #When we're going up a second time, we no longer need a temp - push!(stack, (temp_message, node, exceed_child, exceed_child, false, true)) #Go to the "new" child - the "new" lastind + push!(stack, (temp_message, node, sampled_child_ind, sampled_child_ind, false, true)) #Go to the "new" child - the "new" lastind continue #Don't fel-up yet end end @@ -147,104 +137,13 @@ function nni_optim_full_traversal!( end end -function nni_optim!( - temp_messages::Vector{Vector{T}}, - tree::FelNode, - models, - partition_list; - acc_rule = (x, y) -> x > y, - traversal = Iterators.reverse -) where {T <: Partition} - - #Consider a NamedTuple/struct - stack = [(pop!(temp_messages), tree, 1, 1, true, true)] - while !isempty(stack) - temp_message, node, ind, lastind, first, down = pop!(stack) - #We start out with a regular downward pass... - #(except for some extra bookkeeping to track if node is visited for the first time) - #------------------- - if isleafnode(node) - push!(temp_messages, temp_message) - continue - end - if down - if first - model_list = models(node) - for part in partition_list - forward!( - temp_message[part], - node.parent_message[part], - model_list[part], - node, - ) - end - @assert length(node.children) <= 2 - #Temp must be constant between iterations for a node during down... - child_iter = traversal(1:length(node.children)) - lastind = Base.first(child_iter) #(which is why we track the last child to be visited during down) - push!(stack, (Vector{T}(), node, ind, lastind, false, false)) #... but not up - for i = child_iter #Iterative reverse <=> Recursive non-reverse, also optimal for lazysort!?? - push!(stack, (temp_message, node, i, lastind, false, true)) - end - end - if !first - sib_inds = sibling_inds(node.children[ind]) - for part in partition_list - combine!( - (node.children[ind]).parent_message[part], - [mess[part] for mess in node.child_messages[sib_inds]], - true, - ) - combine!( - (node.children[ind]).parent_message[part], - [temp_message[part]], - false, - ) - end - #But calling nni_optim! recursively... (the iterative equivalent) - push!(stack, (safepop!(temp_messages, temp_message), node.children[ind], ind, lastind, true, true)) #first + down combination => safepop! - ind == lastind && push!(temp_messages, temp_message) #We no longer need constant temp - end - end - if !down - #Then combine node.child_messages into node.message... - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - end - #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. - #------------------- - if !isroot(node) - temp_message = pop!(temp_messages) - model_list = models(node) - nnid, exceed_sib, exceed_child = do_nni( - node, - temp_message, - models; - partition_list = partition_list, - acc_rule = acc_rule, - ) - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - backward!(node.parent.child_messages[ind][part], node.message[part], model_list[part], node) - combine!( - node.parent.message[part], - [mess[part] for mess in node.parent.child_messages], - true, - ) - end - push!(temp_messages, temp_message) - end - end - end -end - #Unsure if this is the best choice to handle the model,models, and model_func stuff. function nni_optim!( temp_messages::Vector{Vector{T}}, tree::FelNode, models::Vector{<:BranchModel}, partition_list; - acc_rule = (x, y) -> x > y, + selection_rule = x -> argmax(x), traversal = Iterators.reverse, ) where {T <: Partition} nni_optim!( @@ -252,7 +151,7 @@ function nni_optim!( tree, x -> models, partition_list, - acc_rule = acc_rule, + selection_rule = selection_rule, traversal = traversal, ) end @@ -261,7 +160,7 @@ function nni_optim!( tree::FelNode, model::BranchModel, partition_list; - acc_rule = (x, y) -> x > y, + selection_rule = x -> argmax(x), traversal = Iterators.reverse, ) where {T <: Partition} @@ -270,7 +169,7 @@ function nni_optim!( tree, x -> [model], partition_list, - acc_rule = acc_rule, + selection_rule = selection_rule, traversal = traversal, ) end @@ -280,7 +179,7 @@ function do_nni( temp_message, models::F; partition_list = 1:length(node.message), - acc_rule = (x, y) -> x > y, + selection_rule = x -> argmax(x), ) where {F<:Function} if length(node.children) == 0 || node.parent === nothing return false @@ -299,12 +198,16 @@ function do_nni( #total_LL(node.parent_message[part]) for part in partition_list]) - max_LL = -Inf - exceeded, exceed_sib, exceed_child = (false, 0, 0) + change = false + nni_LLs = [curr_LL] + nni_configs = [(0,0)] + + + for sib_ind in [x for x in 1:length(node.parent.children) if node.parent.children[x] != node] - switch_LL = 0.0 + for child_ind = 1:length(node.children) for part in partition_list #move the sibling message, after upward propogation, to temp_message to work with it @@ -350,36 +253,34 @@ function do_nni( combine!(temp_message[part], [node.parent.parent_message[part]], false) end - switch_LL = sum([total_LL(temp_message[part]) for part in partition_list]) - + LL = sum([total_LL(temp_message[part]) for part in partition_list]) - if switch_LL > max_LL - exceed_sib = sib_ind - exceed_child = child_ind - max_LL = switch_LL - end + push!(nni_LLs, LL) + push!(nni_configs, (sib_ind, child_ind)) end end - exceeded = acc_rule(max_LL, curr_LL) + sampled_config_ind = selection_rule(nni_LLs) + change = sampled_config_ind != 1 + (sampled_sib_ind, sampled_child_ind) = nni_configs[sampled_config_ind] - #do the actual move here, switching exceed child and exceed sib - if !(exceeded) - return false, exceed_sib, exceed_child + #do the actual move here, switching sampled_child_in and sampled_sib_ind + if !(change) + return false, sampled_sib_ind, sampled_child_ind else - sib = node.parent.children[exceed_sib] - child = node.children[exceed_child] + sib = node.parent.children[sampled_sib_ind] + child = node.children[sampled_child_ind] child.parent = node.parent sib.parent = node - node.children[exceed_child] = sib - node.parent.children[exceed_sib] = child + node.children[sampled_child_ind] = sib + node.parent.children[sampled_sib_ind] = child - node.parent.child_messages[exceed_sib], node.child_messages[exceed_child] = - node.child_messages[exceed_child], node.parent.child_messages[exceed_sib] + node.parent.child_messages[sampled_sib_ind], node.child_messages[sampled_child_ind] = + node.child_messages[sampled_child_ind], node.parent.child_messages[sampled_sib_ind] - return true, exceed_sib, exceed_child + return true, sampled_sib_ind, sampled_child_ind end end end @@ -394,7 +295,7 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th # Keyword Arguments - `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models, the default option). -- `acc_rule=(x, y) -> x > y`: a function that takes the current and proposed log likelihoods, and if true is returned the move is accepted. +- `selection_rule = x -> argmax(x)`: a function that takes the current and proposed log likelihoods and selects a nni configuration. Note that the current log likelihood is stored at x[1]. - `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized. - `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable. - `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`. @@ -403,7 +304,7 @@ function nni_optim!( tree::FelNode, models; partition_list = nothing, - acc_rule = (x, y) -> x > y, + selection_rule = x -> argmax(x), sort_tree = false, traversal = Iterators.reverse, shuffle = false @@ -415,13 +316,12 @@ function nni_optim!( partition_list = 1:length(tree.message) end - #Need to decide here between nni_optim and nni_optim_full_traversal nni_optim!( temp_messages, tree, models, partition_list, - acc_rule = acc_rule, + selection_rule = selection_rule, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal ) -end +end \ No newline at end of file diff --git a/src/core/nodes/FelNode.jl b/src/core/nodes/FelNode.jl index 89ce09a..97119d7 100644 --- a/src/core/nodes/FelNode.jl +++ b/src/core/nodes/FelNode.jl @@ -98,3 +98,55 @@ function mixed_type_equilibrium_message( end return out_mess end + +""" + function copy_tree(root::FelNode, shallow_copy=false) + + Returns an untangled copy of the tree. Optionally, the flag `shallow_copy` can be used to obtain a copy of the tree with only the names and branchlengths. +""" +function copy_tree(root::FelNode, shallow_copy=false) + + root_copy = FelNode(root.branchlength, root.name) + stack = [(root, root_copy)] + + while !isempty(stack) + node, node_copy = pop!(stack) + + if !shallow_copy + + if isdefined(node, :nodeindex) + node_copy.nodeindex = node.nodeindex + end + if isdefined(node, :seqindex) + node_copy.seqindex = node.seqindex + end + if isdefined(node, :state_path) + node_copy.state_path = deepcopy(node.state_path) + end + if isdefined(node, :branch_params) + node_copy.branch_params = copy(node.branch_params) + end + if isdefined(node, :node_data) + node_copy.node_data = deepcopy(node.node_data) + end + if isdefined(node, :message) + node_copy.message = copy_message(node.message) + end + if isdefined(node, :parent_message) + node_copy.parent_message = copy_message(node.parent_message) + end + if isdefined(node, :child_messages) + node_copy.child_messages = [copy_message(msg) for msg in node.child_messages] + end + end + + for child in node.children + child_copy = FelNode(child.branchlength, child.name) + push!(stack, (child, child_copy)) + child_copy.parent = node_copy + push!(node_copy.children, child_copy) + end + end + + return root_copy +end diff --git a/src/utils/misc.jl b/src/utils/misc.jl index 114f249..ebb8590 100644 --- a/src/utils/misc.jl +++ b/src/utils/misc.jl @@ -318,3 +318,8 @@ function write_nexus(fname::String,tree::FelNode) n.name = old_names[i] end end + +function softmax(x) + exp_x = exp.(x .- maximum(x)) # For numerical stability + return exp_x ./ sum(exp_x) +end \ No newline at end of file diff --git a/src/utils/simple_optim.jl b/src/utils/simple_optim.jl index c2dd5d4..f559f1d 100644 --- a/src/utils/simple_optim.jl +++ b/src/utils/simple_optim.jl @@ -15,6 +15,10 @@ end struct GoldenSectionOpt <: UnivariateOpt end struct BrentsMethodOpt <: UnivariateOpt end +function univariate_modifier(fun, modifier::UnivariateOpt; a=0, b=1, transform=unit_transform, tol=10e-5, kwargs...) + return univariate_maximize(fun, a, b, unit_transform, modifier, tol) +end + """ Golden section search. diff --git a/src/utils/simple_sample.jl b/src/utils/simple_sample.jl new file mode 100644 index 0000000..56103c7 --- /dev/null +++ b/src/utils/simple_sample.jl @@ -0,0 +1,45 @@ + +function univariate_modifier(f, modifier::UnivariateSampler; curr_value=nothing, kwargs...) + return univariate_sampler(f, modifier, curr_value) +end + +""" + BranchlengthSampler + + A type that allows you to specify a additive proposal function in the log domain and a prior distrubution over the log of the branchlengths. It also holds the acceptance ratio acc_ratio (acc_ratio[1] stores the number of accepts, and acc_ratio[2] stores the number of rejects). +""" +struct BranchlengthSampler <: UnivariateSampler + acc_ratio::Vector{Int} + log_bl_proposal::Distribution + log_bl_prior::Distribution + BranchlengthSampler(log_bl_proposal,log_bl_prior) = new([0,0],log_bl_proposal,log_bl_prior) +end + +""" + univariate_sampler(LL, modifier::BranchlengthPeturbation, curr_branchlength) + + A MCMC algorithm that draws the next sample of a Markov Chain that approximates the Posterior distrubution over the branchlengths. +""" +function univariate_sampler(LL, modifier::BranchlengthSampler, curr_branchlength) + return branchlength_metropolis(LL, modifier, curr_branchlength) +end + +function branchlength_metropolis(LL, modifier, curr_value) + # The prior distribution for the variable log(branchlength). A small perturbation of +1e-12 is added to enhance numerical stability near zero. + log_prior(x) = logpdf(modifier.log_bl_prior,log(x + 1e-12)) + # Adding additive normal symmetrical noise in the log(branchlength) domain to ensure the proposal function is symmetric. + noise = rand(modifier.log_bl_proposal) + proposal = exp(log(curr_value)+noise) + # The standard Metropolis acceptance criterion. + if rand() <= exp(LL(proposal)+log_prior(proposal)-LL(curr_value)-log_prior(curr_value)) + modifier.acc_ratio[1] = modifier.acc_ratio[1] + 1 + return proposal + else + modifier.acc_ratio[2] = modifier.acc_ratio[2] + 1 + return curr_value + end +end + + + + diff --git a/src/utils/utils.jl b/src/utils/utils.jl index cd69ae9..253395b 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,4 +1,5 @@ include("misc.jl") include("simple_optim.jl") +include("simple_sample.jl") include("tree_hash.jl") #fasta_io.jl is optionally included with Requires.jl in MolecularEvolution.jl diff --git a/test/partition_selection.jl b/test/partition_selection.jl index 16bf0de..fb7fbb9 100644 --- a/test/partition_selection.jl +++ b/test/partition_selection.jl @@ -57,7 +57,7 @@ begin branchlength_optim!(tree, bm_models, partition_list = [1]) branchlength_optim!(tree, bm_models, partition_list = [2]) branchlength_optim!(tree, bm_models) - branchlength_optim!(tree, bm_models, bl_optimizer=BrentsMethodOpt()) + branchlength_optim!(tree, bm_models, bl_modifier=BrentsMethodOpt()) branchlength_optim!(tree, x -> bm_models, partition_list = [2]) branchlength_optim!(tree, x -> bm_models)