Skip to content

Commit 91ed19b

Browse files
authored
Merge pull request #50 from nossleinad/Hastings-root-pos
Acc ratio structure
2 parents c66c2bb + d4ccef6 commit 91ed19b

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MolecularEvolution"
22
uuid = "9f975960-e239-4209-8aa0-3d3ad5a82892"
33
authors = ["Ben Murrell <[email protected]> and contributors"]
4-
version = "0.2.4"
4+
version = "0.2.5"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

examples/update.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ mutable struct MyModelSampler{
4343
T1<:ContinuousUnivariateDistribution,
4444
T2<:ContinuousUnivariateDistribution,
4545
} <: ModelsUpdate
46-
acc_ratio::Vector{Int}
46+
acc_ratio::Tuple{Float64, Int64, Int64}
4747
log_var_drift_proposal::T1
4848
log_var_drift_prior::T2
4949
mean_drift::Float64
@@ -52,7 +52,7 @@ mutable struct MyModelSampler{
5252
log_var_drift_prior::T2,
5353
mean_drift::Float64,
5454
) where {T1<:ContinuousUnivariateDistribution, T2<:ContinuousUnivariateDistribution}
55-
new{T1, T2}([0, 0], log_var_drift_proposal, log_var_drift_prior, mean_drift)
55+
new{T1, T2}((0.0, 0, 0), log_var_drift_proposal, log_var_drift_prior, mean_drift)
5656
end
5757
end
5858
# Then we let this struct implement our [`metropolis_step`](@ref) interface

src/core/algorithms/root_optim.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ mutable struct StandardRootOpt <: RootOpt
1010
end
1111

1212
struct StandardRootSample <: UniformRootPositionSample
13-
acc_ratio::Array{Int64,1}
13+
acc_ratio::Tuple{Float64, Int64, Int64}
1414
consecutive::Int64
1515

1616
function StandardRootSample(consecutive::Int64)
17-
new(zeros(Int64, 2), consecutive)
17+
new((0.0, 0, 0), consecutive)
1818
end
1919
end
2020

src/utils/simple_sample.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ end
66
"""
77
BranchlengthSampler
88
9-
A type that allows you to specify a additive proposal function in the log domain and a prior distrubution over the log of the branchlengths. It also holds the acceptance ratio acc_ratio (acc_ratio[1] stores the number of accepts, and acc_ratio[2] stores the number of rejects).
9+
A type that allows you to specify a additive proposal function in the log domain and a prior distrubution over the log of the branchlengths. It also stores `acc_ratio` which is a tuple of `(ratio, total, #acceptances)`, where `ratio::Float64` is the acceptance ratio, `total::Int64` is the total number of proposals, and `#acceptances::Int64` is the number of acceptances.
1010
"""
11-
struct BranchlengthSampler <: UnivariateSampler
12-
acc_ratio::Vector{Int}
11+
mutable struct BranchlengthSampler <: UnivariateSampler
12+
acc_ratio::Tuple{Float64, Int64, Int64} #(ratio, total, #acceptances)
1313
log_bl_proposal::Distribution
1414
log_bl_prior::Distribution
15-
BranchlengthSampler(log_bl_proposal,log_bl_prior) = new([0,0],log_bl_proposal,log_bl_prior)
15+
BranchlengthSampler(log_bl_proposal,log_bl_prior) = new((0.0, 0, 0),log_bl_proposal,log_bl_prior)
1616
end
1717

1818
"""
@@ -61,11 +61,13 @@ proposal(modifier::BranchlengthSampler, curr_value) = exp(log(curr_value) + rand
6161
log_prior(modifier::BranchlengthSampler, x) = logpdf(modifier.log_bl_prior, log(x + 1e-12))
6262
#Default definition. Overload it for your own modifier type
6363
function apply_decision(modifier, accept::Bool)
64+
ratio, total, acc = modifier.acc_ratio
65+
total += 1
6466
if accept
65-
modifier.acc_ratio[1] += 1
66-
else
67-
modifier.acc_ratio[2] += 1
67+
acc += 1
6868
end
69+
ratio = acc / total
70+
modifier.acc_ratio = (ratio, total, acc)
6971
end
7072

7173
tr(modifier, x) = x

0 commit comments

Comments
 (0)