Skip to content

Commit 8edb5e0

Browse files
authored
Merge pull request #55 from MurrellGroup/hipstr
Hipstr
2 parents 4809231 + d4f0243 commit 8edb5e0

File tree

4 files changed

+255
-8
lines changed

4 files changed

+255
-8
lines changed

examples/viz.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ model = DiagonalizedCTMC(reversibleQ(ones(6) ./ (6 * mean), rparams(4)))
7474
internal_message_init!(tree, NucleotidePartition(ones(4) ./ 4, 100))
7575
sample_down!(tree, model)
7676
@time trees, LLs = metropolis_sample(tree, [model], 300, collect_LLs=true);
77-
reference = trees[argmax(LLs)];
78-
# We'll use the maximum a posteriori tree as reference
77+
# We'll use the [`HIPSTR`](@ref) tree as reference
78+
reference = HIPSTR(trees);
7979
plot_multiple_trees(trees, reference)
8080
# We can pass in a weight function to fit query trees against `reference` in a weighted least squares fashion with a location and scale parameter.
8181
#=
@@ -99,5 +99,6 @@ values_from_phylo_tree
9999
savefig_tweakSVG
100100
tree_draw
101101
plot_multiple_trees
102+
HIPSTR
102103
```
103104
=#

src/utils/HIPSTR.jl

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
export HIPSTR
2+
"""
3+
HIPSTR(trees::Vector{FelNode}; set_branchlengths = true)
4+
5+
Construct a Highest Independent Posterior Subtree Reconstruction (HIPSTR) tree
6+
from a collection of trees.
7+
8+
Returns a single FelNode representing the HIPSTR consensus tree.
9+
10+
If `set_branchlengths = true`, the branch length of a node in the HIPSTR tree will be set to the mean branch length of all nodes from the input trees that have the same clade. (By the same clade, we mean that the set of leaves below the node is the same.) Otherwise, the root branch length is 0.0 and the rest 1.0.
11+
12+
Source: https://www.biorxiv.org/content/10.1101/2024.12.08.627395v1.full.pdf
13+
"""
14+
function HIPSTR(trees::Vector{FelNode}; set_branchlengths = true)
15+
16+
# Step 1: Collect all clades, their frequencies, and child pairs
17+
clades_stats = Dict{Tuple{UInt64, UInt64}, CladeStats}()
18+
leaf_names = Set{String}()
19+
20+
# First identify all leaf nodes across all trees
21+
for tree in trees
22+
for leaf in getleaflist(tree)
23+
push!(leaf_names, leaf.name)
24+
end
25+
end
26+
27+
# Create a mapping from leaf names to indices for consistent hashing
28+
leaf_dict = Dict(name => i for (i, name) in enumerate(sort(collect(leaf_names))))
29+
30+
# Process each tree to collect clade information
31+
for tree in trees
32+
collect_clades!(tree, clades_stats, leaf_dict)
33+
end
34+
35+
# Scale clade frequencies to get credibilities
36+
n_trees = length(trees)
37+
for (_, stats) in clades_stats
38+
stats.frequency /= n_trees
39+
end
40+
41+
# Step 2: Compute the root clade hash (all tips)
42+
all_tips = BitSet(1:length(leaf_dict))
43+
root_hash = hash_clade(all_tips)
44+
45+
# Step 3: Build the credibility cache through post-order traversal
46+
cred_cache = Dict{Tuple{UInt64, UInt64}, Tuple{Float64, Tuple{UInt64, UInt64}, Tuple{UInt64, UInt64}}}()
47+
compute_credibility = function(clade_hash)
48+
# Return from cache if available
49+
haskey(cred_cache, clade_hash) && return cred_cache[clade_hash][1]
50+
51+
# Base case: single tip or clade not found
52+
if !haskey(clades_stats, clade_hash) || isempty(clades_stats[clade_hash].child_pairs)
53+
cred_cache[clade_hash] = (clades_stats[clade_hash].frequency, (0, 0), (0, 0))
54+
return clades_stats[clade_hash].frequency
55+
end
56+
57+
# Find the best child pair
58+
best_cred = 0.0
59+
best_left = (0, 0)
60+
best_right = (0, 0)
61+
62+
for (left_hash, right_hash) in clades_stats[clade_hash].child_pairs
63+
left_cred = compute_credibility(left_hash)
64+
right_cred = compute_credibility(right_hash)
65+
66+
# Product of the credibilities
67+
pair_cred = left_cred * right_cred * clades_stats[clade_hash].frequency
68+
69+
if pair_cred > best_cred
70+
best_cred = pair_cred
71+
best_left = left_hash
72+
best_right = right_hash
73+
end
74+
end
75+
76+
# Cache and return
77+
cred_cache[clade_hash] = (best_cred, best_left, best_right)
78+
return best_cred
79+
end
80+
81+
# Compute credibility for the root clade
82+
compute_credibility(root_hash)
83+
84+
# Step 4: Construct the HIPSTR tree through another traversal
85+
reverse_leaf_dict = Dict(i => name for (name, i) in leaf_dict)
86+
87+
# Function to build the tree recursively
88+
function build_tree(clade_hash)
89+
_, left_hash, right_hash = cred_cache[clade_hash]
90+
91+
# Handle leaf case
92+
if left_hash == (0, 0) && right_hash == (0, 0)
93+
# Determine which tip this is
94+
for (index, name) in reverse_leaf_dict
95+
tip_hash = hash_clade(BitSet([index]))
96+
if tip_hash == clade_hash
97+
node = FelNode(1.0, name)
98+
node.seqindex = index
99+
return node
100+
end
101+
end
102+
error("Failed to find leaf corresponding to hash $clade_hash")
103+
end
104+
105+
# Internal node
106+
node = FelNode(1.0, "")
107+
108+
# Add children
109+
left_child = build_tree(left_hash)
110+
right_child = build_tree(right_hash)
111+
112+
# Default branch lengths to 1.0 if we don't have better information
113+
left_child.branchlength = 1.0
114+
right_child.branchlength = 1.0
115+
116+
left_child.parent = node
117+
right_child.parent = node
118+
push!(node.children, left_child)
119+
push!(node.children, right_child)
120+
121+
return node
122+
end
123+
124+
# Build the final tree
125+
hipstr_tree = build_tree(root_hash)
126+
# Set the root branchlength to 0.0
127+
hipstr_tree.branchlength = 0.0
128+
129+
# Set node indices
130+
set_node_indices!(hipstr_tree)
131+
132+
# Set branch lengths to mean
133+
set_branchlengths && set_mean_branchlengths!(hipstr_tree, trees)
134+
135+
return hipstr_tree
136+
end
137+
138+
function set_mean_branchlengths!(tree::FelNode, trees::Vector{FelNode})
139+
#Set branchlengths to 0.0, we'll use these as accumulators and then eventually normalize
140+
tree_nodes = nodes(tree)
141+
for node in tree_nodes
142+
node.branchlength = 0.0
143+
end
144+
#Initialize container for counting the amount of terms in the accumulators
145+
branch_length_counts = Dict(zip(tree_nodes, zeros(Int64, length(tree_nodes))))
146+
#Go through all the trees, t...
147+
for t in trees
148+
#and for each matching clade, (tree_node, t_node), between tree and t...
149+
matching_pairs = tree_match_pairs(tree, t, push_leaves = true)
150+
for (tree_node, t_node) in matching_pairs
151+
#add the branch length of t_node in t to the branch length of tree_node in tree...
152+
tree_node.branchlength += t_node.branchlength
153+
#and increment the counter for tree_node.
154+
branch_length_counts[tree_node] += 1
155+
end
156+
end
157+
#Then normalize the branch length of all nodes in tree by the number of matching clades.
158+
for node in tree_nodes
159+
branch_length_counts[node] == 0 && @warn "Branch length counts for node $(node.name) with nodeindex $(node.nodeindex) is 0. Coming from a `HIPSTR` call, this number should be strictly positive."
160+
node.branchlength /= max(1, branch_length_counts[node])
161+
end
162+
#^^ This deviates somewhat from the HIPSTR paper. They find distributions over node ages (with respect to a molecular clock), which differs a bit from the notion of branch length.
163+
end
164+
"""
165+
Store statistics about a clade: its frequency and observed child pairs.
166+
"""
167+
mutable struct CladeStats
168+
frequency::Float64
169+
child_pairs::Set{Tuple{Tuple{UInt64, UInt64}, Tuple{UInt64, UInt64}}}
170+
171+
CladeStats() = new(0.0, Set{Tuple{Tuple{UInt64, UInt64}, Tuple{UInt64, UInt64}}}())
172+
end
173+
174+
"""
175+
Compute a hash for a clade based on its tips.
176+
"""
177+
function hash_clade(tips::BitSet)
178+
h1 = hash(tips)
179+
h2 = hash(reverse(collect(tips)))
180+
return (h1, h2)
181+
end
182+
183+
"""
184+
Recursively collect clades from a tree.
185+
"""
186+
function collect_clades!(node::FelNode, clades_stats::Dict{Tuple{UInt64, UInt64}, CladeStats}, leaf_dict::Dict{String, Int})
187+
# Get tips under this node
188+
tips = BitSet()
189+
190+
if isleafnode(node)
191+
# For a leaf, the tips are just this node
192+
if haskey(leaf_dict, node.name)
193+
push!(tips, leaf_dict[node.name])
194+
else
195+
# Skip if the leaf name is not recognized
196+
@warn "Skipping unrecognized leaf name: $(node.name)"
197+
return tips
198+
end
199+
else
200+
# For internal nodes, combine tips from children
201+
for child in node.children
202+
union!(tips, collect_clades!(child, clades_stats, leaf_dict))
203+
end
204+
end
205+
206+
# Compute the clade hash
207+
clade_hash = hash_clade(tips)
208+
209+
# Update clade stats
210+
if !haskey(clades_stats, clade_hash)
211+
clades_stats[clade_hash] = CladeStats()
212+
end
213+
clades_stats[clade_hash].frequency += 1
214+
215+
# For internal nodes, record child pairs
216+
if !isleafnode(node) && length(node.children) == 2
217+
left_tips = BitSet()
218+
for leaf in getleaflist(node.children[1])
219+
if haskey(leaf_dict, leaf.name)
220+
push!(left_tips, leaf_dict[leaf.name])
221+
end
222+
end
223+
224+
right_tips = BitSet()
225+
for leaf in getleaflist(node.children[2])
226+
if haskey(leaf_dict, leaf.name)
227+
push!(right_tips, leaf_dict[leaf.name])
228+
end
229+
end
230+
231+
left_hash = hash_clade(left_tips)
232+
right_hash = hash_clade(right_tips)
233+
234+
# Add the child pair to the set for this clade
235+
push!(clades_stats[clade_hash].child_pairs, (left_hash, right_hash))
236+
end
237+
238+
return tips
239+
end

src/utils/tree_hash.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,31 @@ function get_node_hashes(newt; push_leaves = false)
4242
#This makes a hash that matches everything except the given node
4343
other_hash = xor.(hash_container, all_names_hash)
4444
#Sort these, to make comparisons order invariant, which makes the comparison rooting invariant
45+
#Only sort internal node hashes, and don't use other_hash for leaves (to avoid collisions)
46+
isleafposition = isleafnode.(node_container)
47+
sensitive_tuple_sort(t::Tuple{Bool, Tuple{UInt64, UInt64}}) = ifelse(t[1], (t[2][1], t[2][1]), tuple_sort(t[2]))
4548
#Consider making this sort an option, and then we can have a rooted comparison and an unrooted one
46-
sorted_hash_pairs = tuple_sort.(collect(zip(hash_container, other_hash)))
47-
return sorted_hash_pairs, node_container
49+
sorted_hash_pairs = sensitive_tuple_sort.(collect(zip(isleafposition, zip(hash_container, other_hash))))
50+
return sorted_hash_pairs, node_container, hash_container
4851
end
4952

5053
export tree_diff
5154
#returns nodes in the query that don't have matching splits in the reference
5255
function tree_diff(query, reference)
53-
newt_hc, newt_nc = get_node_hashes(query)
54-
n_hc, n_nc = get_node_hashes(reference)
56+
newt_hc, newt_nc, _ = get_node_hashes(query)
57+
n_hc, n_nc, _ = get_node_hashes(reference)
5558
hashset = Set(n_hc)
5659
changed_nodes = newt_nc[[!(n in hashset) for n in newt_hc]]
5760
return changed_nodes
5861
end
5962

6063
export tree_match_pairs
64+
#returns pairs of nodes in the query and reference trees that have the same clade (i.e. rooting dependent)
6165
function tree_match_pairs(query, reference; push_leaves = false)
62-
newt_hc, newt_nc = get_node_hashes(query, push_leaves = push_leaves)
63-
n_hc, n_nc = get_node_hashes(reference, push_leaves = push_leaves)
66+
newt_hash_pairs, newt_nc, newt_hash_container = get_node_hashes(query, push_leaves = push_leaves)
67+
n_hash_pairs, n_nc, n_hash_container = get_node_hashes(reference, push_leaves = push_leaves)
68+
newt_hc = collect(zip(newt_hash_pairs, newt_hash_container))
69+
n_hc = collect(zip(n_hash_pairs, n_hash_container))
6470
newt_hash2node = Dict(zip(newt_hc, newt_nc))
6571
n_hash2node = Dict(zip(n_hc, n_nc))
6672
return map(h -> (newt_hash2node[h], n_hash2node[h]), filter(x -> haskey(n_hash2node, x), newt_hc))

src/utils/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ include("misc.jl")
22
include("simple_optim.jl")
33
include("simple_sample.jl")
44
include("tree_hash.jl")
5+
include("HIPSTR.jl")
56
#fasta_io.jl is optionally included with Requires.jl in MolecularEvolution.jl

0 commit comments

Comments
 (0)