Skip to content

Commit e14b08f

Browse files
authored
Merge pull request #49 from nossleinad/update-pattern
Update pattern
2 parents fe44c7b + 5e5526b commit e14b08f

16 files changed

+796
-27
lines changed

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
33
Compose = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
4+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
45
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
56
FASTX = "c2308a5c-f048-11e8-3e8a-31650f418d12"
67
Fontconfig = "186bb1d3-e1f7-5a2c-a377-96d770f13627"

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using MolecularEvolution
22
using Documenter, Literate
3-
using Phylo
3+
using Phylo, Distributions
44
using Plots
55
using Compose, Cairo, Fontconfig
66
using FASTX
@@ -48,6 +48,7 @@ makedocs(;
4848
"optimization.md",
4949
"ancestors.md",
5050
"generated/viz.md",
51+
"generated/update.md",
5152
],
5253
)
5354

examples/update.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# # Updating a phylogenetic tree
2+
#=
3+
## Interface
4+
5+
```@docs; canonical=false
6+
AbstractUpdate
7+
StandardUpdate
8+
```
9+
=#
10+
11+
# ## Example
12+
13+
using MolecularEvolution, Plots, Distributions
14+
# Simulate a tree
15+
tree = sim_tree(n = 50)
16+
initial_message = GaussianPartition()
17+
models = BrownianMotion(0.0, 1.0)
18+
internal_message_init!(tree, initial_message)
19+
sample_down!(tree, models)
20+
log_likelihood!(tree, models)
21+
# Add some noise to the branch lengths
22+
for n in getnodelist(tree)
23+
n.branchlength += 100 * rand()
24+
end
25+
log_likelihood!(tree, models)
26+
# Optimize under the brownian motion model
27+
update = MaxLikUpdate(branchlength = 1, nni = 0, root = 1)
28+
tree, models = update(tree, models)
29+
@show log_likelihood!(tree, models)
30+
# ### Set up a Bayesian model sampler
31+
#=
32+
Let's assume the target of inference is not the tree itself, but rather the models.
33+
Assume further that you want to, for a fixed mean drift, sample the variance of the brownian motion model,
34+
with the metropolis algorithm.
35+
=#
36+
# We begin with a struct that defines the model and how it's updated
37+
tree = sim_tree(n = 200)
38+
internal_message_init!(tree, GaussianPartition())
39+
#Simulate brownian motion over the tree
40+
models = BrownianMotion(0.0, 2.0)
41+
sample_down!(tree, models)
42+
mutable struct MyModelSampler{
43+
T1<:ContinuousUnivariateDistribution,
44+
T2<:ContinuousUnivariateDistribution,
45+
} <: ModelsUpdate
46+
acc_ratio::Vector{Int}
47+
log_var_drift_proposal::T1
48+
log_var_drift_prior::T2
49+
mean_drift::Float64
50+
function MyModelSampler(
51+
log_var_drift_proposal::T1,
52+
log_var_drift_prior::T2,
53+
mean_drift::Float64,
54+
) where {T1<:ContinuousUnivariateDistribution, T2<:ContinuousUnivariateDistribution}
55+
new{T1, T2}([0, 0], log_var_drift_proposal, log_var_drift_prior, mean_drift)
56+
end
57+
end
58+
# Then we let this struct implement our [`metropolis_step`](@ref) interface
59+
MolecularEvolution.tr(::MyModelSampler, x::BrownianMotion) = log(x.var_drift)
60+
MolecularEvolution.invtr(modifier::MyModelSampler, x::Float64) =
61+
BrownianMotion(modifier.mean_drift, exp(x))
62+
63+
MolecularEvolution.proposal(modifier::MyModelSampler, curr_value::Float64) =
64+
curr_value + rand(modifier.log_var_drift_proposal)
65+
MolecularEvolution.log_prior(modifier::MyModelSampler, x::Float64) =
66+
logpdf(modifier.log_var_drift_prior, x)
67+
# Now we define what a model update is
68+
function (update::MyModelSampler)(
69+
tree::FelNode,
70+
models::BranchModel;
71+
partition_list = 1:length(tree.message),
72+
)
73+
metropolis_step(update, models) do x::BrownianMotion
74+
log_likelihood!(tree, x)
75+
end
76+
end
77+
# Now we define how the model is collapsed to its parameter
78+
function MolecularEvolution.collapse_models(::MyModelSampler, models::BranchModel)
79+
return models.var_drift
80+
end
81+
# Now we define a Bayesian sampler
82+
update = BayesUpdate(
83+
nni = 0,
84+
branchlength = 0,
85+
models = 1,
86+
models_sampler = MyModelSampler(Normal(0.0, 1.0), Normal(-10.0, 1.0), 0.0),
87+
)
88+
trees, models_samples = metropolis_sample(
89+
update,
90+
tree,
91+
BrownianMotion(0.0, 7.67),
92+
1000,
93+
burn_in = 1000,
94+
collect_models = true,
95+
)
96+
97+
ll(x) = log_likelihood!(tree, BrownianMotion(0.0, x))
98+
prior(x) = logpdf(update.models_update.log_var_drift_prior, log(x)) - log(x)
99+
x_range = 0.1:0.1:5
100+
101+
p1 = histogram(
102+
models_samples,
103+
normalize = :pdf,
104+
alpha = 0.5,
105+
label = "Posterior samples",
106+
xlims = (minimum(x_range), maximum(x_range)),
107+
xlabel = "variance per unit time",
108+
ylabel = "probability density",
109+
)
110+
p2 = plot(x_range, ll, label = "Tree likelihood")
111+
112+
p3 = plot(x_range, prior, label = "Prior")
113+
plot(p1, p2, p3, layout = (1, 3), size = (1100, 400))
114+
#-
115+
plot(models_samples)

