Skip to content

Commit 05430e5

Browse files
committed
[gbdt] enhance error handling for forced splits file loading
1 parent d02a01a commit 05430e5

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

src/boosting/gbdt.cpp

+30-11
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,19 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
8383
// load forced_splits file
8484
if (!config->forcedsplits_filename.empty()) {
8585
std::ifstream forced_splits_file(config->forcedsplits_filename.c_str());
86-
std::stringstream buffer;
87-
buffer << forced_splits_file.rdbuf();
88-
std::string err;
89-
forced_splits_json_ = Json::parse(buffer.str(), &err);
86+
if (!forced_splits_file.good()) {
87+
Log::Warning("Forced splits file '%s' does not exist. Forced splits will be ignored.",
88+
config->forcedsplits_filename.c_str());
89+
} else {
90+
std::stringstream buffer;
91+
buffer << forced_splits_file.rdbuf();
92+
std::string err;
93+
forced_splits_json_ = Json::parse(buffer.str(), &err);
94+
if (!err.empty()) {
95+
Log::Fatal("Failed to parse forced splits file '%s': %s",
96+
config->forcedsplits_filename.c_str(), err.c_str());
97+
}
98+
}
9099
}
91100

92101
objective_function_ = objective_function;
@@ -823,13 +832,23 @@ void GBDT::ResetConfig(const Config* config) {
823832
if (config_.get() != nullptr && config_->forcedsplits_filename != new_config->forcedsplits_filename) {
824833
// load forced_splits file
825834
if (!new_config->forcedsplits_filename.empty()) {
826-
std::ifstream forced_splits_file(
827-
new_config->forcedsplits_filename.c_str());
828-
std::stringstream buffer;
829-
buffer << forced_splits_file.rdbuf();
830-
std::string err;
831-
forced_splits_json_ = Json::parse(buffer.str(), &err);
832-
tree_learner_->SetForcedSplit(&forced_splits_json_);
835+
std::ifstream forced_splits_file(new_config->forcedsplits_filename.c_str());
836+
if (!forced_splits_file.good()) {
837+
Log::Warning("Forced splits file '%s' does not exist. Forced splits will be ignored.",
838+
new_config->forcedsplits_filename.c_str());
839+
forced_splits_json_ = Json();
840+
tree_learner_->SetForcedSplit(nullptr);
841+
} else {
842+
std::stringstream buffer;
843+
buffer << forced_splits_file.rdbuf();
844+
std::string err;
845+
forced_splits_json_ = Json::parse(buffer.str(), &err);
846+
if (!err.empty()) {
847+
Log::Fatal("Failed to parse forced splits file '%s': %s",
848+
new_config->forcedsplits_filename.c_str(), err.c_str());
849+
}
850+
tree_learner_->SetForcedSplit(&forced_splits_json_);
851+
}
833852
} else {
834853
forced_splits_json_ = Json();
835854
tree_learner_->SetForcedSplit(nullptr);

0 commit comments

Comments
 (0)