From cb1c4d154ec0cf0042d319a760b36dc881957987 Mon Sep 17 00:00:00 2001 From: Glenn Hickey Date: Thu, 23 May 2024 13:58:42 -0400 Subject: [PATCH] start traversal-clustering logic --- src/deconstructor.cpp | 20 +++++--- src/deconstructor.hpp | 8 +++- src/subcommand/deconstruct_main.cpp | 13 ++++-- src/traversal_clusters.cpp | 72 +++++++++++++++++++++++------ src/traversal_clusters.hpp | 26 ++++++++++- src/traversal_finder.cpp | 2 +- src/traversal_finder.hpp | 2 +- 7 files changed, 116 insertions(+), 27 deletions(-) diff --git a/src/deconstructor.cpp b/src/deconstructor.cpp index 5dea7b7be50..7228669f51e 100644 --- a/src/deconstructor.cpp +++ b/src/deconstructor.cpp @@ -26,7 +26,7 @@ vector Deconstructor::get_alleles(vcflib::Variant& v, const vector& travs, const vector>& trav_steps, int ref_path_idx, - const vector& use_trav, + const vector>& trav_clusters, char prev_char, bool use_start) const { assert(ref_path_idx >=0 && ref_path_idx < travs.size()); @@ -57,10 +57,11 @@ vector Deconstructor::get_alleles(vcflib::Variant& v, bool substitution = true; // set the other alleles (they can end up as 0 alleles too if their strings match the reference) - for (int i = 0; i < travs.size(); ++i) { - if (i != ref_path_idx) { - if (use_trav[i]) { - string allele = trav_to_string(travs[i]); + // note that we have one (unique) allele per cluster, so we take advantage of that here + for (const vector& cluster : trav_clusters) { + string allele = trav_to_string(travs[cluster.front()]); + for (const int& i : cluster) { + if (i != ref_path_idx) { auto ai_it = allele_idx.find(allele); if (ai_it == allele_idx.end()) { // make a new allele for this string @@ -797,9 +798,15 @@ bool Deconstructor::deconstruct_site(const handle_t& snarl_start, const handle_t } } + // Sort the traversals for clustering + vector sorted_travs = get_traversal_order(graph, travs, trav_path_names, ref_travs, use_trav); + + // jaccard clustering (using handles for now) on traversals + vector> trav_clusters = cluster_traversals(graph, travs, sorted_travs, cluster_threshold); + vector trav_to_allele = get_alleles(v, travs, trav_steps, ref_trav_idx, - use_trav, + trav_clusters, prev_char, use_start); // Fill in the genotypes @@ -1074,6 +1081,7 @@ void Deconstructor::deconstruct(vector ref_paths, const PathPositionHand bool keep_conflicted, bool strict_conflicts, bool long_ref_contig, + double cluster_threshold, gbwt::GBWT* gbwt) { this->graph = graph; diff --git a/src/deconstructor.hpp b/src/deconstructor.hpp index 19f8e0300ff..af9668a2c41 100644 --- a/src/deconstructor.hpp +++ b/src/deconstructor.hpp @@ -45,6 +45,7 @@ class Deconstructor : public VCFOutputCaller { bool keep_conflicted, bool strict_conflicts, bool long_ref_contig, + double cluster_threshold = 1.0, gbwt::GBWT* gbwt = nullptr); private: @@ -78,7 +79,7 @@ class Deconstructor : public VCFOutputCaller { const vector& travs, const vector>& trav_steps, int ref_path_idx, - const vector& use_trav, + const vector>& trav_clusters, char prev_char, bool use_start) const; // write traversal path names as genotypes @@ -145,6 +146,11 @@ class Deconstructor : public VCFOutputCaller { // should we keep conflicted genotypes or not bool keep_conflicted_genotypes = false; + + // used to merge together similar traversals (to keep allele counts down) + // currently implemented as handle jaccard coefficient. So 1 means only + // merge if identical (which is what deconstruct has always done) + double cluster_threshold = 1.0; }; diff --git a/src/subcommand/deconstruct_main.cpp b/src/subcommand/deconstruct_main.cpp index 497922e5b70..2ccc8192786 100644 --- a/src/subcommand/deconstruct_main.cpp +++ b/src/subcommand/deconstruct_main.cpp @@ -51,6 +51,7 @@ void help_deconstruct(char** argv){ << " -K, --keep-conflicted Retain conflicted genotypes in output." << endl << " -S, --strict-conflicts Drop genotypes when we have more than one haplotype for any given phase (set by default when using GBWT input)." << endl << " -C, --contig-only-ref Only use the CONTIG name (and not SAMPLE#CONTIG#HAPLOTYPE etc) for the reference if possible (ie there is only one reference sample)." << endl + << " -L, --cluster F Cluster traversals whose (handle) Jaccard coefficient is >= F together (default: 1.0)" << endl << " -t, --threads N Use N threads" << endl << " -v, --verbose Print some status messages" << endl << endl; @@ -77,6 +78,7 @@ int main_deconstruct(int argc, char** argv){ int context_jaccard_window = 10000; bool untangle_traversals = false; bool contig_only_ref = false; + double cluster_threshold = 1.0; int c; optind = 2; // force optind past command positional argument @@ -97,14 +99,15 @@ int main_deconstruct(int argc, char** argv){ {"all-snarls", no_argument, 0, 'a'}, {"keep-conflicted", no_argument, 0, 'K'}, {"strict-conflicts", no_argument, 0, 'S'}, - {"contig-only-ref", no_argument, 0, 'C'}, + {"contig-only-ref", no_argument, 0, 'C'}, + {"cluster", required_argument, 0, 'L'}, {"threads", required_argument, 0, 't'}, {"verbose", no_argument, 0, 'v'}, {0, 0, 0, 0} }; int option_index = 0; - c = getopt_long (argc, argv, "hp:P:H:r:g:T:OeKSCd:c:uat:v", + c = getopt_long (argc, argv, "hp:P:H:r:g:T:OeKSCd:c:uaL:t:v", long_options, &option_index); // Detect the end of the options. @@ -158,7 +161,10 @@ int main_deconstruct(int argc, char** argv){ break; case 'C': contig_only_ref = true; - break; + break; + case 'L': + cluster_threshold = min(0.0, max(1.0, parse(optarg))); + break; case 't': omp_set_num_threads(parse(optarg)); break; @@ -344,6 +350,7 @@ int main_deconstruct(int argc, char** argv){ keep_conflicted, strict_conflicts, !contig_only_ref, + cluster_threshold, gbwt_index); return 0; } diff --git a/src/traversal_clusters.cpp b/src/traversal_clusters.cpp index c5a01e3b00e..0b7b26aed09 100644 --- a/src/traversal_clusters.cpp +++ b/src/traversal_clusters.cpp @@ -2,36 +2,82 @@ namespace vg { + +vector get_traversal_order(const PathHandleGraph* graph, + const vector& traversals, + const vector& trav_path_names, + const vector& ref_travs, + const vector& use_traversal) { + set ref_set(ref_travs.begin(), ref_travs.end()); + + function trav_less = [&](int i, int j) { + if (ref_set.count(i) && !ref_set.count(j)) { + return true; + } + if (!trav_path_names[i].empty() && (!trav_path_names[j].empty() || trav_path_names[i] < trav_path_names[j])) { + return true; + } + return false; + }; + + vector sorted_travs; + sorted_travs.reserve(traversals.size()); + for (int64_t i = 0; i < traversals.size(); ++i) { + if (use_traversal[i]) { + sorted_travs.push_back(i); + } + } + std::sort(sorted_travs.begin(), sorted_travs.end(), trav_less); + return sorted_travs; +} + vector> cluster_traversals(const PathHandleGraph* graph, const vector& traversals, - const vector& traversal_order, + const vector& traversal_order, double min_jaccard) { - assert(traversal_order.size() == traversals.size()); + assert(traversal_order.size() <= traversals.size()); // the values are indexes in the input traversals vector // the "reference" traversal of each cluster (to which distance is computed) // is always its first element vector> clusters; - for (const int64_t& i : traversal_order) { - const Traversal& trav = traversals[i]; - double min_jaccard = numeric_limits::max(); - int64_t min_cluster_idx = -1; + // need the clusters as sorted lists. we'll forget the endpoints while we're at + // it since they're always shared. note we work with multisets since we want to + // count differences between, say, cycle copy numbers. + vector> sorted_traversals; + for (const Traversal& trav : traversals) { + assert(trav.size() >=2 ); + // prune snarl nodes as they're always the same + // todo: does jaccard properly handle empty sets? + int64_t first = trav.size() == 2 ? 0 : 1; + int64_t last = trav.size() == 2 ? trav.size() : trav.size() - 1; + multiset sorted_trav; + for (int64_t i = first; i < last; ++i) { + sorted_trav.insert(trav[i]); + } + sorted_traversals.push_back(sorted_trav); + } + + for (const int& i : traversal_order) { + const auto& trav = sorted_traversals[i]; + double max_jaccard = 0; + int64_t max_cluster_idx = -1; for (int64_t j = 0; j < clusters.size(); ++j) { - const Traversal& cluster_trav = traversals[clusters[j][0]]; + const auto& cluster_trav = sorted_traversals[clusters[j][0]]; double jac = jaccard_coefficient(trav, cluster_trav); - if (jac < min_jaccard) { - min_jaccard = jac; - min_cluster_idx = j; - if (jac == 0) { + if (jac > max_jaccard) { + max_jaccard = jac; + max_cluster_idx = j; + if (jac == 1) { break; } } } - if (min_cluster_idx >= 0) { + if (max_cluster_idx >= 0 && max_jaccard >= min_jaccard) { // we've found a suitably similar cluster, add it do that - clusters[min_cluster_idx].push_back(i); + clusters[max_cluster_idx].push_back(i); } else { // there's no cluster close enough, need to start a new one clusters.push_back({i}); diff --git a/src/traversal_clusters.hpp b/src/traversal_clusters.hpp index 13b869a41f2..b03bd779267 100644 --- a/src/traversal_clusters.hpp +++ b/src/traversal_clusters.hpp @@ -42,14 +42,36 @@ inline double jaccard_coefficient(const T& target, const U& query) { } +// the information needed from the parent traversal in order to +// genotype a child traversal +struct ParentGenotypeInfo { + Traversal ref_traversal; + pair ref_steps; + unordered_map sample_to_ploidy; // to check/enforce consistency +}; + + +/// sort the traversals, putting the reference first then using names +/// traversals masked out by use_traversal will be filrtered out entirely +/// (so the output vector may be smaller than the input...) +vector get_traversal_order(const PathHandleGraph* graph, + const vector& traversals, + const vector& trav_path_names, + const vector& ref_travs, + const vector& use_traversal); + + /// cluster the traversals. The algorithm is: /// - visit traversals in provided order /// - if the traversal is <= min_jaccard away from the reference traversal of cluster, add to cluster /// - else start a new cluster, with the given traversal as a reference +/// note that traversal_order can specify a subset of traversals vector> cluster_traversals(const PathHandleGraph* graph, const vector& traversals, - function traversal_order, + const vector& traversal_order, double min_jaccard); - +//int64_t find_parent_traversal(const PathHandleGraph* graph, +// const vector& traversals, + } diff --git a/src/traversal_finder.cpp b/src/traversal_finder.cpp index 4a2b7498f9b..2c1a1590c95 100644 --- a/src/traversal_finder.cpp +++ b/src/traversal_finder.cpp @@ -10,7 +10,7 @@ namespace vg { using namespace std; -string traversal_to_string(const PathHandleGraph* graph, const Traversal& traversal, bool max_steps) { +string traversal_to_string(const PathHandleGraph* graph, const Traversal& traversal, int64_t max_steps) { string s; function handle_to_string = [&](handle_t handle) { string ss = graph->get_is_reverse(handle) ? "<" : ">"; diff --git a/src/traversal_finder.hpp b/src/traversal_finder.hpp index 349d6c27f00..e39499cba03 100644 --- a/src/traversal_finder.hpp +++ b/src/traversal_finder.hpp @@ -39,7 +39,7 @@ class AugmentedGraph; // some Protobuf replacements using Traversal = vector; using PathInterval = pair; -string traversal_to_string(const PathHandleGraph* graph, const Traversal& traversal, bool max_steps = 10); +string traversal_to_string(const PathHandleGraph* graph, const Traversal& traversal, int64_t max_steps = 10); // replaces pb2json(snarl) string graph_interval_to_string(const HandleGraph* graph, const handle_t& start_handle, const handle_t& end_handle);