src/MolecularEvolution.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ abstract type UnivariateModifier end
3434
abstract type UnivariateOpt <: UnivariateModifier end
3535
abstract type UnivariateSampler <: UnivariateModifier end
3636

37+
abstract type RootUpdate <: Function end
38+
abstract type RootOpt <: RootUpdate end
39+
abstract type RootSample <: RootUpdate end
40+
abstract type UniformRootPositionSample <: RootSample end
41+
abstract type ModelsUpdate <: Function end
42+
43+
3744
abstract type LazyDirection end
3845

3946
#include("core/core.jl")
@@ -103,6 +110,7 @@ export
103110
combine!,
104111
felsenstein!,
105112
felsenstein_down!,
113+
felsenstein_roundtrip!,
106114
sample_down!,
107115
#endpoint_conditioned_sample_down!,
108116
log_likelihood!,
@@ -123,10 +131,28 @@ export
123131
nni_optim!,
124132
branchlength_update!,
125133
branchlength_optim!,
134+
root_optim!,
135+
root_position_sample!,
126136
metropolis_sample,
127137
metropolis_step,
128138
copy_tree,
129139

140+
#update
141+
AbstractUpdate,
142+
StandardUpdate,
143+
BayesUpdate,
144+
MaxLikUpdate,
145+
RootUpdate,
146+
RootOpt,
147+
RootSample,
148+
UniformRootPositionSample,
149+
StandardRootOpt,
150+
StandardRootSample,
151+
ModelsUpdate,
152+
StandardModelsUpdate,
153+
collapse_models,
154+
155+
130156
#Tree simulation functions
131157
sim_tree,
132158
standard_tree_sim,

