From 9deda1878d6657f2412ff1f60e0c6f3d7ec789c1 Mon Sep 17 00:00:00 2001 From: Ahmad Alobaid Date: Tue, 21 Feb 2023 02:56:13 +0300 Subject: [PATCH] add sample option --- include/entity.h | 8 ++++- src/entity.cpp | 38 ++++++++++++++++++----- src/tests.cpp | 73 +++++++++++++++++++++++++++++++++----------- test_files/test3.csv | 2 +- 4 files changed, 94 insertions(+), 27 deletions(-) diff --git a/include/entity.h b/include/entity.h index c31c672..7580c9e 100644 --- a/include/entity.h +++ b/include/entity.h @@ -76,10 +76,12 @@ class EntityAnn { string get_quoted(string); string get_taged(string); std::list *recompute_f(double); - void set_language_tag(string); string get_title_case(string); void set_title_case(bool); bool get_title_case(); + void set_language_tag(string tag); + unsigned long get_m(); + std::list *annotate_entity_property_column(std::list*> *, long, long); void annotate_entity_property_pair(string, string); std::list *annotate_entity_property_heuristic(std::list*> *, string, long); @@ -96,6 +98,9 @@ class EntityAnn { //double m_ambiguitity_penalty=2; double m_ambiguitity_penalty = 1; // no penalty + void set_sample_size(long sample_size); + long get_sample_size(); + std::list get_labels_uris(); bool append_label_uri(string); bool clear_label_uri(); @@ -123,6 +128,7 @@ class EntityAnn { unsigned long m_m; bool m_retry_with_title_case = false; string m_lang_tag; + long m_sample_size = 0; void init(string hdt_file_dir, string log_file_dir, double alpha); void init(hdt::HDT *hdt_ptr, string log_file_dir, double alpha); }; diff --git a/src/entity.cpp b/src/entity.cpp index 27fe588..1b882eb 100644 --- a/src/entity.cpp +++ b/src/entity.cpp @@ -160,12 +160,18 @@ std::list *EntityAnn::annotate_column(std::list*> *dat if (this->compute_intermediate_coverage(entity, prop, double_levels)) { m++; + + if (m_sample_size > 0 && m >= m_sample_size) { + m_logger->log("annotate_column> sample size is reached"); + break; + } + } delete prop; } - m_m = m; +// m_m = m; m_logger->log("annotate_column> m: " + to_string(m)); return this->annotate_semi_scored_column(m); } else { @@ -192,6 +198,11 @@ std::list *EntityAnn::annotate_column(std::list*> *dat if (this->compute_intermediate_coverage(l)) { m++; + + if (m_sample_size > 0 && m >= m_sample_size) { + m_logger->log("annotate_column> sample size is reached"); + break; + } } } @@ -200,6 +211,7 @@ std::list *EntityAnn::annotate_column(std::list*> *dat std::list *EntityAnn::annotate_semi_scored_column(unsigned long m, double alpha) { + m_m = m; this->compute_Ic_for_all(); this->compute_Lc_for_all(); this->pick_root(); @@ -226,8 +238,6 @@ std::list *EntityAnn::annotate_semi_scored_column(unsigned long m) { } - - // Get entities of a given cell value or name using the rdfs:label property std::list *EntityAnn::get_entities_of_value(string value) { string qvalue; @@ -697,8 +707,6 @@ void EntityAnn::compute_Ic_for_node(TNode *tnode) { } - - void EntityAnn::compute_classes_entities_counts() { hdt::IteratorTripleString *itt; unsigned long num_of_entities; @@ -721,6 +729,7 @@ void EntityAnn::compute_classes_entities_counts() { } } + // include the counts of the childs because HDT does not perform reasoning void EntityAnn::propagate_counts(TNode *tnode) { unsigned long count; @@ -770,6 +779,7 @@ void EntityAnn::compute_Is_for_all() { } } + void EntityAnn::compute_Is_for_node(TNode *tnode) { TNode *ch = nullptr; double ch_count; @@ -794,6 +804,7 @@ void EntityAnn::compute_Is_for_node(TNode *tnode) { } } + void EntityAnn::compute_Ls_for_all() { std::list *leaves = m_graph->get_leaves(); @@ -802,6 +813,7 @@ void EntityAnn::compute_Ls_for_all() { } } + double EntityAnn::compute_Ls_for_node(TNode *tnode) { double ls = tnode->ls; double ls_2, ls_max; @@ -834,6 +846,7 @@ double EntityAnn::compute_Ls_for_node(TNode *tnode) { return ls; } + void EntityAnn::compute_fs() { TNode *tnode; @@ -844,6 +857,7 @@ void EntityAnn::compute_fs() { } } + void EntityAnn::compute_fc(unsigned long m) { TNode *tnode; double m_d = static_cast(m); @@ -856,6 +870,7 @@ void EntityAnn::compute_fc(unsigned long m) { } } + void EntityAnn::compute_f() { TNode *tnode; @@ -1206,8 +1221,6 @@ unsigned long EntityAnn::get_counts_of_class(string uri) { return class_num + propagated_num; } - - std::list EntityAnn::get_labels_uris() { return m_labels_uris; } @@ -1222,4 +1235,15 @@ bool EntityAnn::clear_label_uri() { return true; } +void EntityAnn::set_sample_size(long sample_size) { + m_sample_size = sample_size; +} + +long EntityAnn::get_sample_size() { + return m_sample_size; +} + +unsigned long EntityAnn::get_m() { + return m_m; +} diff --git a/src/tests.cpp b/src/tests.cpp index 9040bd2..24f51af 100644 --- a/src/tests.cpp +++ b/src/tests.cpp @@ -131,7 +131,7 @@ TEST(EntityTest, IntermediateScoresMultiClass) { string label = "golferboxer1"; // entity with two classes ea->compute_intermediate_coverage(label); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); tnode = ea->get_tnode(class_uri); ASSERT_NE(tnode, nullptr); ASSERT_DOUBLE_EQ(0.5, tnode->tc); @@ -149,14 +149,14 @@ TEST(EntityTest, GraphLeaves) { ea->compute_intermediate_coverage(label); Graph *graph = ea->get_graph(); leaves = graph->get_leaves(); - graph->print_nodes(); + //graph->print_nodes(); ASSERT_EQ(leaves->size(), 1); ASSERT_STREQ(leaves->front()->uri.c_str(), class_uri.c_str()); label = "golferboxer1"; ea->compute_intermediate_coverage(label); graph = ea->get_graph(); leaves = graph->get_leaves(); - graph->print_nodes(); + //graph->print_nodes(); ASSERT_EQ(leaves->size(), 2); delete ea; } @@ -173,7 +173,7 @@ TEST(EntityTest, GraphRoots) { ea->compute_intermediate_coverage(label); Graph *graph = ea->get_graph(); roots = graph->get_candidate_roots(); - graph->print_nodes(); + //graph->print_nodes(); ASSERT_EQ(roots->size(), 1); ASSERT_STREQ(roots->front()->uri.c_str(), root.c_str()); delete ea; @@ -197,7 +197,7 @@ TEST(EntityTest, GraphContruction) { cout << "leaves: " << (*it)->uri << endl; } - graph->print_nodes(); + //graph->print_nodes(); ASSERT_EQ(leaves->size(), 2); tnode = graph->get_node(class_uri); ASSERT_EQ(tnode->parents->size(), 1); @@ -220,7 +220,7 @@ TEST(EntityTest, IcLcSingle) { ea->compute_intermediate_coverage(label); ea->compute_Ic_for_all(); ea->compute_Lc_for_all(); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); tnode = ea->get_tnode(class_uri); ASSERT_NE(tnode, nullptr); ASSERT_EQ(1.0, tnode->lc); @@ -240,7 +240,7 @@ TEST(EntityTest, IcLcMulti) { ea->compute_intermediate_coverage(label); ea->compute_Ic_for_all(); ea->compute_Lc_for_all(); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); tnode = ea->get_tnode(class_uri); cout << "class ic: " << tnode->ic << "Lc: " << tnode->lc << endl; ASSERT_NE(tnode, nullptr); @@ -266,13 +266,13 @@ TEST(EntityTest, ClassEntityCounts) { ea->compute_intermediate_coverage(label); ea->compute_Ic_for_all(); ea->compute_Lc_for_all(); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); ea->get_graph()->pick_root(); ASSERT_STREQ((dbo_prefix + "Agent").c_str(), ea->get_graph()->get_root()->uri.c_str()); ea->compute_classes_entities_counts(); ea->compute_Is_for_all(); ea->compute_Ls_for_all(); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); boxer_tnode = ea->get_graph()->get_node(dbo_prefix + "Boxer"); amature_tnode = ea->get_graph()->get_node(dbo_prefix + "AmateurBoxer"); agent_tnode = ea->get_graph()->get_node(dbo_prefix + "Agent"); @@ -294,13 +294,13 @@ TEST(EntityTest, NONExistantLabel) { ea->compute_intermediate_coverage(label); ea->compute_Ic_for_all(); ea->compute_Lc_for_all(); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); ea->get_graph()->pick_root(); ASSERT_EQ(ea->get_graph()->get_root(), nullptr); ea->compute_classes_entities_counts(); ea->compute_Is_for_all(); ea->compute_Ls_for_all(); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); delete ea; } @@ -312,7 +312,22 @@ TEST(EntityTest, Scores) { Parser p(base_dir + "test_files/test1.csv"); data = p.parse(); candidates = ea->annotate_column(data, 2, 0.1); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); + ASSERT_STREQ(class_uri.c_str(), candidates->front().c_str()); + delete ea; +} + +TEST(EntityTest, ScoresSample) { + EntityAnn *ea = new EntityAnn(hdt_file, log_file); + std::list *candidates; + string class_uri = dbo_prefix + "Boxer"; + std::list*> *data; + Parser p(base_dir + "test_files/test1.csv"); + data = p.parse(); + ea->set_sample_size(1); + candidates = ea->annotate_column(data, 2, 0.1); + //ea->get_graph()->print_nodes(); + ASSERT_EQ(ea->get_m(), 1); ASSERT_STREQ(class_uri.c_str(), candidates->front().c_str()); delete ea; } @@ -325,7 +340,7 @@ TEST(EntityTest, ScoresExtraRoot) { Parser p(base_dir + "test_files/test2.csv"); data = p.parse(); candidates = ea->annotate_column(data, 2, 0.1); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); ASSERT_STREQ(class_uri.c_str(), candidates->front().c_str()); delete ea; } @@ -346,11 +361,33 @@ TEST(EntityTest, Context) { ea = new EntityAnn(hdt_file, log_file, 0.1); data = p.parse_vertical(); candidates = ea->annotate_column(data, 1, true, false); - ea->get_graph()->print_nodes(); +// ea->get_graph()->print_nodes(); ASSERT_STREQ(football_class_uri.c_str(), candidates->front().c_str()); delete ea; } +TEST(EntityTest, ContextSample) { + EntityAnn *ea = new EntityAnn(hdt_file, log_file, 0.1); + std::list *candidates; + string volley_class_uri = dbo_prefix + "VolleyballPlayer"; + string football_class_uri = dbo_prefix + "FootballPlayer"; + std::list*> *data; + Parser p(base_dir + "test_files/test3.csv"); + data = p.parse(); + candidates = ea->annotate_column(data, 1, 0.1); + ASSERT_STREQ(volley_class_uri.c_str(), candidates->front().c_str()); + ASSERT_EQ(ea->get_m(), 3); + delete ea; + + ea = new EntityAnn(hdt_file, log_file, 0.1); + ea->set_sample_size(2); + candidates = ea->annotate_column(data, 1, 0.1); + ASSERT_STREQ(volley_class_uri.c_str(), candidates->front().c_str()); + ASSERT_EQ(ea->get_m(), 2); + delete ea; + +} + TEST(EntityTest, DoubleLevel) { EntityAnn *ea = new EntityAnn(hdt_file, log_file, 0.1); //EntityAnn* ea = new EntityAnn(hdt_file, log_file); @@ -360,7 +397,7 @@ TEST(EntityTest, DoubleLevel) { Parser p(base_dir + "test_files/test4.csv"); data = p.parse_vertical(); candidates = ea->annotate_column(data, 1, true, true); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); ASSERT_STREQ(wrestler_class_uri.c_str(), candidates->front().c_str()); delete ea; } @@ -374,7 +411,7 @@ TEST(EntityTest, recomputef) { Parser p(base_dir + "test_files/test4.csv"); data = p.parse_vertical(); candidates = ea->annotate_column(data, 1, true, true); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); delete candidates; candidates = ea->recompute_f(0.1); ASSERT_STREQ(wrestler_class_uri.c_str(), candidates->front().c_str()); @@ -390,7 +427,7 @@ TEST(EntityTest, LangTag) { Parser p(base_dir + "test_files/test5.csv"); data = p.parse_vertical(); candidates = ea->annotate_column(data, 0, true, true); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); delete candidates; candidates = ea->recompute_f(0.1); ASSERT_STREQ(class_uri.c_str(), candidates->front().c_str()); @@ -418,7 +455,7 @@ TEST(EntityTest, TitleCase) { Parser p(base_dir + "test_files/test6.csv"); data = p.parse_vertical(); candidates = ea->annotate_column(data, 0, true, true); - ea->get_graph()->print_nodes(); + //ea->get_graph()->print_nodes(); delete candidates; candidates = ea->recompute_f(0.1); ASSERT_STREQ(class_uri.c_str(), candidates->front().c_str()); diff --git a/test_files/test3.csv b/test_files/test3.csv index 2bdb013..b4b14b2 100644 --- a/test_files/test3.csv +++ b/test_files/test3.csv @@ -1,4 +1,4 @@ id,name,city 1,common_name_1,city1 -2,common_name_2,city2 3,volleyballp4,city4 +2,common_name_2,city2