|
| 1 | +export HIPSTR |
| 2 | +""" |
| 3 | + HIPSTR(trees::Vector{FelNode}; set_branchlengths = true) |
| 4 | +
|
| 5 | +Construct a Highest Independent Posterior Subtree Reconstruction (HIPSTR) tree |
| 6 | +from a collection of trees. |
| 7 | +
|
| 8 | +Returns a single FelNode representing the HIPSTR consensus tree. |
| 9 | +
|
| 10 | +If `set_branchlengths = true`, the branch length of a node in the HIPSTR tree will be set to the mean branch length of all nodes from the input trees that have the same clade. (By the same clade, we mean that the set of leaves below the node is the same.) Otherwise, the root branch length is 0.0 and the rest 1.0. |
| 11 | +
|
| 12 | +Source: https://www.biorxiv.org/content/10.1101/2024.12.08.627395v1.full.pdf |
| 13 | +""" |
| 14 | +function HIPSTR(trees::Vector{FelNode}; set_branchlengths = true) |
| 15 | + |
| 16 | + # Step 1: Collect all clades, their frequencies, and child pairs |
| 17 | + clades_stats = Dict{Tuple{UInt64, UInt64}, CladeStats}() |
| 18 | + leaf_names = Set{String}() |
| 19 | + |
| 20 | + # First identify all leaf nodes across all trees |
| 21 | + for tree in trees |
| 22 | + for leaf in getleaflist(tree) |
| 23 | + push!(leaf_names, leaf.name) |
| 24 | + end |
| 25 | + end |
| 26 | + |
| 27 | + # Create a mapping from leaf names to indices for consistent hashing |
| 28 | + leaf_dict = Dict(name => i for (i, name) in enumerate(sort(collect(leaf_names)))) |
| 29 | + |
| 30 | + # Process each tree to collect clade information |
| 31 | + for tree in trees |
| 32 | + collect_clades!(tree, clades_stats, leaf_dict) |
| 33 | + end |
| 34 | + |
| 35 | + # Scale clade frequencies to get credibilities |
| 36 | + n_trees = length(trees) |
| 37 | + for (_, stats) in clades_stats |
| 38 | + stats.frequency /= n_trees |
| 39 | + end |
| 40 | + |
| 41 | + # Step 2: Compute the root clade hash (all tips) |
| 42 | + all_tips = BitSet(1:length(leaf_dict)) |
| 43 | + root_hash = hash_clade(all_tips) |
| 44 | + |
| 45 | + # Step 3: Build the credibility cache through post-order traversal |
| 46 | + cred_cache = Dict{Tuple{UInt64, UInt64}, Tuple{Float64, Tuple{UInt64, UInt64}, Tuple{UInt64, UInt64}}}() |
| 47 | + compute_credibility = function(clade_hash) |
| 48 | + # Return from cache if available |
| 49 | + haskey(cred_cache, clade_hash) && return cred_cache[clade_hash][1] |
| 50 | + |
| 51 | + # Base case: single tip or clade not found |
| 52 | + if !haskey(clades_stats, clade_hash) || isempty(clades_stats[clade_hash].child_pairs) |
| 53 | + cred_cache[clade_hash] = (clades_stats[clade_hash].frequency, (0, 0), (0, 0)) |
| 54 | + return clades_stats[clade_hash].frequency |
| 55 | + end |
| 56 | + |
| 57 | + # Find the best child pair |
| 58 | + best_cred = 0.0 |
| 59 | + best_left = (0, 0) |
| 60 | + best_right = (0, 0) |
| 61 | + |
| 62 | + for (left_hash, right_hash) in clades_stats[clade_hash].child_pairs |
| 63 | + left_cred = compute_credibility(left_hash) |
| 64 | + right_cred = compute_credibility(right_hash) |
| 65 | + |
| 66 | + # Product of the credibilities |
| 67 | + pair_cred = left_cred * right_cred * clades_stats[clade_hash].frequency |
| 68 | + |
| 69 | + if pair_cred > best_cred |
| 70 | + best_cred = pair_cred |
| 71 | + best_left = left_hash |
| 72 | + best_right = right_hash |
| 73 | + end |
| 74 | + end |
| 75 | + |
| 76 | + # Cache and return |
| 77 | + cred_cache[clade_hash] = (best_cred, best_left, best_right) |
| 78 | + return best_cred |
| 79 | + end |
| 80 | + |
| 81 | + # Compute credibility for the root clade |
| 82 | + compute_credibility(root_hash) |
| 83 | + |
| 84 | + # Step 4: Construct the HIPSTR tree through another traversal |
| 85 | + reverse_leaf_dict = Dict(i => name for (name, i) in leaf_dict) |
| 86 | + |
| 87 | + # Function to build the tree recursively |
| 88 | + function build_tree(clade_hash) |
| 89 | + _, left_hash, right_hash = cred_cache[clade_hash] |
| 90 | + |
| 91 | + # Handle leaf case |
| 92 | + if left_hash == (0, 0) && right_hash == (0, 0) |
| 93 | + # Determine which tip this is |
| 94 | + for (index, name) in reverse_leaf_dict |
| 95 | + tip_hash = hash_clade(BitSet([index])) |
| 96 | + if tip_hash == clade_hash |
| 97 | + node = FelNode(1.0, name) |
| 98 | + node.seqindex = index |
| 99 | + return node |
| 100 | + end |
| 101 | + end |
| 102 | + error("Failed to find leaf corresponding to hash $clade_hash") |
| 103 | + end |
| 104 | + |
| 105 | + # Internal node |
| 106 | + node = FelNode(1.0, "") |
| 107 | + |
| 108 | + # Add children |
| 109 | + left_child = build_tree(left_hash) |
| 110 | + right_child = build_tree(right_hash) |
| 111 | + |
| 112 | + # Default branch lengths to 1.0 if we don't have better information |
| 113 | + left_child.branchlength = 1.0 |
| 114 | + right_child.branchlength = 1.0 |
| 115 | + |
| 116 | + left_child.parent = node |
| 117 | + right_child.parent = node |
| 118 | + push!(node.children, left_child) |
| 119 | + push!(node.children, right_child) |
| 120 | + |
| 121 | + return node |
| 122 | + end |
| 123 | + |
| 124 | + # Build the final tree |
| 125 | + hipstr_tree = build_tree(root_hash) |
| 126 | + # Set the root branchlength to 0.0 |
| 127 | + hipstr_tree.branchlength = 0.0 |
| 128 | + |
| 129 | + # Set node indices |
| 130 | + set_node_indices!(hipstr_tree) |
| 131 | + |
| 132 | + # Set branch lengths to mean |
| 133 | + set_branchlengths && set_mean_branchlengths!(hipstr_tree, trees) |
| 134 | + |
| 135 | + return hipstr_tree |
| 136 | +end |
| 137 | + |
| 138 | +function set_mean_branchlengths!(tree::FelNode, trees::Vector{FelNode}) |
| 139 | + #Set branchlengths to 0.0, we'll use these as accumulators and then eventually normalize |
| 140 | + tree_nodes = nodes(tree) |
| 141 | + for node in tree_nodes |
| 142 | + node.branchlength = 0.0 |
| 143 | + end |
| 144 | + #Initialize container for counting the amount of terms in the accumulators |
| 145 | + branch_length_counts = Dict(zip(tree_nodes, zeros(Int64, length(tree_nodes)))) |
| 146 | + #Go through all the trees, t... |
| 147 | + for t in trees |
| 148 | + #and for each matching clade, (tree_node, t_node), between tree and t... |
| 149 | + matching_pairs = tree_match_pairs(tree, t, push_leaves = true) |
| 150 | + for (tree_node, t_node) in matching_pairs |
| 151 | + #add the branch length of t_node in t to the branch length of tree_node in tree... |
| 152 | + tree_node.branchlength += t_node.branchlength |
| 153 | + #and increment the counter for tree_node. |
| 154 | + branch_length_counts[tree_node] += 1 |
| 155 | + end |
| 156 | + end |
| 157 | + #Then normalize the branch length of all nodes in tree by the number of matching clades. |
| 158 | + for node in tree_nodes |
| 159 | + branch_length_counts[node] == 0 && @warn "Branch length counts for node $(node.name) with nodeindex $(node.nodeindex) is 0. Coming from a `HIPSTR` call, this number should be strictly positive." |
| 160 | + node.branchlength /= max(1, branch_length_counts[node]) |
| 161 | + end |
| 162 | + #^^ This deviates somewhat from the HIPSTR paper. They find distributions over node ages (with respect to a molecular clock), which differs a bit from the notion of branch length. |
| 163 | +end |
| 164 | +""" |
| 165 | +Store statistics about a clade: its frequency and observed child pairs. |
| 166 | +""" |
| 167 | +mutable struct CladeStats |
| 168 | + frequency::Float64 |
| 169 | + child_pairs::Set{Tuple{Tuple{UInt64, UInt64}, Tuple{UInt64, UInt64}}} |
| 170 | + |
| 171 | + CladeStats() = new(0.0, Set{Tuple{Tuple{UInt64, UInt64}, Tuple{UInt64, UInt64}}}()) |
| 172 | +end |
| 173 | + |
| 174 | +""" |
| 175 | +Compute a hash for a clade based on its tips. |
| 176 | +""" |
| 177 | +function hash_clade(tips::BitSet) |
| 178 | + h1 = hash(tips) |
| 179 | + h2 = hash(reverse(collect(tips))) |
| 180 | + return (h1, h2) |
| 181 | +end |
| 182 | + |
| 183 | +""" |
| 184 | +Recursively collect clades from a tree. |
| 185 | +""" |
| 186 | +function collect_clades!(node::FelNode, clades_stats::Dict{Tuple{UInt64, UInt64}, CladeStats}, leaf_dict::Dict{String, Int}) |
| 187 | + # Get tips under this node |
| 188 | + tips = BitSet() |
| 189 | + |
| 190 | + if isleafnode(node) |
| 191 | + # For a leaf, the tips are just this node |
| 192 | + if haskey(leaf_dict, node.name) |
| 193 | + push!(tips, leaf_dict[node.name]) |
| 194 | + else |
| 195 | + # Skip if the leaf name is not recognized |
| 196 | + @warn "Skipping unrecognized leaf name: $(node.name)" |
| 197 | + return tips |
| 198 | + end |
| 199 | + else |
| 200 | + # For internal nodes, combine tips from children |
| 201 | + for child in node.children |
| 202 | + union!(tips, collect_clades!(child, clades_stats, leaf_dict)) |
| 203 | + end |
| 204 | + end |
| 205 | + |
| 206 | + # Compute the clade hash |
| 207 | + clade_hash = hash_clade(tips) |
| 208 | + |
| 209 | + # Update clade stats |
| 210 | + if !haskey(clades_stats, clade_hash) |
| 211 | + clades_stats[clade_hash] = CladeStats() |
| 212 | + end |
| 213 | + clades_stats[clade_hash].frequency += 1 |
| 214 | + |
| 215 | + # For internal nodes, record child pairs |
| 216 | + if !isleafnode(node) && length(node.children) == 2 |
| 217 | + left_tips = BitSet() |
| 218 | + for leaf in getleaflist(node.children[1]) |
| 219 | + if haskey(leaf_dict, leaf.name) |
| 220 | + push!(left_tips, leaf_dict[leaf.name]) |
| 221 | + end |
| 222 | + end |
| 223 | + |
| 224 | + right_tips = BitSet() |
| 225 | + for leaf in getleaflist(node.children[2]) |
| 226 | + if haskey(leaf_dict, leaf.name) |
| 227 | + push!(right_tips, leaf_dict[leaf.name]) |
| 228 | + end |
| 229 | + end |
| 230 | + |
| 231 | + left_hash = hash_clade(left_tips) |
| 232 | + right_hash = hash_clade(right_tips) |
| 233 | + |
| 234 | + # Add the child pair to the set for this clade |
| 235 | + push!(clades_stats[clade_hash].child_pairs, (left_hash, right_hash)) |
| 236 | + end |
| 237 | + |
| 238 | + return tips |
| 239 | +end |
0 commit comments