src/bayes/sampling.jl

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,30 @@
11
"""
22
metropolis_sample(
3-
update!::Function,
3+
update!::AbstractUpdate,
44
initial_tree::FelNode,
5-
models::Vector{<:BranchModel},
5+
models,
66
num_of_samples;
7+
partition_list = 1:length(initial_tree.message),
78
burn_in = 1000,
89
sample_interval = 10,
910
collect_LLs = false,
11+
collect_models = false,
1012
midpoint_rooting = false,
1113
ladderize = false,
1214
)
1315
1416
Samples tree topologies from a posterior distribution using a custom `update!` function.
1517
1618
# Arguments
17-
- `update!`: A function that takes (tree::FelNode, models::Vector{<:BranchModel}) and updates `tree`. `update!` takes (tree::FelNode, models::Vector{<:BranchModel}) and updates `tree`. One call to `update!` corresponds to one iteration of the Metropolis algorithm.
19+
- `update!`: A callable that takes (tree::FelNode, models) and updates `tree` and `models`. One call to `update!` corresponds to one iteration of the Metropolis algorithm.
1820
- `initial_tree`: An initial tree topology with the leaves populated with data, for the likelihood calculation.
1921
- `models`: A list of branch models.
2022
- `num_of_samples`: The number of tree samples drawn from the posterior.
23+
- `partition_list`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to sample with all partitions, the default option).
2124
- `burn_in`: The number of samples discarded at the start of the Markov Chain.
2225
- `sample_interval`: The distance between samples in the underlying Markov Chain (to reduce sample correlation).
2326
- `collect_LLs`: Specifies if the function should return the log-likelihoods of the trees.
27+
- `collect_models`: Specifies if the function should return the models.
2428
- `midpoint_rooting`: Specifies whether the drawn samples should be midpoint rerooted (Important! Should only be used for time-reversible branch models starting in equilibrium).
2529
2630
!!! note
@@ -29,39 +33,49 @@ Samples tree topologies from a posterior distribution using a custom `update!` f
2933
# Returns
3034
- `samples`: The trees drawn from the posterior. Returns shallow tree copies, which needs to be repopulated before running felsenstein! etc.
3135
- `sample_LLs`: The associated log-likelihoods of the tree (optional).
36+
- `sample_models`: The models drawn from the posterior (optional). The models can be collapsed into it's parameters with `collapse_models`.
3237
"""
3338
function metropolis_sample(
34-
update!::Function,
39+
update!::AbstractUpdate,
3540
initial_tree::FelNode,
36-
models::Vector{<:BranchModel},
41+
models,#::Vector{<:BranchModel},
3742
num_of_samples;
43+
partition_list = 1:length(initial_tree.message),
3844
burn_in = 1000,
3945
sample_interval = 10,
4046
collect_LLs = false,
4147
midpoint_rooting = false,
4248
ladderize = false,
49+
collect_models = false,
4350
)
4451

4552
# The prior over the (log) of the branchlengths should be specified in bl_sampler.
4653
# Furthermore, a non-informative/uniform prior is assumed over the tree topolgies (excluding the branchlengths).
4754

48-
sample_LLs = []
55+
sample_LLs = Float64[]
4956
samples = FelNode[]
50-
tree = deepcopy(initial_tree)
57+
sample_models = []
58+
tree = initial_tree#deepcopy(initial_tree)
5159
iterations = burn_in + num_of_samples * sample_interval
5260

5361
for i = 1:iterations
5462
# Updates the tree topolgy and branchlengths.
55-
update!(tree, models)
63+
tree, models = update!(tree, models, partition_list = partition_list)
64+
if isnothing(tree)
65+
break
66+
end
5667

5768
if (i - burn_in) % sample_interval == 0 && i > burn_in
5869

5970
push!(samples, copy_tree(tree, true))
6071

6172
if collect_LLs
62-
push!(sample_LLs, log_likelihood!(tree, models))
73+
push!(sample_LLs, log_likelihood!(tree, models, partition_list = partition_list))
6374
end
6475

76+
if collect_models
77+
push!(sample_models, collapse_models(update!, models))
78+
end
6579
end
6680

6781
end
@@ -79,10 +93,15 @@ function metropolis_sample(
7993
end
8094
end
8195

82-
if collect_LLs
96+
if collect_LLs && collect_models
97+
return samples, sample_LLs, sample_models
98+
elseif collect_LLs && !collect_models
8399
return samples, sample_LLs
100+
elseif !collect_LLs && collect_models
101+
return samples, sample_models
84102
end
85103

104+
86105
return samples
87106
end
88107

@@ -110,10 +129,7 @@ function metropolis_sample(
110129
bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0, 2), Normal(-1, 1)),
111130
kwargs...,
112131
)
113-
metropolis_sample(initial_tree, models, num_of_samples; kwargs...) do tree, models
114-
nni_update!(softmax_sampler, tree, x -> models)
115-
branchlength_update!(bl_sampler, tree, x -> models)
116-
end
132+
metropolis_sample(BayesUpdate(; branchlength_sampler = bl_sampler), initial_tree, models, num_of_samples; kwargs...)
117133
end
118134

119135
# Below are some functions that help to assess the mixing by looking at the distance between leaf nodes.

0 commit comments

Comments
 (0)