Skip to content

Be consistent with message type declaration #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/core/algorithms/ancestors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function depth_first_reconstruction(
run_fel_up = true,
run_fel_down = true,
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
if run_fel_up
felsenstein!(tree, model_func, partition_list = partition_list)
Expand All @@ -56,7 +56,7 @@ function depth_first_reconstruction(
run_fel_up = true,
run_fel_down = true,
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
depth_first_reconstruction(
tree,
Expand All @@ -76,7 +76,7 @@ function depth_first_reconstruction(
run_fel_up = true,
run_fel_down = true,
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
depth_first_reconstruction(
tree,
Expand All @@ -91,7 +91,7 @@ end

#For marginal reconstructions
function reconstruct_marginal_node!(
node_message_dict::Dict{FelNode,Vector{Partition}},
node_message_dict::Dict{FelNode,Vector{<:Partition}},
node::FelNode,
model_array::Vector{<:BranchModel},
partition_list,
Expand All @@ -109,7 +109,7 @@ end

export marginal_state_dict
"""
marginal_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{Partition}}())
marginal_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{<:Partition}}())

Takes in a tree and a model (which can be a single model, an array of models, or a function that maps FelNode->Array{<:BranchModel}), and
returns a dictionary mapping nodes to their marginal reconstructions (ie. P(state|all observations,model)). A subset of partitions can be specified by partition_list,
Expand All @@ -119,7 +119,7 @@ function marginal_state_dict(
tree::FelNode,
model;
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
return depth_first_reconstruction(
tree,
Expand All @@ -133,7 +133,7 @@ end
#For joint max reconstructions
export dependent_reconstruction!
function dependent_reconstruction!(
node_message_dict::Dict{FelNode,Vector{Partition}},
node_message_dict::Dict{FelNode,Vector{<:Partition}},
node::FelNode,
model_array::Vector{<:BranchModel},
partition_list;
Expand Down Expand Up @@ -173,7 +173,7 @@ reconstruct_cascading_max_node!(node_message_dict, node, model_array, partition_
)
export cascading_max_state_dict
"""
cascading_max_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{Partition}}())
cascading_max_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{<:Partition}}())

Takes in a tree and a model (which can be a single model, an array of models, or a function that maps FelNode->Array{<:BranchModel}), and
returns a dictionary mapping nodes to their inferred ancestors under the following scheme: the state that maximizes the marginal likelihood is selected at the root,
Expand All @@ -184,7 +184,7 @@ function cascading_max_state_dict(
tree::FelNode,
model;
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
return depth_first_reconstruction(
tree,
Expand All @@ -206,7 +206,7 @@ conditioned_sample_node!(node_message_dict, node, model_array, partition_list) =
)
export endpoint_conditioned_sample_state_dict
"""
endpoint_conditioned_sample_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{Partition}}())
endpoint_conditioned_sample_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{<:Partition}}())

Takes in a tree and a model (which can be a single model, an array of models, or a function that maps FelNode->Array{<:BranchModel}), and draws samples under the model
conditions on the leaf observations. These samples are stored in the node_message_dict, which is returned. A subset of partitions can be specified by partition_list, and a
Expand All @@ -216,7 +216,7 @@ function endpoint_conditioned_sample_state_dict(
tree::FelNode,
model;
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
return depth_first_reconstruction(
tree,
Expand Down
6 changes: 3 additions & 3 deletions src/core/algorithms/branchlength_optim.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#Model list should be a list of P matrices.
function branch_LL_up(
bl::Real,
temp_message::Vector{Partition},
temp_message::Vector{<:Partition},
node::FelNode,
model_list::Vector{<:BranchModel},
partition_list,
Expand All @@ -22,8 +22,8 @@ end
#I need to add a version of this that takes a generic optimizer function and uses that instead of golden_section_maximize on just the branchlength.
#This is for cases where the user stores node-level parameters and wants to optimize them.
function branchlength_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
models,
partition_list,
Expand Down
12 changes: 6 additions & 6 deletions src/core/algorithms/nni_optim.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@


function nni_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
models,
partition_list;
Expand Down Expand Up @@ -72,8 +72,8 @@ end

#Unsure if this is the best choice to handle the model,models, and model_func stuff.
function nni_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
models::Vector{<:BranchModel},
partition_list;
Expand All @@ -89,8 +89,8 @@ function nni_optim!(
)
end
function nni_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
model::BranchModel,
partition_list;
Expand Down
1 change: 1 addition & 0 deletions src/models/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ function partition_from_template(partition_template::T) where {T <: DiscretePart
end
=#

#Note: not enforcing a return type causes some unnecesarry conversions
copy_message(msg::Vector{<:Partition}) = [copy_partition(x) for x in msg]

#This is a function shared for all models - perhaps move this elsewhere
Expand Down
17 changes: 17 additions & 0 deletions test/message_type_stability.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
begin
# Single partition example
tree = sim_tree(n = 10)
GAA_partition = GappyAminoAcidPartition(5)
AA_freqs = [1 / GAA_partition.states for _ = 1:GAA_partition.states]
GAA_partition.state .= AA_freqs
internal_message_init!(tree, GAA_partition)
Q = gappy_Q_from_symmetric_rate_matrix(WAGmatrix, 1.0, AA_freqs)
model = DiagonalizedCTMC(Q)
sample_down!(tree, model)
felsenstein!(tree, model)

# These would previously break since Vector{GappyAminoAcidPartition} is not <: Vector{Partition}, for example.
branchlength_optim!(tree, model)
marginal_state_dict(tree, model)
nni_optim!(tree, model)
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,8 @@ using Test
include("partition_selection.jl")
end

@testset "message_type_stability" begin
include("message_type_stability.jl")
end

end