diff --git a/.env b/.env index 0bf600b..7f8f2a8 100644 --- a/.env +++ b/.env @@ -1,6 +1,7 @@ # change this to a directory on your local machine to store pubmed articles PUBMED_DIR=/path/to/pubmed/folder -NEO4J_DIR=/path/to/neo4j/folder +NEO4J_KG_DIR=/path/to/neo4j/folder +NEO4J_SEMMEDDB_DIR=/path/to/neo4j/folder # password hash (password is 'password' by default; to change it, you need # to generate a hash yourself using bcrypt and put it here) diff --git a/docker-compose.yml b/docker-compose.yml index 273c99e..e6bc1fc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,7 +18,7 @@ services: context: . dockerfile: ./src/workers/Dockerfile image: fast_km-worker:build - command: --workers 2 --high_priority 1 --medium_priority 0 --low_priority 0 + command: --workers 2 --high_priority 1 --medium_priority 0 --low_priority 0 --neo4j_address neo4j_kg:7687,neo4j_semmeddb:7687 volumes: - ${PUBMED_DIR}:/mnt/pubmed # edit .env file to change pubmed dir depends_on: @@ -37,16 +37,29 @@ services: networks: - fast_km-network - neo4j: + neo4j_kg: image: neo4j:4.4.16 environment: NEO4J_AUTH: neo4j/mypassword volumes: - - ${NEO4J_DIR}/data:/data - - ${NEO4J_DIR}/logs:/logs - ports: - - "7474:7474" - - "7687:7687" + - ${NEO4J_KG_DIR}/data:/data + - ${NEO4J_KG_DIR}/logs:/logs + expose: + - 7474 + - 7687 + networks: + - fast_km-network + + neo4j_semmeddb: + image: neo4j:4.4.16 + environment: + NEO4J_AUTH: neo4j/mypassword + volumes: + - ${NEO4J_SEMMEDDB_DIR}/data:/data + - ${NEO4J_SEMMEDDB_DIR}/logs:/logs + expose: + - 7474 + - 7687 networks: - fast_km-network diff --git a/src/indexing/km_util.py b/src/indexing/km_util.py index e6643d9..24271bc 100644 --- a/src/indexing/km_util.py +++ b/src/indexing/km_util.py @@ -4,7 +4,7 @@ redis_host = 'redis' mongo_host = 'mongo' -neo4j_host = 'neo4j' +neo4j_host = ['neo4j:7687'] # overridden in run_worker.py tokenizer = nltk.RegexpTokenizer(r"\w+") encoding = 'utf-8' @@ -74,8 +74,8 @@ def get_index_file(abstracts_dir: str) -> str: def get_cataloged_files(abstracts_dir: str) -> str: return os.path.join(get_index_dir(abstracts_dir), 'cataloged.txt') -def get_knowledge_graph_node_id_index(abstracts_dir: str) -> str: - return os.path.join(get_index_dir(abstracts_dir), 'kg_node_ids.txt') +def get_knowledge_graph_node_id_index(abstracts_dir: str, graph_name: str) -> str: + return os.path.join(get_index_dir(abstracts_dir), graph_name + '_node_ids.txt') def get_icite_file(abstracts_dir: str) -> str: return os.path.join(get_index_dir(abstracts_dir), 'icite.json') \ No newline at end of file diff --git a/src/knowledge_graph/knowledge_graph.py b/src/knowledge_graph/knowledge_graph.py index ec334a2..e722f4a 100644 --- a/src/knowledge_graph/knowledge_graph.py +++ b/src/knowledge_graph/knowledge_graph.py @@ -6,7 +6,6 @@ import indexing.index as index import workers.loaded_index as li -neo4j_port = "7687" user = "neo4j" password = "mypassword" @@ -15,29 +14,34 @@ max_synonyms = 9999 class KnowledgeGraph: - def __init__(self): + def __init__(self, url: str): self.query_cache = dict() self.node_ids = dict() + self.graph_name = 'neo4j' try: - uri="bolt://" + util.neo4j_host + ":" + neo4j_port + uri="bolt://" + url self.graph = Graph(uri, auth=(user, password)) except: self.graph = None - print('WARNING: Could not find a neo4j knowledge graph database. knowledge graph will be unavailable.') + print('WARNING: Could not find a neo4j knowledge graph database at ' + uri + '; knowledge graph will be unavailable.') return try: - kg_ids = util.get_knowledge_graph_node_id_index(li.pubmed_path) - if kg_ids: + self.graph_name = url.split(':')[0] + kg_ids = util.get_knowledge_graph_node_id_index(li.pubmed_path, self.graph_name) + if os.path.exists(kg_ids): + self.load_node_id_index(kg_ids) + else: + self.write_node_id_index() self.load_node_id_index(kg_ids) except: self.node_ids = dict() - print('WARNING: Problem loading graph node IDs. knowledge graph queries may be slower than normal.') + print('WARNING: Problem loading graph node IDs (expected path ' + kg_ids + '). knowledge graph queries may be slower than normal.') def query(self, a_term: str, b_term: str, censor_year = None): if not self.graph: - return [{'a_term': a_term, 'a_type': '', 'relationship': 'neo4j connection error', 'b_term': b_term, 'b_type': '', 'pmids': []}] + return [self._construct_rel_response(a_term, '', b_term, '', 'neo4j connection error', [], self.graph_name)] if index.logical_and in a_term or index.logical_and in b_term: return [self._null_rel_response(a_term, b_term)] @@ -112,7 +116,7 @@ def query(self, a_term: str, b_term: str, censor_year = None): else: pmids = relation['pmids'] - relation_json = {'a_term': node1_name, 'a_type': node1_type, 'relationship': relationship, 'b_term': node2_name, 'b_type': node2_type, 'pmids': pmids[:100]} + relation_json = self._construct_rel_response(node1_name, node1_type, node2_name, node2_type, relationship, pmids[:100], self.graph_name) result.append(relation_json) if not result: @@ -121,7 +125,9 @@ def query(self, a_term: str, b_term: str, censor_year = None): self.query_cache[sanitized_ab_tuple] = result return result - def write_node_id_index(self, path: str): + def write_node_id_index(self): + path = util.get_knowledge_graph_node_id_index(li.pubmed_path, self.graph_name) + dir = os.path.dirname(path) if not os.path.exists(dir): @@ -237,7 +243,10 @@ def _post_rels(self, rels: dict): merge_relationships(self.graph.auto(), batch, r_type, start_node_key=(n1_type, "name"), end_node_key=(n2_type, "name")) def _null_rel_response(self, a_term, b_term): - return {'a_term': a_term, 'a_type': '', 'relationship': '', 'b_term': b_term, 'b_type': '', 'pmids': []} + return self._construct_rel_response(a_term, '', b_term, '', '', [], self.graph_name) + + def _construct_rel_response(self, a_term: str, a_type: str, b_term: str, b_type: str, relationship: str, pmids: list, source: str): + return {'a_term': a_term, 'a_type': a_type, 'relationship': relationship, 'b_term': b_term, 'b_type': b_type, 'pmids': pmids, 'source': source} def _sanitize_txt(term: str): subterms = set() diff --git a/src/run_worker.py b/src/run_worker.py index f5ab4e8..c7451a6 100644 --- a/src/run_worker.py +++ b/src/run_worker.py @@ -10,6 +10,7 @@ parser.add_argument('--high_priority', default=0, required=False) parser.add_argument('--medium_priority', default=0, required=False) parser.add_argument('--low_priority', default=0, required=False) +parser.add_argument('--neo4j_address', default='neo4j:7687', required=False) args = parser.parse_args() def start_workers(do_multiprocessing = True): @@ -17,6 +18,7 @@ def start_workers(do_multiprocessing = True): high_priority = int(args.high_priority) medium_priority = int(args.medium_priority) low_priority = int(args.low_priority) + km_util.neo4j_host = [x.strip() for x in args.neo4j_address.split(',')] if do_multiprocessing: worker_processes = [] diff --git a/src/workers/km_worker.py b/src/workers/km_worker.py index c6b896c..e7f8a6a 100644 --- a/src/workers/km_worker.py +++ b/src/workers/km_worker.py @@ -9,7 +9,7 @@ class KmWorker(Worker): def __init__(self, queues=None, *args, **kwargs): super().__init__(queues, *args, **kwargs) -def start_worker(queues: 'list[str]' = [km_util.JobPriority.MEDIUM.name]): +def start_worker(queues: 'list[str]' = [km_util.JobPriority.MEDIUM.name], neo4j_addresses: 'list[str]' = ['neo4j']): print('INFO: worker sleeping for 5 sec before starting...') time.sleep(5) diff --git a/src/workers/work.py b/src/workers/work.py index 75ed317..e413067 100644 --- a/src/workers/work.py +++ b/src/workers/work.py @@ -17,7 +17,7 @@ def km_work(json: list): _initialize_mongo_caching() - knowledge_graph = connect_to_neo4j() + knowledge_graphs = connect_to_neo4j() return_val = [] @@ -48,8 +48,11 @@ def km_work(json: list): query_kg = bool(item['query_knowledge_graph']) if query_kg and res['pvalue'] < rel_pvalue_cutoff: - rel = knowledge_graph.query(a_term, b_term) - res['relationship'] = rel + res['relationship'] = [] + + for kg in knowledge_graphs: + rel = kg.query(a_term, b_term) + res['relationship'].extend(rel) return_val.append(res) @@ -57,7 +60,7 @@ def km_work(json: list): def km_work_all_vs_all(json: dict): _initialize_mongo_caching() - knowledge_graph = connect_to_neo4j() + knowledge_graphs = connect_to_neo4j() return_val = [] km_only = False @@ -171,8 +174,11 @@ def km_work_all_vs_all(json: dict): if query_kg: if abc_result['ab_pvalue'] < _rel_pvalue_cutoff: - rel = knowledge_graph.query(abc_result['a_term'], abc_result['b_term'], censor_year) - abc_result['ab_relationship'] = rel + abc_result['ab_relationship'] = [] + + for kg in knowledge_graphs: + rel = kg.query(abc_result['a_term'], abc_result['b_term'], censor_year) + abc_result['ab_relationship'].extend(rel) else: abc_result['ab_relationship'] = None @@ -197,8 +203,11 @@ def km_work_all_vs_all(json: dict): if query_kg: if abc_result['bc_pvalue'] < _rel_pvalue_cutoff: - rel = knowledge_graph.query(abc_result['b_term'], abc_result['c_term'], censor_year) - abc_result['bc_relationship'] = rel + abc_result['bc_relationship'] = [] + + for kg in knowledge_graphs: + rel = kg.query(abc_result['b_term'], abc_result['c_term'], censor_year) + abc_result['bc_relationship'].extend(rel) else: abc_result['bc_relationship'] = None @@ -326,8 +335,11 @@ def _initialize_mongo_caching(): # such as 'fever' to save the current state of the index li.the_index._check_if_mongo_should_be_refreshed() -def connect_to_neo4j(): - return KnowledgeGraph() +def connect_to_neo4j() -> 'list[KnowledgeGraph]': + graphs = [] + for url in km_util.neo4j_host: + graphs.append(KnowledgeGraph(url)) + return graphs def _queue_jobs(jobs): for job in jobs: