Skip to content

Commit 97d34d7

Browse files
committed
store original query information
store relative data indices for pairs store only original instead of paired labels
1 parent c40965a commit 97d34d7

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

include/LightGBM/dataset.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,8 @@ class Metadata {
387387
data_size_t num_weights_;
388388
/*! \brief Number of positions, used to check correct position file */
389389
data_size_t num_positions_;
390-
/*! \brief Label data */
390+
/*! \brief Label data. In pairwise ranking, the label_ refer to the labels of the original unpaired dataset. */
391391
std::vector<label_t> label_;
392-
/*! \brief Paired label data for pairwise lambdarank */
393-
std::vector<label_t> paired_label_;
394392
/*! \brief Weights data */
395393
std::vector<label_t> weights_;
396394
/*! \brief Positions data */
@@ -399,6 +397,8 @@ class Metadata {
399397
std::vector<std::string> position_ids_;
400398
/*! \brief Query boundaries */
401399
std::vector<data_size_t> query_boundaries_;
400+
/*! \brief Original query boundaries, used in pairwise ranking */
401+
std::vector<data_size_t> original_query_boundaries_;
402402
/*! \brief Query weights */
403403
std::vector<label_t> query_weights_;
404404
/*! \brief Number of querys */

src/io/metadata.cpp

+23-8
Original file line numberDiff line numberDiff line change
@@ -857,11 +857,28 @@ data_size_t Metadata::BuildPairwiseFeatureRanking(const Metadata& metadata) {
857857
num_data_ = 0;
858858
num_queries_ = metadata.num_queries();
859859
label_.clear();
860-
paired_label_.clear();
861860
if (pairwise_ranking_mode_ == PairwiseRankingMode::kRelevance) {
862-
const label_t* labels = metadata.label();
861+
const label_t* original_label = metadata.label();
863862
paired_ranking_item_index_map_.clear();
864863
const data_size_t* query_boundaries = metadata.query_boundaries();
864+
865+
// backup original query boundaries
866+
original_query_boundaries_.clear();
867+
original_query_boundaries_.resize(num_queries_);
868+
const int num_threads = OMP_NUM_THREADS();
869+
#pragma omp parallel for schedule(static) num_threads(num_threads) if (num_queries_ >= 1024)
870+
for (data_size_t i = 0; i < num_queries_; ++i) {
871+
original_query_boundaries_[i] = query_boundaries[i];
872+
}
873+
874+
// copy labels
875+
const data_size_t original_num_data = query_boundaries[num_queries_];
876+
label_.resize(original_num_data);
877+
#pragma omp parallel for schedule(static) num_threads(num_threads) if (original_num_data >= 1024)
878+
for (data_size_t i = 0; i < original_num_data; ++i) {
879+
label_[i] = original_label[i];
880+
}
881+
865882
data_size_t num_pairs_in_query = 0;
866883
query_boundaries_.clear();
867884
query_boundaries_.push_back(0);
@@ -870,16 +887,14 @@ data_size_t Metadata::BuildPairwiseFeatureRanking(const Metadata& metadata) {
870887
const data_size_t query_start = query_boundaries[query_index];
871888
const data_size_t query_end = query_boundaries[query_index + 1];
872889
for (data_size_t item_index_i = query_start; item_index_i < query_end; ++item_index_i) {
873-
const label_t label_i = labels[item_index_i];
890+
const label_t label_i = label_[item_index_i];
874891
for (data_size_t item_index_j = query_start; item_index_j < query_end; ++item_index_j) {
875892
if (item_index_i == item_index_j) {
876893
continue;
877894
}
878-
const label_t label_j = labels[item_index_j];
879-
label_.push_back(label_i);
880-
paired_label_.push_back(label_j);
895+
const label_t label_j = label_[item_index_j];
881896
if (label_i != label_j) {
882-
paired_ranking_item_index_map_.push_back(std::pair<data_size_t, data_size_t>{item_index_i, item_index_j});
897+
paired_ranking_item_index_map_.push_back(std::pair<data_size_t, data_size_t>{item_index_i - query_start, item_index_j - query_start});
883898
++num_pairs_in_query;
884899
++num_data_;
885900
}
@@ -894,7 +909,7 @@ data_size_t Metadata::BuildPairwiseFeatureRanking(const Metadata& metadata) {
894909
// TODO(shiyu1994)
895910
Log::Fatal("Not implemented.");
896911
}
897-
912+
898913
return num_data_;
899914
}
900915

0 commit comments

Comments
 (0)