Skip to content

Commit

Permalink
start traversal-clustering logic
Browse files Browse the repository at this point in the history
  • Loading branch information
glennhickey committed May 23, 2024
1 parent 3e766a7 commit cb1c4d1
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 27 deletions.
20 changes: 14 additions & 6 deletions src/deconstructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ vector<int> Deconstructor::get_alleles(vcflib::Variant& v,
const vector<Traversal>& travs,
const vector<pair<step_handle_t, step_handle_t>>& trav_steps,
int ref_path_idx,
const vector<bool>& use_trav,
const vector<vector<int>>& trav_clusters,
char prev_char, bool use_start) const {

assert(ref_path_idx >=0 && ref_path_idx < travs.size());
Expand Down Expand Up @@ -57,10 +57,11 @@ vector<int> Deconstructor::get_alleles(vcflib::Variant& v,
bool substitution = true;

// set the other alleles (they can end up as 0 alleles too if their strings match the reference)
for (int i = 0; i < travs.size(); ++i) {
if (i != ref_path_idx) {
if (use_trav[i]) {
string allele = trav_to_string(travs[i]);
// note that we have one (unique) allele per cluster, so we take advantage of that here
for (const vector<int>& cluster : trav_clusters) {
string allele = trav_to_string(travs[cluster.front()]);
for (const int& i : cluster) {
if (i != ref_path_idx) {
auto ai_it = allele_idx.find(allele);
if (ai_it == allele_idx.end()) {
// make a new allele for this string
Expand Down Expand Up @@ -797,9 +798,15 @@ bool Deconstructor::deconstruct_site(const handle_t& snarl_start, const handle_t
}
}

// Sort the traversals for clustering
vector<int> sorted_travs = get_traversal_order(graph, travs, trav_path_names, ref_travs, use_trav);

// jaccard clustering (using handles for now) on traversals
vector<vector<int>> trav_clusters = cluster_traversals(graph, travs, sorted_travs, cluster_threshold);

vector<int> trav_to_allele = get_alleles(v, travs, trav_steps,
ref_trav_idx,
use_trav,
trav_clusters,
prev_char, use_start);

// Fill in the genotypes
Expand Down Expand Up @@ -1074,6 +1081,7 @@ void Deconstructor::deconstruct(vector<string> ref_paths, const PathPositionHand
bool keep_conflicted,
bool strict_conflicts,
bool long_ref_contig,
double cluster_threshold,
gbwt::GBWT* gbwt) {

this->graph = graph;
Expand Down
8 changes: 7 additions & 1 deletion src/deconstructor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Deconstructor : public VCFOutputCaller {
bool keep_conflicted,
bool strict_conflicts,
bool long_ref_contig,
double cluster_threshold = 1.0,
gbwt::GBWT* gbwt = nullptr);

private:
Expand Down Expand Up @@ -78,7 +79,7 @@ class Deconstructor : public VCFOutputCaller {
const vector<Traversal>& travs,
const vector<pair<step_handle_t, step_handle_t>>& trav_steps,
int ref_path_idx,
const vector<bool>& use_trav,
const vector<vector<int>>& trav_clusters,
char prev_char, bool use_start) const;

// write traversal path names as genotypes
Expand Down Expand Up @@ -145,6 +146,11 @@ class Deconstructor : public VCFOutputCaller {

// should we keep conflicted genotypes or not
bool keep_conflicted_genotypes = false;

// used to merge together similar traversals (to keep allele counts down)
// currently implemented as handle jaccard coefficient. So 1 means only
// merge if identical (which is what deconstruct has always done)
double cluster_threshold = 1.0;
};


Expand Down
13 changes: 10 additions & 3 deletions src/subcommand/deconstruct_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void help_deconstruct(char** argv){
<< " -K, --keep-conflicted Retain conflicted genotypes in output." << endl
<< " -S, --strict-conflicts Drop genotypes when we have more than one haplotype for any given phase (set by default when using GBWT input)." << endl
<< " -C, --contig-only-ref Only use the CONTIG name (and not SAMPLE#CONTIG#HAPLOTYPE etc) for the reference if possible (ie there is only one reference sample)." << endl
<< " -L, --cluster F Cluster traversals whose (handle) Jaccard coefficient is >= F together (default: 1.0)" << endl
<< " -t, --threads N Use N threads" << endl
<< " -v, --verbose Print some status messages" << endl
<< endl;
Expand All @@ -77,6 +78,7 @@ int main_deconstruct(int argc, char** argv){
int context_jaccard_window = 10000;
bool untangle_traversals = false;
bool contig_only_ref = false;
double cluster_threshold = 1.0;

int c;
optind = 2; // force optind past command positional argument
Expand All @@ -97,14 +99,15 @@ int main_deconstruct(int argc, char** argv){
{"all-snarls", no_argument, 0, 'a'},
{"keep-conflicted", no_argument, 0, 'K'},
{"strict-conflicts", no_argument, 0, 'S'},
{"contig-only-ref", no_argument, 0, 'C'},
{"contig-only-ref", no_argument, 0, 'C'},
{"cluster", required_argument, 0, 'L'},
{"threads", required_argument, 0, 't'},
{"verbose", no_argument, 0, 'v'},
{0, 0, 0, 0}
};

int option_index = 0;
c = getopt_long (argc, argv, "hp:P:H:r:g:T:OeKSCd:c:uat:v",
c = getopt_long (argc, argv, "hp:P:H:r:g:T:OeKSCd:c:uaL:t:v",
long_options, &option_index);

// Detect the end of the options.
Expand Down Expand Up @@ -158,7 +161,10 @@ int main_deconstruct(int argc, char** argv){
break;
case 'C':
contig_only_ref = true;
break;
break;
case 'L':
cluster_threshold = min(0.0, max(1.0, parse<double>(optarg)));
break;
case 't':
omp_set_num_threads(parse<int>(optarg));
break;
Expand Down Expand Up @@ -344,6 +350,7 @@ int main_deconstruct(int argc, char** argv){
keep_conflicted,
strict_conflicts,
!contig_only_ref,
cluster_threshold,
gbwt_index);
return 0;
}
Expand Down
72 changes: 59 additions & 13 deletions src/traversal_clusters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,82 @@

namespace vg {


vector<int> get_traversal_order(const PathHandleGraph* graph,
const vector<Traversal>& traversals,
const vector<string>& trav_path_names,
const vector<int>& ref_travs,
const vector<bool>& use_traversal) {
set<int> ref_set(ref_travs.begin(), ref_travs.end());

function<bool(int, int)> trav_less = [&](int i, int j) {
if (ref_set.count(i) && !ref_set.count(j)) {
return true;
}
if (!trav_path_names[i].empty() && (!trav_path_names[j].empty() || trav_path_names[i] < trav_path_names[j])) {
return true;
}
return false;
};

vector<int> sorted_travs;
sorted_travs.reserve(traversals.size());
for (int64_t i = 0; i < traversals.size(); ++i) {
if (use_traversal[i]) {
sorted_travs.push_back(i);
}
}
std::sort(sorted_travs.begin(), sorted_travs.end(), trav_less);
return sorted_travs;
}

vector<vector<int>> cluster_traversals(const PathHandleGraph* graph,
const vector<Traversal>& traversals,
const vector<int64_t>& traversal_order,
const vector<int>& traversal_order,
double min_jaccard) {

assert(traversal_order.size() == traversals.size());
assert(traversal_order.size() <= traversals.size());

// the values are indexes in the input traversals vector
// the "reference" traversal of each cluster (to which distance is computed)
// is always its first element
vector<vector<int>> clusters;

for (const int64_t& i : traversal_order) {
const Traversal& trav = traversals[i];
double min_jaccard = numeric_limits<double>::max();
int64_t min_cluster_idx = -1;
// need the clusters as sorted lists. we'll forget the endpoints while we're at
// it since they're always shared. note we work with multisets since we want to
// count differences between, say, cycle copy numbers.
vector<multiset<handle_t>> sorted_traversals;
for (const Traversal& trav : traversals) {
assert(trav.size() >=2 );
// prune snarl nodes as they're always the same
// todo: does jaccard properly handle empty sets?
int64_t first = trav.size() == 2 ? 0 : 1;
int64_t last = trav.size() == 2 ? trav.size() : trav.size() - 1;
multiset<handle_t> sorted_trav;
for (int64_t i = first; i < last; ++i) {
sorted_trav.insert(trav[i]);
}
sorted_traversals.push_back(sorted_trav);
}

for (const int& i : traversal_order) {
const auto& trav = sorted_traversals[i];
double max_jaccard = 0;
int64_t max_cluster_idx = -1;
for (int64_t j = 0; j < clusters.size(); ++j) {
const Traversal& cluster_trav = traversals[clusters[j][0]];
const auto& cluster_trav = sorted_traversals[clusters[j][0]];
double jac = jaccard_coefficient(trav, cluster_trav);
if (jac < min_jaccard) {
min_jaccard = jac;
min_cluster_idx = j;
if (jac == 0) {
if (jac > max_jaccard) {
max_jaccard = jac;
max_cluster_idx = j;
if (jac == 1) {
break;
}
}
}
if (min_cluster_idx >= 0) {
if (max_cluster_idx >= 0 && max_jaccard >= min_jaccard) {
// we've found a suitably similar cluster, add it do that
clusters[min_cluster_idx].push_back(i);
clusters[max_cluster_idx].push_back(i);
} else {
// there's no cluster close enough, need to start a new one
clusters.push_back({i});
Expand Down
26 changes: 24 additions & 2 deletions src/traversal_clusters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,36 @@ inline double jaccard_coefficient(const T& target, const U& query) {

}

// the information needed from the parent traversal in order to
// genotype a child traversal
struct ParentGenotypeInfo {
Traversal ref_traversal;
pair<step_handle_t, step_handle_t> ref_steps;
unordered_map<string, int64_t> sample_to_ploidy; // to check/enforce consistency
};


/// sort the traversals, putting the reference first then using names
/// traversals masked out by use_traversal will be filrtered out entirely
/// (so the output vector may be smaller than the input...)
vector<int> get_traversal_order(const PathHandleGraph* graph,
const vector<Traversal>& traversals,
const vector<string>& trav_path_names,
const vector<int>& ref_travs,
const vector<bool>& use_traversal);


/// cluster the traversals. The algorithm is:
/// - visit traversals in provided order
/// - if the traversal is <= min_jaccard away from the reference traversal of cluster, add to cluster
/// - else start a new cluster, with the given traversal as a reference
/// note that traversal_order can specify a subset of traversals
vector<vector<int>> cluster_traversals(const PathHandleGraph* graph,
const vector<Traversal>& traversals,
function<int64_t(int64_t)> traversal_order,
const vector<int>& traversal_order,
double min_jaccard);


//int64_t find_parent_traversal(const PathHandleGraph* graph,
// const vector<Traversal>& traversals,

}
2 changes: 1 addition & 1 deletion src/traversal_finder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace vg {

using namespace std;

string traversal_to_string(const PathHandleGraph* graph, const Traversal& traversal, bool max_steps) {
string traversal_to_string(const PathHandleGraph* graph, const Traversal& traversal, int64_t max_steps) {
string s;
function<string(handle_t)> handle_to_string = [&](handle_t handle) {
string ss = graph->get_is_reverse(handle) ? "<" : ">";
Expand Down
2 changes: 1 addition & 1 deletion src/traversal_finder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class AugmentedGraph;
// some Protobuf replacements
using Traversal = vector<handle_t>;
using PathInterval = pair<step_handle_t, step_handle_t>;
string traversal_to_string(const PathHandleGraph* graph, const Traversal& traversal, bool max_steps = 10);
string traversal_to_string(const PathHandleGraph* graph, const Traversal& traversal, int64_t max_steps = 10);
// replaces pb2json(snarl)
string graph_interval_to_string(const HandleGraph* graph, const handle_t& start_handle, const handle_t& end_handle);

Expand Down

1 comment on commit cb1c4d1

@adamnovak
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vg CI tests complete for branch deconstruct. View the full report here.

16 tests passed, 0 tests failed and 0 tests skipped in 17635 seconds

Please sign in to comment.