Skip to content

Commit 680caec

Browse files
committed
Let --aux-len work correctly
This includes some refactoring. In particular, instead of aux_len, we now use a main_hash_mask and instead of right-shifting by aux_len, we use the mask to get only the relevant part of the hash.
1 parent dd95fb4 commit 680caec

13 files changed

+142
-55
lines changed

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ add_executable(test-strobealign
132132
tests/test_aligner.cpp
133133
tests/test_cigar.cpp
134134
tests/test_randstrobes.cpp
135+
tests/test_indexparameters.cpp
135136
)
136137
target_link_libraries(test-strobealign salib)
137138
target_include_directories(test-strobealign PUBLIC src/ ext/ ${PROJECT_BINARY_DIR})

src/arguments.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct SeedingArguments {
2626
"results with non default values.", {'s'}}
2727
, bits{parser, "INT", "No. of top bits of hash to use as bucket indices (8-31)"
2828
"[determined from reference size]", {'b'}}
29-
, aux_len{parser, "INT", "No. of bits to use from secondary strobe hash [24]", {"aux-len"}}
29+
, aux_len{parser, "INT", "No. of bits to use from secondary strobe hash [17]", {"aux-len"}}
3030
{
3131
}
3232
args::ArgumentParser& parser;

src/cmdline.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct CommandLineOptions {
5252
int u { 7 };
5353
int s { 16 };
5454
int c { 8 };
55-
int aux_len{26};
55+
int aux_len{17};
5656

5757
// Alignment
5858
int A { 2 };

src/dumpstrobes.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ int run_dumpstrobes(int argc, char **argv) {
101101
}
102102

103103
// Seeding
104-
int r{150}, k{20}, s{16}, c{8}, l{1}, u{7}, aux_len{26};
104+
int r{150}, k{20}, s{16}, c{8}, l{1}, u{7}, aux_len{17};
105105
int max_seed_len{};
106106

107107
bool k_set{false}, s_set{false}, c_set{false}, max_seed_len_set{false}, l_set{false}, u_set{false};

src/index.hpp

+23-27
Original file line numberDiff line numberDiff line change
@@ -59,25 +59,23 @@ struct StrobemerIndex {
5959

6060
// Find first entry that matches the given key
6161
size_t find_full(randstrobe_hash_t key) const {
62-
return find(key, 0);
62+
return find(key, RANDSTROBE_HASH_MASK);
6363
}
6464

6565
/*
6666
* Find the first entry that matches the main hash (ignoring the aux_len
6767
* least significant bits)
6868
*/
6969
size_t find_partial(randstrobe_hash_t key) const {
70-
return find(key, parameters.randstrobe.aux_len);
70+
return find(key, parameters.randstrobe.main_hash_mask);
7171
}
7272

7373
/*
74-
* Find first entry whose hash matches the given key, but ignore the
75-
* b least significant bits
74+
* Find first entry whose hash matches the given key. Mask both key and
75+
* entry by hash_mask.
7676
*/
77-
size_t find(randstrobe_hash_t key, uint8_t b) const {
78-
const unsigned int aux_len = b;
79-
randstrobe_hash_t key_prefix = key >> aux_len;
80-
77+
size_t find(randstrobe_hash_t key, uint64_t hash_mask) const {
78+
randstrobe_hash_t masked_key = key & hash_mask;
8179
constexpr int MAX_LINEAR_SEARCH = 4;
8280
const unsigned int top_N = key >> (64 - bits);
8381
bucket_index_t position_start = randstrobe_start_indices[top_N];
@@ -88,19 +86,20 @@ struct StrobemerIndex {
8886

8987
if (position_end - position_start < MAX_LINEAR_SEARCH) {
9088
for ( ; position_start < position_end; ++position_start) {
91-
if (randstrobes[position_start].hash() >> aux_len == key_prefix) return position_start;
92-
if (randstrobes[position_start].hash() >> aux_len > key_prefix) return end();
89+
if ((randstrobes[position_start].hash() & hash_mask) == masked_key) return position_start;
90+
if ((randstrobes[position_start].hash() & hash_mask) > masked_key) return end();
9391
}
9492
return end();
9593
}
96-
auto cmp = [&aux_len](const RefRandstrobe lhs, const RefRandstrobe rhs) {
97-
return (lhs.hash() >> aux_len) < (rhs.hash() >> aux_len); };
94+
auto cmp = [&hash_mask](const RefRandstrobe lhs, const RefRandstrobe rhs) {
95+
return (lhs.hash() & hash_mask) < (rhs.hash() & hash_mask);
96+
};
9897

9998
auto pos = std::lower_bound(randstrobes.begin() + position_start,
10099
randstrobes.begin() + position_end,
101100
RefRandstrobe{key, 0, 0, 0, 0},
102101
cmp);
103-
if (pos->hash() >> aux_len == key_prefix) return pos - randstrobes.begin();
102+
if ((pos->hash() & hash_mask) == masked_key) return pos - randstrobes.begin();
104103
return end();
105104
}
106105

@@ -114,7 +113,7 @@ struct StrobemerIndex {
114113

115114
randstrobe_hash_t get_main_hash(bucket_index_t position) const {
116115
if (position < randstrobes.size()) {
117-
return randstrobes[position].hash() >> parameters.randstrobe.aux_len;
116+
return randstrobes[position].hash() & parameters.randstrobe.main_hash_mask;
118117
} else {
119118
return end();
120119
}
@@ -129,8 +128,7 @@ struct StrobemerIndex {
129128
}
130129

131130
bool is_partial_filtered(bucket_index_t position) const {
132-
const unsigned int shift = parameters.randstrobe.aux_len;
133-
return (get_hash(position) >> shift) == (get_hash(position + partial_filter_cutoff) >> shift);
131+
return get_main_hash(position) == get_main_hash(position + partial_filter_cutoff);
134132
}
135133

136134
unsigned int get_strobe1_position(bucket_index_t position) const {
@@ -163,14 +161,14 @@ struct StrobemerIndex {
163161
}
164162

165163
unsigned int get_count_full(bucket_index_t position) const {
166-
return get_count(position, 0);
164+
return get_count(position, RANDSTROBE_HASH_MASK);
167165
}
168166

169167
unsigned int get_count_partial(bucket_index_t position) const {
170-
return get_count(position, parameters.randstrobe.aux_len);
168+
return get_count(position, parameters.randstrobe.main_hash_mask);
171169
}
172170

173-
unsigned int get_count(bucket_index_t position, uint8_t b) const {
171+
unsigned int get_count(bucket_index_t position, uint64_t hash_mask) const {
174172
// For 95% of cases, the result will be small and a brute force search
175173
// is the best option. Once, we go over MAX_LINEAR_SEARCH, though, we
176174
// use a binary search to get the next position
@@ -182,27 +180,25 @@ struct StrobemerIndex {
182180
// seed with the given hash to yield the number of seeds with this hash.
183181

184182
constexpr unsigned int MAX_LINEAR_SEARCH = 8;
185-
const unsigned int aux_len = b;
186183

187184
const auto key = randstrobes[position].hash();
188-
randstrobe_hash_t key_prefix = key >> aux_len;
185+
randstrobe_hash_t masked_key = key & hash_mask;
189186

190187
const unsigned int top_N = key >> (64 - bits);
191188
bucket_index_t position_end = randstrobe_start_indices[top_N + 1];
192189
uint64_t count = 1;
193190

194191
if (position_end - position < MAX_LINEAR_SEARCH) {
195192
for (bucket_index_t position_start = position + 1; position_start < position_end; ++position_start) {
196-
if (randstrobes[position_start].hash() >> aux_len == key_prefix){
193+
if ((randstrobes[position_start].hash() & hash_mask) == masked_key) {
197194
count += 1;
198-
}
199-
else{
195+
} else {
200196
break;
201197
}
202198
}
203199
return count;
204200
}
205-
auto cmp = [&aux_len](const RefRandstrobe lhs, const RefRandstrobe rhs) {return (lhs.hash() >> aux_len) < (rhs.hash() >> aux_len); };
201+
auto cmp = [&hash_mask](const RefRandstrobe lhs, const RefRandstrobe rhs) {return (lhs.hash() & hash_mask) < (rhs.hash() & hash_mask); };
206202

207203
auto pos = std::upper_bound(randstrobes.begin() + position,
208204
randstrobes.begin() + position_end,
@@ -223,8 +219,8 @@ struct StrobemerIndex {
223219
return bits;
224220
}
225221

226-
int get_aux_len() const {
227-
return parameters.randstrobe.aux_len;
222+
uint64_t get_main_hash_mask() const {
223+
return parameters.randstrobe.main_hash_mask;
228224
}
229225

230226
private:

src/indexparameters.cpp

+26-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ bool RandstrobeParameters::operator==(const RandstrobeParameters& other) const {
1717
&& this->max_dist == other.max_dist
1818
&& this->w_min == other.w_min
1919
&& this->w_max == other.w_max
20-
&& this->aux_len == other.aux_len;
20+
&& this->main_hash_mask == other.main_hash_mask;
2121
}
2222

2323
/* Pre-defined index parameters that work well for a certain
@@ -78,7 +78,7 @@ IndexParameters IndexParameters::from_read_length(int read_length, int k, int s,
7878
}
7979
int q = std::pow(2, c == DEFAULT ? default_c : c) - 1;
8080
if (aux_len == DEFAULT) {
81-
aux_len = 26;
81+
aux_len = 17;
8282
}
8383

8484
return IndexParameters(canonical_read_length, k, s, l, u, q, max_dist, aux_len);
@@ -92,7 +92,7 @@ void IndexParameters::write(std::ostream& os) const {
9292
write_int_to_ostream(os, randstrobe.w_max);
9393
write_int_to_ostream(os, randstrobe.q);
9494
write_int_to_ostream(os, randstrobe.max_dist);
95-
write_int_to_ostream(os, randstrobe.aux_len);
95+
write_uint64_to_ostream(os, randstrobe.main_hash_mask);
9696
}
9797

9898
IndexParameters IndexParameters::read(std::istream& is) {
@@ -105,8 +105,8 @@ IndexParameters IndexParameters::read(std::istream& is) {
105105
uint32_t w_max = read_int_from_istream(is);
106106
uint64_t q = read_int_from_istream(is);
107107
int max_dist = read_int_from_istream(is);
108-
uint32_t aux_len = read_int_from_istream(is);
109-
const RandstrobeParameters randstrobe_parameters{q, max_dist, w_min, w_max, aux_len};
108+
uint64_t main_hash_mask = read_uint64_from_istream(is);
109+
const RandstrobeParameters randstrobe_parameters{q, max_dist, w_min, w_max, main_hash_mask};
110110

111111
return IndexParameters(canonical_read_length, syncmer_parameters, randstrobe_parameters);
112112
}
@@ -132,6 +132,26 @@ std::string IndexParameters::filename_extension() const {
132132
return sstream.str();
133133
}
134134

135+
std::ostream& operator<<(std::ostream& os, const SyncmerParameters& parameters) {
136+
os << "SyncmerParameters("
137+
<< "k=" << parameters.k
138+
<< ", s=" << parameters.s
139+
<< ", t_syncmer=" << parameters.t_syncmer
140+
<< ")";
141+
return os;
142+
}
143+
144+
std::ostream& operator<<(std::ostream& os, const RandstrobeParameters& parameters) {
145+
os << "RandstrobeParameters("
146+
<< "q=" << parameters.q
147+
<< ", max_dist=" << parameters.max_dist
148+
<< ", w_min=" << parameters.w_min
149+
<< ", w_max=" << parameters.w_max
150+
<< ", main_hash_mask=0x" << std::hex << parameters.main_hash_mask << std::dec
151+
<< ")";
152+
return os;
153+
}
154+
135155
std::ostream& operator<<(std::ostream& os, const IndexParameters& parameters) {
136156
os << "IndexParameters("
137157
<< "r=" << parameters.canonical_read_length
@@ -142,6 +162,7 @@ std::ostream& operator<<(std::ostream& os, const IndexParameters& parameters) {
142162
<< ", max_dist=" << parameters.randstrobe.max_dist
143163
<< ", w_min=" << parameters.randstrobe.w_min
144164
<< ", w_max=" << parameters.randstrobe.w_max
165+
<< ", main_hash_mask=0x" << std::hex << parameters.randstrobe.main_hash_mask << std::dec
145166
<< ")";
146167
return os;
147168
}

src/indexparameters.hpp

+15-7
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,21 @@ struct RandstrobeParameters {
4141
const int max_dist;
4242
const unsigned w_min;
4343
const unsigned w_max;
44-
const unsigned aux_len;
44+
const uint64_t main_hash_mask;
4545

46-
RandstrobeParameters(uint64_t q, int max_dist, unsigned w_min, unsigned w_max, unsigned aux_len)
46+
RandstrobeParameters(uint64_t q, int max_dist, unsigned w_min, unsigned w_max, uint64_t main_hash_mask)
4747
: q(q)
4848
, max_dist(max_dist)
4949
, w_min(w_min)
5050
, w_max(w_max)
51-
, aux_len(aux_len)
51+
, main_hash_mask(main_hash_mask)
5252
{
53-
verify();
5453
}
5554

5655
bool operator==(const RandstrobeParameters& other) const;
5756

5857
private:
59-
void verify() const {
58+
void verify(unsigned aux_len) const {
6059
if (max_dist > 255) {
6160
throw BadParameter("maximum seed length (-m <max_dist>) is larger than 255");
6261
}
@@ -78,11 +77,12 @@ class IndexParameters {
7877

7978
static const int DEFAULT = std::numeric_limits<int>::min();
8079

81-
IndexParameters(size_t canonical_read_length, int k, int s, int l, int u, int q, int max_dist, int aux_len)
80+
IndexParameters(size_t canonical_read_length, int k, int s, int l, int u, uint64_t q, int max_dist, int aux_len)
8281
: canonical_read_length(canonical_read_length)
8382
, syncmer(k, s)
84-
, randstrobe(q, max_dist, std::max(0, k / (k - s + 1) + l), k / (k - s + 1) + u, aux_len)
83+
, randstrobe(q, max_dist, std::max(0, k / (k - s + 1) + l), k / (k - s + 1) + u, ~0ul << (9 + aux_len))
8584
{
85+
verify(aux_len);
8686
}
8787

8888
IndexParameters(size_t canonical_read_length, SyncmerParameters syncmer, RandstrobeParameters randstrobe)
@@ -92,6 +92,12 @@ class IndexParameters {
9292
{
9393
}
9494

95+
void verify(unsigned aux_len) const {
96+
if (aux_len > 27) {
97+
throw BadParameter("aux_len must be less than 28");
98+
}
99+
}
100+
95101
static IndexParameters from_read_length(
96102
int read_length, int k = DEFAULT, int s = DEFAULT, int l = DEFAULT, int u = DEFAULT, int c = DEFAULT, int max_seed_len = DEFAULT, int aux_len = DEFAULT);
97103
static IndexParameters read(std::istream& os);
@@ -101,6 +107,8 @@ class IndexParameters {
101107
bool operator!=(const IndexParameters& other) const { return !(*this == other); }
102108
};
103109

110+
std::ostream& operator<<(std::ostream& os, const SyncmerParameters& parameters);
111+
std::ostream& operator<<(std::ostream& os, const RandstrobeParameters& parameters);
104112
std::ostream& operator<<(std::ostream& os, const IndexParameters& parameters);
105113

106114
#endif

src/io.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,15 @@ int32_t read_int_from_istream(std::istream& is) {
1111
is.read(reinterpret_cast<char*>(&val), sizeof(val));
1212
return val;
1313
}
14+
15+
void write_uint64_to_ostream(std::ostream& os, uint64_t value) {
16+
uint64_t val;
17+
val = value;
18+
os.write(reinterpret_cast<const char*>(&val), sizeof(val));
19+
}
20+
21+
uint64_t read_uint64_from_istream(std::istream& is) {
22+
uint64_t val;
23+
is.read(reinterpret_cast<char*>(&val), sizeof(val));
24+
return val;
25+
}

src/io.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
void write_int_to_ostream(std::ostream& os, int32_t value);
99
int32_t read_int_from_istream(std::istream& is);
1010

11+
void write_uint64_to_ostream(std::ostream& os, uint64_t value);
12+
uint64_t read_uint64_from_istream(std::istream& is);
13+
1114
/* Write a vector to an output stream, preceded by its length */
1215
template <typename T>
1316
void write_vector(std::ostream& os, const std::vector<T>& v) {

src/main.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ int run_strobealign(int argc, char **argv) {
230230
throw InvalidFasta("Too many reference sequences. Current maximum is " + std::to_string(RefRandstrobe::max_number_of_references));
231231
}
232232

233-
logger.debug() << "Auxiliary hash length: " << index_parameters.randstrobe.aux_len << "\n";
233+
logger.debug() << "Auxiliary hash length: " << opt.aux_len << "\n";
234+
logger.debug() << "Base hash mask: " << std::hex << index_parameters.randstrobe.main_hash_mask << std::dec << '\n';
234235
logger.info() << "Using multi-context seeds: " << (map_param.use_mcs ? "yes" : "no") << '\n';
235236
StrobemerIndex index(references, index_parameters, opt.bits);
236237
if (opt.use_index) {

src/nam.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ std::tuple<float, int, std::vector<Nam>> find_nams(
234234
int total_hits = 0;
235235
for (const auto &q : query_randstrobes) {
236236
size_t position = index.find_full(q.hash);
237-
if (position != index.end()){
237+
if (position != index.end()) {
238238
total_hits++;
239239
if (index.is_filtered(position)) {
240240
continue;
@@ -243,7 +243,7 @@ std::tuple<float, int, std::vector<Nam>> find_nams(
243243
add_to_matches_map_full(matches_map[q.is_reverse], q.start, q.end, index, position);
244244
}
245245
else if (use_mcs) {
246-
PartialHit ph{q.hash >> index.get_aux_len(), q.partial_start, q.is_reverse};
246+
PartialHit ph{q.hash & index.get_main_hash_mask(), q.partial_start, q.is_reverse};
247247
if (std::find(partial_queried.begin(), partial_queried.end(), ph) != partial_queried.end()) {
248248
// already queried
249249
continue;
@@ -312,7 +312,7 @@ std::pair<int, std::vector<Nam>> find_nams_rescue(
312312
}
313313
}
314314
else if (use_mcs) {
315-
PartialHit ph = {qr.hash >> index.get_aux_len(), qr.partial_start, qr.is_reverse};
315+
PartialHit ph = {qr.hash & index.get_main_hash_mask(), qr.partial_start, qr.is_reverse};
316316
if (std::find(partial_queried.begin(), partial_queried.end(), ph) != partial_queried.end()) {
317317
// already queried
318318
continue;

0 commit comments

Comments
 (0)