Skip to content

Commit

Permalink
add sample option
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmad88me committed Feb 20, 2023
1 parent 5f00529 commit 9deda18
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 27 deletions.
8 changes: 7 additions & 1 deletion include/entity.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ class EntityAnn {
string get_quoted(string);
string get_taged(string);
std::list<string> *recompute_f(double);
void set_language_tag(string);
string get_title_case(string);
void set_title_case(bool);
bool get_title_case();
void set_language_tag(string tag);
unsigned long get_m();

std::list<string> *annotate_entity_property_column(std::list<std::list<string>*> *, long, long);
void annotate_entity_property_pair(string, string);
std::list<string> *annotate_entity_property_heuristic(std::list<std::list<string>*> *, string, long);
Expand All @@ -96,6 +98,9 @@ class EntityAnn {
//double m_ambiguitity_penalty=2;
double m_ambiguitity_penalty = 1; // no penalty

void set_sample_size(long sample_size);
long get_sample_size();

std::list<string> get_labels_uris();
bool append_label_uri(string);
bool clear_label_uri();
Expand Down Expand Up @@ -123,6 +128,7 @@ class EntityAnn {
unsigned long m_m;
bool m_retry_with_title_case = false;
string m_lang_tag;
long m_sample_size = 0;
void init(string hdt_file_dir, string log_file_dir, double alpha);
void init(hdt::HDT *hdt_ptr, string log_file_dir, double alpha);
};
Expand Down
38 changes: 31 additions & 7 deletions src/entity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,18 @@ std::list<string> *EntityAnn::annotate_column(std::list<std::list<string>*> *dat

if (this->compute_intermediate_coverage(entity, prop, double_levels)) {
m++;

if (m_sample_size > 0 && m >= m_sample_size) {
m_logger->log("annotate_column> sample size is reached");
break;
}

}

delete prop;
}

m_m = m;
// m_m = m;
m_logger->log("annotate_column> m: " + to_string(m));
return this->annotate_semi_scored_column(m);
} else {
Expand All @@ -192,6 +198,11 @@ std::list<string> *EntityAnn::annotate_column(std::list<std::list<string>*> *dat

if (this->compute_intermediate_coverage(l)) {
m++;

if (m_sample_size > 0 && m >= m_sample_size) {
m_logger->log("annotate_column> sample size is reached");
break;
}
}
}

Expand All @@ -200,6 +211,7 @@ std::list<string> *EntityAnn::annotate_column(std::list<std::list<string>*> *dat


std::list<string> *EntityAnn::annotate_semi_scored_column(unsigned long m, double alpha) {
m_m = m;
this->compute_Ic_for_all();
this->compute_Lc_for_all();
this->pick_root();
Expand All @@ -226,8 +238,6 @@ std::list<string> *EntityAnn::annotate_semi_scored_column(unsigned long m) {
}




// Get entities of a given cell value or name using the rdfs:label property
std::list<string> *EntityAnn::get_entities_of_value(string value) {
string qvalue;
Expand Down Expand Up @@ -697,8 +707,6 @@ void EntityAnn::compute_Ic_for_node(TNode *tnode) {
}




void EntityAnn::compute_classes_entities_counts() {
hdt::IteratorTripleString *itt;
unsigned long num_of_entities;
Expand All @@ -721,6 +729,7 @@ void EntityAnn::compute_classes_entities_counts() {
}
}


// include the counts of the childs because HDT does not perform reasoning
void EntityAnn::propagate_counts(TNode *tnode) {
unsigned long count;
Expand Down Expand Up @@ -770,6 +779,7 @@ void EntityAnn::compute_Is_for_all() {
}
}


void EntityAnn::compute_Is_for_node(TNode *tnode) {
TNode *ch = nullptr;
double ch_count;
Expand All @@ -794,6 +804,7 @@ void EntityAnn::compute_Is_for_node(TNode *tnode) {
}
}


void EntityAnn::compute_Ls_for_all() {
std::list<TNode *> *leaves = m_graph->get_leaves();

Expand All @@ -802,6 +813,7 @@ void EntityAnn::compute_Ls_for_all() {
}
}


double EntityAnn::compute_Ls_for_node(TNode *tnode) {
double ls = tnode->ls;
double ls_2, ls_max;
Expand Down Expand Up @@ -834,6 +846,7 @@ double EntityAnn::compute_Ls_for_node(TNode *tnode) {
return ls;
}


void EntityAnn::compute_fs() {
TNode *tnode;

Expand All @@ -844,6 +857,7 @@ void EntityAnn::compute_fs() {
}
}


void EntityAnn::compute_fc(unsigned long m) {
TNode *tnode;
double m_d = static_cast<unsigned long>(m);
Expand All @@ -856,6 +870,7 @@ void EntityAnn::compute_fc(unsigned long m) {
}
}


void EntityAnn::compute_f() {
TNode *tnode;

Expand Down Expand Up @@ -1206,8 +1221,6 @@ unsigned long EntityAnn::get_counts_of_class(string uri) {
return class_num + propagated_num;
}



std::list<string> EntityAnn::get_labels_uris() {
return m_labels_uris;
}
Expand All @@ -1222,4 +1235,15 @@ bool EntityAnn::clear_label_uri() {
return true;
}

void EntityAnn::set_sample_size(long sample_size) {
m_sample_size = sample_size;
}

long EntityAnn::get_sample_size() {
return m_sample_size;
}

unsigned long EntityAnn::get_m() {
return m_m;
}

73 changes: 55 additions & 18 deletions src/tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ TEST(EntityTest, IntermediateScoresMultiClass) {
string label = "golferboxer1";
// entity with two classes
ea->compute_intermediate_coverage(label);
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
tnode = ea->get_tnode(class_uri);
ASSERT_NE(tnode, nullptr);
ASSERT_DOUBLE_EQ(0.5, tnode->tc);
Expand All @@ -149,14 +149,14 @@ TEST(EntityTest, GraphLeaves) {
ea->compute_intermediate_coverage(label);
Graph *graph = ea->get_graph();
leaves = graph->get_leaves();
graph->print_nodes();
//graph->print_nodes();
ASSERT_EQ(leaves->size(), 1);
ASSERT_STREQ(leaves->front()->uri.c_str(), class_uri.c_str());
label = "golferboxer1";
ea->compute_intermediate_coverage(label);
graph = ea->get_graph();
leaves = graph->get_leaves();
graph->print_nodes();
//graph->print_nodes();
ASSERT_EQ(leaves->size(), 2);
delete ea;
}
Expand All @@ -173,7 +173,7 @@ TEST(EntityTest, GraphRoots) {
ea->compute_intermediate_coverage(label);
Graph *graph = ea->get_graph();
roots = graph->get_candidate_roots();
graph->print_nodes();
//graph->print_nodes();
ASSERT_EQ(roots->size(), 1);
ASSERT_STREQ(roots->front()->uri.c_str(), root.c_str());
delete ea;
Expand All @@ -197,7 +197,7 @@ TEST(EntityTest, GraphContruction) {
cout << "leaves: " << (*it)->uri << endl;
}

graph->print_nodes();
//graph->print_nodes();
ASSERT_EQ(leaves->size(), 2);
tnode = graph->get_node(class_uri);
ASSERT_EQ(tnode->parents->size(), 1);
Expand All @@ -220,7 +220,7 @@ TEST(EntityTest, IcLcSingle) {
ea->compute_intermediate_coverage(label);
ea->compute_Ic_for_all();
ea->compute_Lc_for_all();
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
tnode = ea->get_tnode(class_uri);
ASSERT_NE(tnode, nullptr);
ASSERT_EQ(1.0, tnode->lc);
Expand All @@ -240,7 +240,7 @@ TEST(EntityTest, IcLcMulti) {
ea->compute_intermediate_coverage(label);
ea->compute_Ic_for_all();
ea->compute_Lc_for_all();
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
tnode = ea->get_tnode(class_uri);
cout << "class ic: " << tnode->ic << "Lc: " << tnode->lc << endl;
ASSERT_NE(tnode, nullptr);
Expand All @@ -266,13 +266,13 @@ TEST(EntityTest, ClassEntityCounts) {
ea->compute_intermediate_coverage(label);
ea->compute_Ic_for_all();
ea->compute_Lc_for_all();
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
ea->get_graph()->pick_root();
ASSERT_STREQ((dbo_prefix + "Agent").c_str(), ea->get_graph()->get_root()->uri.c_str());
ea->compute_classes_entities_counts();
ea->compute_Is_for_all();
ea->compute_Ls_for_all();
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
boxer_tnode = ea->get_graph()->get_node(dbo_prefix + "Boxer");
amature_tnode = ea->get_graph()->get_node(dbo_prefix + "AmateurBoxer");
agent_tnode = ea->get_graph()->get_node(dbo_prefix + "Agent");
Expand All @@ -294,13 +294,13 @@ TEST(EntityTest, NONExistantLabel) {
ea->compute_intermediate_coverage(label);
ea->compute_Ic_for_all();
ea->compute_Lc_for_all();
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
ea->get_graph()->pick_root();
ASSERT_EQ(ea->get_graph()->get_root(), nullptr);
ea->compute_classes_entities_counts();
ea->compute_Is_for_all();
ea->compute_Ls_for_all();
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
delete ea;
}

Expand All @@ -312,7 +312,22 @@ TEST(EntityTest, Scores) {
Parser p(base_dir + "test_files/test1.csv");
data = p.parse();
candidates = ea->annotate_column(data, 2, 0.1);
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
ASSERT_STREQ(class_uri.c_str(), candidates->front().c_str());
delete ea;
}

TEST(EntityTest, ScoresSample) {
EntityAnn *ea = new EntityAnn(hdt_file, log_file);
std::list<string> *candidates;
string class_uri = dbo_prefix + "Boxer";
std::list<std::list<string>*> *data;
Parser p(base_dir + "test_files/test1.csv");
data = p.parse();
ea->set_sample_size(1);
candidates = ea->annotate_column(data, 2, 0.1);
//ea->get_graph()->print_nodes();
ASSERT_EQ(ea->get_m(), 1);
ASSERT_STREQ(class_uri.c_str(), candidates->front().c_str());
delete ea;
}
Expand All @@ -325,7 +340,7 @@ TEST(EntityTest, ScoresExtraRoot) {
Parser p(base_dir + "test_files/test2.csv");
data = p.parse();
candidates = ea->annotate_column(data, 2, 0.1);
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
ASSERT_STREQ(class_uri.c_str(), candidates->front().c_str());
delete ea;
}
Expand All @@ -346,11 +361,33 @@ TEST(EntityTest, Context) {
ea = new EntityAnn(hdt_file, log_file, 0.1);
data = p.parse_vertical();
candidates = ea->annotate_column(data, 1, true, false);
ea->get_graph()->print_nodes();
// ea->get_graph()->print_nodes();
ASSERT_STREQ(football_class_uri.c_str(), candidates->front().c_str());
delete ea;
}

TEST(EntityTest, ContextSample) {
EntityAnn *ea = new EntityAnn(hdt_file, log_file, 0.1);
std::list<string> *candidates;
string volley_class_uri = dbo_prefix + "VolleyballPlayer";
string football_class_uri = dbo_prefix + "FootballPlayer";
std::list<std::list<string>*> *data;
Parser p(base_dir + "test_files/test3.csv");
data = p.parse();
candidates = ea->annotate_column(data, 1, 0.1);
ASSERT_STREQ(volley_class_uri.c_str(), candidates->front().c_str());
ASSERT_EQ(ea->get_m(), 3);
delete ea;

ea = new EntityAnn(hdt_file, log_file, 0.1);
ea->set_sample_size(2);
candidates = ea->annotate_column(data, 1, 0.1);
ASSERT_STREQ(volley_class_uri.c_str(), candidates->front().c_str());
ASSERT_EQ(ea->get_m(), 2);
delete ea;

}

TEST(EntityTest, DoubleLevel) {
EntityAnn *ea = new EntityAnn(hdt_file, log_file, 0.1);
//EntityAnn* ea = new EntityAnn(hdt_file, log_file);
Expand All @@ -360,7 +397,7 @@ TEST(EntityTest, DoubleLevel) {
Parser p(base_dir + "test_files/test4.csv");
data = p.parse_vertical();
candidates = ea->annotate_column(data, 1, true, true);
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
ASSERT_STREQ(wrestler_class_uri.c_str(), candidates->front().c_str());
delete ea;
}
Expand All @@ -374,7 +411,7 @@ TEST(EntityTest, recomputef) {
Parser p(base_dir + "test_files/test4.csv");
data = p.parse_vertical();
candidates = ea->annotate_column(data, 1, true, true);
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
delete candidates;
candidates = ea->recompute_f(0.1);
ASSERT_STREQ(wrestler_class_uri.c_str(), candidates->front().c_str());
Expand All @@ -390,7 +427,7 @@ TEST(EntityTest, LangTag) {
Parser p(base_dir + "test_files/test5.csv");
data = p.parse_vertical();
candidates = ea->annotate_column(data, 0, true, true);
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
delete candidates;
candidates = ea->recompute_f(0.1);
ASSERT_STREQ(class_uri.c_str(), candidates->front().c_str());
Expand Down Expand Up @@ -418,7 +455,7 @@ TEST(EntityTest, TitleCase) {
Parser p(base_dir + "test_files/test6.csv");
data = p.parse_vertical();
candidates = ea->annotate_column(data, 0, true, true);
ea->get_graph()->print_nodes();
//ea->get_graph()->print_nodes();
delete candidates;
candidates = ea->recompute_f(0.1);
ASSERT_STREQ(class_uri.c_str(), candidates->front().c_str());
Expand Down
2 changes: 1 addition & 1 deletion test_files/test3.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
id,name,city
1,common_name_1,city1
2,common_name_2,city2
3,volleyballp4,city4
2,common_name_2,city2

0 comments on commit 9deda18

Please sign in to comment.