Skip to content

Commit

Permalink
support for multiple neo4j containers (#55)
Browse files Browse the repository at this point in the history
* fix bug in PMID set construction with & symbol

* support multiple neo4j containers

* return source with rel

* fix KG querying in work code

* fix type issue

* fix typo
  • Loading branch information
rmillikin authored Aug 31, 2023
1 parent 2e1b315 commit 112161b
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 33 deletions.
3 changes: 2 additions & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# change this to a directory on your local machine to store pubmed articles
PUBMED_DIR=/path/to/pubmed/folder
NEO4J_DIR=/path/to/neo4j/folder
NEO4J_KG_DIR=/path/to/neo4j/folder
NEO4J_SEMMEDDB_DIR=/path/to/neo4j/folder

# password hash (password is 'password' by default; to change it, you need
# to generate a hash yourself using bcrypt and put it here)
Expand Down
27 changes: 20 additions & 7 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ services:
context: .
dockerfile: ./src/workers/Dockerfile
image: fast_km-worker:build
command: --workers 2 --high_priority 1 --medium_priority 0 --low_priority 0
command: --workers 2 --high_priority 1 --medium_priority 0 --low_priority 0 --neo4j_address neo4j_kg:7687,neo4j_semmeddb:7687
volumes:
- ${PUBMED_DIR}:/mnt/pubmed # edit .env file to change pubmed dir
depends_on:
Expand All @@ -37,16 +37,29 @@ services:
networks:
- fast_km-network

neo4j:
neo4j_kg:
image: neo4j:4.4.16
environment:
NEO4J_AUTH: neo4j/mypassword
volumes:
- ${NEO4J_DIR}/data:/data
- ${NEO4J_DIR}/logs:/logs
ports:
- "7474:7474"
- "7687:7687"
- ${NEO4J_KG_DIR}/data:/data
- ${NEO4J_KG_DIR}/logs:/logs
expose:
- 7474
- 7687
networks:
- fast_km-network

neo4j_semmeddb:
image: neo4j:4.4.16
environment:
NEO4J_AUTH: neo4j/mypassword
volumes:
- ${NEO4J_SEMMEDDB_DIR}/data:/data
- ${NEO4J_SEMMEDDB_DIR}/logs:/logs
expose:
- 7474
- 7687
networks:
- fast_km-network

Expand Down
6 changes: 3 additions & 3 deletions src/indexing/km_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

redis_host = 'redis'
mongo_host = 'mongo'
neo4j_host = 'neo4j'
neo4j_host = ['neo4j:7687'] # overridden in run_worker.py
tokenizer = nltk.RegexpTokenizer(r"\w+")
encoding = 'utf-8'

Expand Down Expand Up @@ -74,8 +74,8 @@ def get_index_file(abstracts_dir: str) -> str:
def get_cataloged_files(abstracts_dir: str) -> str:
return os.path.join(get_index_dir(abstracts_dir), 'cataloged.txt')

def get_knowledge_graph_node_id_index(abstracts_dir: str) -> str:
return os.path.join(get_index_dir(abstracts_dir), 'kg_node_ids.txt')
def get_knowledge_graph_node_id_index(abstracts_dir: str, graph_name: str) -> str:
return os.path.join(get_index_dir(abstracts_dir), graph_name + '_node_ids.txt')

def get_icite_file(abstracts_dir: str) -> str:
return os.path.join(get_index_dir(abstracts_dir), 'icite.json')
31 changes: 20 additions & 11 deletions src/knowledge_graph/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import indexing.index as index
import workers.loaded_index as li

neo4j_port = "7687"
user = "neo4j"
password = "mypassword"

Expand All @@ -15,29 +14,34 @@
max_synonyms = 9999

class KnowledgeGraph:
def __init__(self):
def __init__(self, url: str):
self.query_cache = dict()
self.node_ids = dict()
self.graph_name = 'neo4j'

try:
uri="bolt://" + util.neo4j_host + ":" + neo4j_port
uri="bolt://" + url
self.graph = Graph(uri, auth=(user, password))
except:
self.graph = None
print('WARNING: Could not find a neo4j knowledge graph database. knowledge graph will be unavailable.')
print('WARNING: Could not find a neo4j knowledge graph database at ' + uri + '; knowledge graph will be unavailable.')
return

try:
kg_ids = util.get_knowledge_graph_node_id_index(li.pubmed_path)
if kg_ids:
self.graph_name = url.split(':')[0]
kg_ids = util.get_knowledge_graph_node_id_index(li.pubmed_path, self.graph_name)
if os.path.exists(kg_ids):
self.load_node_id_index(kg_ids)
else:
self.write_node_id_index()
self.load_node_id_index(kg_ids)
except:
self.node_ids = dict()
print('WARNING: Problem loading graph node IDs. knowledge graph queries may be slower than normal.')
print('WARNING: Problem loading graph node IDs (expected path ' + kg_ids + '). knowledge graph queries may be slower than normal.')

def query(self, a_term: str, b_term: str, censor_year = None):
if not self.graph:
return [{'a_term': a_term, 'a_type': '', 'relationship': 'neo4j connection error', 'b_term': b_term, 'b_type': '', 'pmids': []}]
return [self._construct_rel_response(a_term, '', b_term, '', 'neo4j connection error', [], self.graph_name)]

if index.logical_and in a_term or index.logical_and in b_term:
return [self._null_rel_response(a_term, b_term)]
Expand Down Expand Up @@ -112,7 +116,7 @@ def query(self, a_term: str, b_term: str, censor_year = None):
else:
pmids = relation['pmids']

relation_json = {'a_term': node1_name, 'a_type': node1_type, 'relationship': relationship, 'b_term': node2_name, 'b_type': node2_type, 'pmids': pmids[:100]}
relation_json = self._construct_rel_response(node1_name, node1_type, node2_name, node2_type, relationship, pmids[:100], self.graph_name)
result.append(relation_json)

if not result:
Expand All @@ -121,7 +125,9 @@ def query(self, a_term: str, b_term: str, censor_year = None):
self.query_cache[sanitized_ab_tuple] = result
return result

def write_node_id_index(self, path: str):
def write_node_id_index(self):
path = util.get_knowledge_graph_node_id_index(li.pubmed_path, self.graph_name)

dir = os.path.dirname(path)

if not os.path.exists(dir):
Expand Down Expand Up @@ -237,7 +243,10 @@ def _post_rels(self, rels: dict):
merge_relationships(self.graph.auto(), batch, r_type, start_node_key=(n1_type, "name"), end_node_key=(n2_type, "name"))

def _null_rel_response(self, a_term, b_term):
return {'a_term': a_term, 'a_type': '', 'relationship': '', 'b_term': b_term, 'b_type': '', 'pmids': []}
return self._construct_rel_response(a_term, '', b_term, '', '', [], self.graph_name)

def _construct_rel_response(self, a_term: str, a_type: str, b_term: str, b_type: str, relationship: str, pmids: list, source: str):
return {'a_term': a_term, 'a_type': a_type, 'relationship': relationship, 'b_term': b_term, 'b_type': b_type, 'pmids': pmids, 'source': source}

def _sanitize_txt(term: str):
subterms = set()
Expand Down
2 changes: 2 additions & 0 deletions src/run_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
parser.add_argument('--high_priority', default=0, required=False)
parser.add_argument('--medium_priority', default=0, required=False)
parser.add_argument('--low_priority', default=0, required=False)
parser.add_argument('--neo4j_address', default='neo4j:7687', required=False)
args = parser.parse_args()

def start_workers(do_multiprocessing = True):
n_workers = int(args.workers)
high_priority = int(args.high_priority)
medium_priority = int(args.medium_priority)
low_priority = int(args.low_priority)
km_util.neo4j_host = [x.strip() for x in args.neo4j_address.split(',')]

if do_multiprocessing:
worker_processes = []
Expand Down
2 changes: 1 addition & 1 deletion src/workers/km_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class KmWorker(Worker):
def __init__(self, queues=None, *args, **kwargs):
super().__init__(queues, *args, **kwargs)

def start_worker(queues: 'list[str]' = [km_util.JobPriority.MEDIUM.name]):
def start_worker(queues: 'list[str]' = [km_util.JobPriority.MEDIUM.name], neo4j_addresses: 'list[str]' = ['neo4j']):
print('INFO: worker sleeping for 5 sec before starting...')
time.sleep(5)

Expand Down
32 changes: 22 additions & 10 deletions src/workers/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def km_work(json: list):
_initialize_mongo_caching()
knowledge_graph = connect_to_neo4j()
knowledge_graphs = connect_to_neo4j()

return_val = []

Expand Down Expand Up @@ -48,16 +48,19 @@ def km_work(json: list):
query_kg = bool(item['query_knowledge_graph'])

if query_kg and res['pvalue'] < rel_pvalue_cutoff:
rel = knowledge_graph.query(a_term, b_term)
res['relationship'] = rel
res['relationship'] = []

for kg in knowledge_graphs:
rel = kg.query(a_term, b_term)
res['relationship'].extend(rel)

return_val.append(res)

return return_val

def km_work_all_vs_all(json: dict):
_initialize_mongo_caching()
knowledge_graph = connect_to_neo4j()
knowledge_graphs = connect_to_neo4j()

return_val = []
km_only = False
Expand Down Expand Up @@ -171,8 +174,11 @@ def km_work_all_vs_all(json: dict):

if query_kg:
if abc_result['ab_pvalue'] < _rel_pvalue_cutoff:
rel = knowledge_graph.query(abc_result['a_term'], abc_result['b_term'], censor_year)
abc_result['ab_relationship'] = rel
abc_result['ab_relationship'] = []

for kg in knowledge_graphs:
rel = kg.query(abc_result['a_term'], abc_result['b_term'], censor_year)
abc_result['ab_relationship'].extend(rel)
else:
abc_result['ab_relationship'] = None

Expand All @@ -197,8 +203,11 @@ def km_work_all_vs_all(json: dict):

if query_kg:
if abc_result['bc_pvalue'] < _rel_pvalue_cutoff:
rel = knowledge_graph.query(abc_result['b_term'], abc_result['c_term'], censor_year)
abc_result['bc_relationship'] = rel
abc_result['bc_relationship'] = []

for kg in knowledge_graphs:
rel = kg.query(abc_result['b_term'], abc_result['c_term'], censor_year)
abc_result['bc_relationship'].extend(rel)
else:
abc_result['bc_relationship'] = None

Expand Down Expand Up @@ -326,8 +335,11 @@ def _initialize_mongo_caching():
# such as 'fever' to save the current state of the index
li.the_index._check_if_mongo_should_be_refreshed()

def connect_to_neo4j():
return KnowledgeGraph()
def connect_to_neo4j() -> 'list[KnowledgeGraph]':
graphs = []
for url in km_util.neo4j_host:
graphs.append(KnowledgeGraph(url))
return graphs

def _queue_jobs(jobs):
for job in jobs:
Expand Down

0 comments on commit 112161b

Please sign in to comment.