From 77d3cdeda3c238c9743de0fe308ca6d89a40006e Mon Sep 17 00:00:00 2001 From: LilDojd Date: Sat, 16 Nov 2024 18:20:56 +0400 Subject: [PATCH 01/11] Validate GufeKey for most common exploits before creating ScopedKey --- alchemiscale/models.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/alchemiscale/models.py b/alchemiscale/models.py index 69fc1288..e59da07c 100644 --- a/alchemiscale/models.py +++ b/alchemiscale/models.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, validator, root_validator from gufe.tokenization import GufeKey from re import fullmatch +import unicodedata class Scope(BaseModel): @@ -131,8 +132,27 @@ class Config: frozen = True @validator("gufe_key") - def cast_gufe_key(cls, v): - return GufeKey(v) + def gufe_key_validator(cls, v): + v = str(v) + + # Normalize the input to NFC form + + v_normalized = unicodedata.normalize("NFC", v) + + # Ensure that there are no control characters + if any(unicodedata.category(c) == "Cc" for c in v_normalized): + raise ValueError("gufe_key contains invalid control characters") + + # Allowed characters: letters, numbers, underscores, hyphens + allowed_chars = set( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-" + ) + + if not set(v_normalized).issubset(allowed_chars): + raise ValueError("gufe_key contains invalid characters") + + # Cast to GufeKey + return GufeKey(v_normalized) def __repr__(self): # pragma: no cover return f"" From dd650071f6a0ef24e7d8c3721952702ee1c1dcdf Mon Sep 17 00:00:00 2001 From: LilDojd Date: Sat, 16 Nov 2024 18:31:35 +0400 Subject: [PATCH 02/11] Add tests for GufeKey validation --- alchemiscale/models.py | 10 ++++++++-- alchemiscale/tests/unit/test_models.py | 19 ++++++++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/alchemiscale/models.py b/alchemiscale/models.py index e59da07c..861fe0c9 100644 --- a/alchemiscale/models.py +++ b/alchemiscale/models.py @@ -114,6 +114,7 @@ def specific(self) -> bool: """Return `True` if this Scope has no unspecified elements.""" return all(self.to_tuple()) +class InvalidGufeKeyError(ValueError): ... class ScopedKey(BaseModel): """Unique identifier for GufeTokenizables in state store. @@ -135,13 +136,18 @@ class Config: def gufe_key_validator(cls, v): v = str(v) + # GufeKey is of form - + prefix, token = v.split("-") + if not prefix or not token: + raise InvalidGufeKeyError("gufe_key must be of the form '-'") + # Normalize the input to NFC form v_normalized = unicodedata.normalize("NFC", v) # Ensure that there are no control characters if any(unicodedata.category(c) == "Cc" for c in v_normalized): - raise ValueError("gufe_key contains invalid control characters") + raise InvalidGufeKeyError("gufe_key contains invalid control characters") # Allowed characters: letters, numbers, underscores, hyphens allowed_chars = set( @@ -149,7 +155,7 @@ def gufe_key_validator(cls, v): ) if not set(v_normalized).issubset(allowed_chars): - raise ValueError("gufe_key contains invalid characters") + raise InvalidGufeKeyError("gufe_key contains invalid characters") # Cast to GufeKey return GufeKey(v_normalized) diff --git a/alchemiscale/tests/unit/test_models.py b/alchemiscale/tests/unit/test_models.py index c8285fbf..f6f59af2 100644 --- a/alchemiscale/tests/unit/test_models.py +++ b/alchemiscale/tests/unit/test_models.py @@ -2,7 +2,7 @@ from pydantic import ValidationError -from alchemiscale.models import Scope +from alchemiscale.models import Scope, ScopedKey @pytest.mark.parametrize( @@ -101,3 +101,20 @@ def test_scope_non_alphanumeric_invalid(scope_string): ) def test_underscore_scopes_valid(scope_string): scope = Scope.from_str(scope_string) + +@pytest.mark.parametrize( + "gufe_key", + [ + "White Space-token", + "WhiteSpace-tok en", + "NoToken", + "Unicode-\u0027MATCH", + "CredentialedEntity) DETACH DELETE n //", + "BadPrefix-token`backtick", + ], +) +def test_gufe_key_invalid(gufe_key): + with pytest.raises(ValidationError): + ScopedKey( + gufe_key=gufe_key, org="org1", campaign="campaignA", project="projectI" + ) From 99034dc283abcb8383f5ca0ae50255631f0d50ac Mon Sep 17 00:00:00 2001 From: LilDojd Date: Sat, 16 Nov 2024 18:37:19 +0400 Subject: [PATCH 03/11] Add positive tests for GufeKey validation --- alchemiscale/models.py | 4 ++-- alchemiscale/tests/unit/test_models.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/alchemiscale/models.py b/alchemiscale/models.py index 861fe0c9..2144dcf0 100644 --- a/alchemiscale/models.py +++ b/alchemiscale/models.py @@ -136,10 +136,10 @@ class Config: def gufe_key_validator(cls, v): v = str(v) - # GufeKey is of form - + # GufeKey is of form - prefix, token = v.split("-") if not prefix or not token: - raise InvalidGufeKeyError("gufe_key must be of the form '-'") + raise InvalidGufeKeyError("gufe_key must be of the form '-'") # Normalize the input to NFC form diff --git a/alchemiscale/tests/unit/test_models.py b/alchemiscale/tests/unit/test_models.py index f6f59af2..b687ff86 100644 --- a/alchemiscale/tests/unit/test_models.py +++ b/alchemiscale/tests/unit/test_models.py @@ -118,3 +118,16 @@ def test_gufe_key_invalid(gufe_key): ScopedKey( gufe_key=gufe_key, org="org1", campaign="campaignA", project="projectI" ) + +@pytest.mark.parametrize( + "gufe_key", + [ + "ClassName-uuid4hex", + "DummyProtocol-1234567890abcdef", + "DummyProtocol-1234567890abcdef41234567890abcdef", + ], +) +def test_gufe_key_valid(gufe_key): + scoped_key = ScopedKey( + gufe_key=gufe_key, org="org1", campaign="campaignA", project="projectI" + ) From 10456b9bfc9fc680f058060ee8ce3de5aeaaaa85 Mon Sep 17 00:00:00 2001 From: LilDojd Date: Sun, 17 Nov 2024 01:26:17 +0400 Subject: [PATCH 04/11] Use cypher params in _query This commit changes `_query()` function in Neo4jStore and adds test that shows how the exploit could be abused in previous versions --- alchemiscale/storage/statestore.py | 34 ++++++-------- .../integration/storage/test_statestore.py | 46 +++++++++++++++++++ 2 files changed, 61 insertions(+), 19 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 4e6b9c3c..576c9f2b 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -468,26 +468,21 @@ def _get_node( ) -> Union[Node, Tuple[Node, Subgraph]]: """ If `return_subgraph = True`, also return subgraph for gufe object. - """ - qualname = scoped_key.qualname - - properties = {"_scoped_key": str(scoped_key)} - prop_string = ", ".join( - "{}: '{}'".format(key, value) for key, value in properties.items() - ) - prop_string = f" {{{prop_string}}}" + # Safety: qualname comes from GufeKey which is validated + qualname = scoped_key.qualname + parameters = {"scoped_key": str(scoped_key)} q = f""" - MATCH (n:{qualname}{prop_string}) + MATCH (n:{qualname} {{ _scoped_key: $scoped_key }}) """ if return_subgraph: q += """ OPTIONAL MATCH p = (n)-[r:DEPENDS_ON*]->(m) WHERE NOT (m)-[:DEPENDS_ON]->() - RETURN n,p + RETURN n, p """ else: q += """ @@ -497,10 +492,12 @@ def _get_node( nodes = set() subgraph = Subgraph() - for record in self.execute_query(q).records: + result = self.execute_query(q, parameters_=parameters) + + for record in result.records: node = record_data_to_node(record["n"]) nodes.add(node) - if return_subgraph and record["p"] is not None: + if return_subgraph and record.get("p") is not None: subgraph = subgraph | subgraph_from_path_record(record["p"]) else: subgraph = node @@ -521,8 +518,8 @@ def _query( self, *, qualname: str, - additional: Dict = None, - key: GufeKey = None, + additional: Optional[Dict] = None, + key: Optional[GufeKey] = None, scope: Scope = Scope(), return_gufe=False, ): @@ -532,9 +529,8 @@ def _query( "_project": scope.project, } - for k, v in list(properties.items()): - if v is None: - properties.pop(k) + # Remove None values from properties + properties = {k: v for k, v in properties.items() if v is not None} if key is not None: properties["_gufe_key"] = str(key) @@ -547,7 +543,7 @@ def _query( prop_string = "" else: prop_string = ", ".join( - "{}: '{}'".format(key, value) for key, value in properties.items() + "{}: ${}".format(key, key) for key in properties.keys() ) prop_string = f" {{{prop_string}}}" @@ -568,7 +564,7 @@ def _query( """ with self.transaction() as tx: - res = tx.run(q).to_eager_result() + res = tx.run(q, **properties).to_eager_result() nodes = list() subgraph = Subgraph() diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 3ec702a6..883cd407 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -271,6 +271,52 @@ def test_query_transformations(self, n4js, network_tyk2, multiple_scopes): == 1 ) + def test_query_transformations_exploit(self, n4js, multiple_scopes, network_tyk2): + # This test is to show that common cypher exploits are mitigated by using parameters + + an = network_tyk2 + + n4js.assemble_network(an, multiple_scopes[0]) + n4js.assemble_network(an, multiple_scopes[1]) + + malicious_name = """'}) + WITH {_org: '', _campaign: '', _project: '', _gufe_key: ''} AS n + RETURN n + UNION + MATCH (m) DETACH DELETE m + WITH {_org: '', _campaign: '', _project: '', _gufe_key: ''} AS n + RETURN n + UNION + CREATE (mark:InjectionMark {_scoped_key: 'InjectionMark-12345-test-testcamp-testproj'}) + WITH {_org: '', _campaign: '', _project: '', _gufe_key: ''} AS n // """ + try: + n4js.query_transformations(name=malicious_name) + except AttributeError as e: + # With old _query, AttributeError would be thrown AFTER the transaction has finished, and the database is already corrupted + assert "'dict' object has no attribute 'labels'" in str(e) + assert len(n4js.query_transformations(scope=multiple_scopes[0])) == 0 + + mark = n4js._query(qualname="InjectionMark") + assert len(mark) == 0 + + assert len(n4js.query_transformations()) == len(network_tyk2.edges) * 2 + assert len(n4js.query_transformations(scope=multiple_scopes[0])) == len( + network_tyk2.edges + ) + + assert ( + len(n4js.query_transformations(name="lig_ejm_31_to_lig_ejm_50_complex")) + == 2 + ) + assert ( + len( + n4js.query_transformations( + scope=multiple_scopes[0], name="lig_ejm_31_to_lig_ejm_50_complex" + ) + ) + == 1 + ) + def test_query_chemicalsystems(self, n4js, network_tyk2, multiple_scopes): an = network_tyk2 From 71f1146221338acf5a3965645d024775713d3a27 Mon Sep 17 00:00:00 2001 From: LilDojd Date: Sun, 17 Nov 2024 02:36:39 +0400 Subject: [PATCH 05/11] WIP on transitioning some queries to parametrised in Neo4jStore --- alchemiscale/storage/statestore.py | 403 +++++++++++++++-------------- 1 file changed, 209 insertions(+), 194 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 576c9f2b..23f97482 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -5,10 +5,11 @@ """ import abc -from datetime import datetime +from datetime import UTC, datetime from contextlib import contextmanager import json from functools import lru_cache +from operator import ne from typing import Dict, List, Optional, Union, Tuple import weakref import numpy as np @@ -98,45 +99,22 @@ def _select_tasks_from_taskpool(taskpool: List[Tuple[str, float]], count) -> Lis return list(np.random.choice(tasks, count, replace=False, p=prob)) - -def _generate_claim_query( - task_sks: List[ScopedKey], compute_service_id: ComputeServiceID -) -> str: - """Generate a query to claim a list of Tasks. - - Parameters - ---------- - task_sks - A list of ScopedKeys of Tasks to claim. - compute_service_id - ComputeServiceID of the claiming service. - - Returns - ------- - query: str - The Cypher query to claim the Task. - """ - - task_data = cypher_list_from_scoped_keys(task_sks) - - query = f""" +CLAIM_QUERY = f""" // only match the task if it doesn't have an existing CLAIMS relationship - UNWIND {task_data} AS task_sk + UNWIND $tasks_list AS task_sk MATCH (t:Task {{_scoped_key: task_sk}}) WHERE NOT (t)<-[:CLAIMS]-(:ComputeServiceRegistration) WITH t // create CLAIMS relationship with given compute service - MATCH (csreg:ComputeServiceRegistration {{identifier: '{compute_service_id}'}}) - CREATE (t)<-[cl:CLAIMS {{claimed: localdatetime('{datetime.utcnow().isoformat()}')}}]-(csreg) + MATCH (csreg:ComputeServiceRegistration {{identifier: $compute_service_id}}) + CREATE (t)<-[cl:CLAIMS {{claimed: localdatetime($datetimestr)}}]-(csreg) SET t.status = '{TaskStatusEnum.running.value}' RETURN t - """ - return query - +""" class Neo4jStore(AlchemiscaleStateStore): # uniqueness constraints applied to the database; key is node label, @@ -703,8 +681,8 @@ def delete_network( self.delete_taskhub(network) # then delete the network - q = f""" - MATCH (an:AlchemicalNetwork {{_scoped_key: "{network}"}}) + q = """ + MATCH (an:AlchemicalNetwork {{_scoped_key: $network}}) DETACH DELETE an """ raise NotImplementedError @@ -844,11 +822,14 @@ def query_networks( *, name=None, key=None, - scope: Optional[Scope] = Scope(), + scope: Optional[Scope] = None, state: Optional[str] = None, ) -> List[ScopedKey]: """Query for `AlchemicalNetwork`\s matching given attributes.""" + if scope is None: + scope = Scope() + query_params = dict( name_pattern=name, org_pattern=scope.org, @@ -912,14 +893,14 @@ def query_chemicalsystems(self, *, name=None, key=None, scope: Scope = Scope()): def get_network_transformations(self, network: ScopedKey) -> List[ScopedKey]: """List ScopedKeys for Transformations associated with the given AlchemicalNetwork.""" - q = f""" - MATCH (:AlchemicalNetwork {{_scoped_key: '{network}'}})-[:DEPENDS_ON]->(t:Transformation|NonTransformation) + q = """ + MATCH (:AlchemicalNetwork {_scoped_key: $network})-[:DEPENDS_ON]->(t:Transformation|NonTransformation) WITH t._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, network=str(network)) for rec in res: sks.append(rec["sk"]) @@ -927,14 +908,14 @@ def get_network_transformations(self, network: ScopedKey) -> List[ScopedKey]: def get_transformation_networks(self, transformation: ScopedKey) -> List[ScopedKey]: """List ScopedKeys for AlchemicalNetworks associated with the given Transformation.""" - q = f""" - MATCH (:Transformation|NonTransformation {{_scoped_key: '{transformation}'}})<-[:DEPENDS_ON]-(an:AlchemicalNetwork) + q = """ + MATCH (:Transformation|NonTransformation {_scoped_key: $transformation})<-[:DEPENDS_ON]-(an:AlchemicalNetwork) WITH an._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, transformation=str(transformation)) for rec in res: sks.append(rec["sk"]) @@ -942,14 +923,14 @@ def get_transformation_networks(self, transformation: ScopedKey) -> List[ScopedK def get_network_chemicalsystems(self, network: ScopedKey) -> List[ScopedKey]: """List ScopedKeys for ChemicalSystems associated with the given AlchemicalNetwork.""" - q = f""" - MATCH (:AlchemicalNetwork {{_scoped_key: '{network}'}})-[:DEPENDS_ON]->(cs:ChemicalSystem) + q = """ + MATCH (:AlchemicalNetwork {_scoped_key: $network})-[:DEPENDS_ON]->(cs:ChemicalSystem) WITH cs._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, network=str(network)) for rec in res: sks.append(rec["sk"]) @@ -957,14 +938,14 @@ def get_network_chemicalsystems(self, network: ScopedKey) -> List[ScopedKey]: def get_chemicalsystem_networks(self, chemicalsystem: ScopedKey) -> List[ScopedKey]: """List ScopedKeys for AlchemicalNetworks associated with the given ChemicalSystem.""" - q = f""" - MATCH (:ChemicalSystem {{_scoped_key: '{chemicalsystem}'}})<-[:DEPENDS_ON]-(an:AlchemicalNetwork) + q = """ + MATCH (:ChemicalSystem {_scoped_key: $chemicalsystem})<-[:DEPENDS_ON]-(an:AlchemicalNetwork) WITH an._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, chemicalsystem=str(chemicalsystem)) for rec in res: sks.append(rec["sk"]) @@ -974,14 +955,14 @@ def get_transformation_chemicalsystems( self, transformation: ScopedKey ) -> List[ScopedKey]: """List ScopedKeys for the ChemicalSystems associated with the given Transformation.""" - q = f""" - MATCH (:Transformation|NonTransformation {{_scoped_key: '{transformation}'}})-[:DEPENDS_ON]->(cs:ChemicalSystem) + q = """ + MATCH (:Transformation|NonTransformation {_scoped_key: $transformation})-[:DEPENDS_ON]->(cs:ChemicalSystem) WITH cs._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, transformation=str(transformation)) for rec in res: sks.append(rec["sk"]) @@ -991,14 +972,14 @@ def get_chemicalsystem_transformations( self, chemicalsystem: ScopedKey ) -> List[ScopedKey]: """List ScopedKeys for the Transformations associated with the given ChemicalSystem.""" - q = f""" - MATCH (:ChemicalSystem {{_scoped_key: '{chemicalsystem}'}})<-[:DEPENDS_ON]-(t:Transformation|NonTransformation) + q = """ + MATCH (:ChemicalSystem {_scoped_key: $chemicalsystem})<-[:DEPENDS_ON]-(t:Transformation|NonTransformation) WITH t._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, chemicalsystem=str(chemicalsystem)) for rec in res: sks.append(rec["sk"]) @@ -1089,10 +1070,10 @@ def deregister_computeservice(self, compute_service_id: ComputeServiceID): """ q = f""" - MATCH (n:ComputeServiceRegistration {{identifier: '{compute_service_id}'}}) + MATCH (n:ComputeServiceRegistration {{identifier: $compute_service_id}}) - OPTIONAL MATCH (n)-[cl:CLAIMS]->(t:Task {{status: 'running'}}) - SET t.status = 'waiting' + OPTIONAL MATCH (n)-[cl:CLAIMS]->(t:Task {{status: '{TaskStatusEnum.running.value}'}}) + SET t.status = '{TaskStatusEnum.waiting.value}' WITH n, n.identifier as identifier @@ -1102,7 +1083,7 @@ def deregister_computeservice(self, compute_service_id: ComputeServiceID): """ with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, compute_service_id=str(compute_service_id)) identifier = next(res)["identifier"] return ComputeServiceID(identifier) @@ -1113,12 +1094,12 @@ def heartbeat_computeservice( """Update the heartbeat for the given ComputeServiceID.""" q = f""" - MATCH (n:ComputeServiceRegistration {{identifier: '{compute_service_id}'}}) + MATCH (n:ComputeServiceRegistration {{identifier: $compute_service_id}}) SET n.heartbeat = localdatetime('{heartbeat.isoformat()}') """ with self.transaction() as tx: - tx.run(q) + tx.run(q, compute_service_id=str(compute_service_id)) return compute_service_id @@ -1130,8 +1111,8 @@ def expire_registrations(self, expire_time: datetime): WITH n - OPTIONAL MATCH (n)-[cl:CLAIMS]->(t:Task {{status: 'running'}}) - SET t.status = 'waiting' + OPTIONAL MATCH (n)-[cl:CLAIMS]->(t:Task {{status: '{TaskStatusEnum.running.value}'}}) + SET t.status = '{TaskStatusEnum.waiting.value}' WITH n, n.identifier as ident @@ -1217,13 +1198,15 @@ def get_taskhub( "`network` ScopedKey does not correspond to an `AlchemicalNetwork`" ) - q = f""" - match (th:TaskHub {{network: "{network}"}})-[:PERFORMS]->(an:AlchemicalNetwork) - return th - """ + q = """ + MATCH (th:TaskHub {network: $network})-[:PERFORMS]->(an:AlchemicalNetwork) + RETURN th + """ try: - node = record_data_to_node(self.execute_query(q).records[0]["th"]) + node = record_data_to_node( + self.execute_query(q, network=str(network)).records[0]["th"] + ) except IndexError: raise KeyError("No such object in database") @@ -1245,11 +1228,11 @@ def delete_taskhub( taskhub = self.get_taskhub(network) - q = f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}}), + q = """ + MATCH (th:TaskHub {_scoped_key: $taskhub}) DETACH DELETE th """ - self.execute_query(q) + self.execute_query(q, taskhub=str(taskhub)) return taskhub @@ -1310,14 +1293,14 @@ def get_taskhub_actioned_tasks( A list of dicts, one per TaskHub, which contains the Task ScopedKeys that are actioned on the given TaskHub as keys, with their weights as values. """ - - q = f""" - UNWIND {cypher_list_from_scoped_keys(taskhubs)} as th_sk - MATCH (th: TaskHub {{_scoped_key: th_sk}})-[a:ACTIONS]->(t:Task) + th_scoped_keys = [str(taskhub) for taskhub in taskhubs if taskhub is not None] + q = """ + UNWIND $taskhubs as th_sk + MATCH (th: TaskHub {_scoped_key: th_sk})-[a:ACTIONS]->(t:Task) RETURN t._scoped_key, a.weight, th._scoped_key """ - results = self.execute_query(q) + results = self.execute_query(q, taskhubs=th_scoped_keys) data = {taskhub: {} for taskhub in taskhubs} for record in results.records: @@ -1370,13 +1353,17 @@ def get_taskhub_weight(self, networks: List[ScopedKey]) -> List[float]: "`network` ScopedKey does not correspond to an `AlchemicalNetwork`" ) - q = f""" - UNWIND {cypher_list_from_scoped_keys(networks)} as network - MATCH (th:TaskHub {{network: network}}) + networks_scoped_keys = [ + str(network) for network in networks if network is not None + ] + + q = """ + UNWIND $networks as network + MATCH (th:TaskHub {network: network}) RETURN network, th.weight """ - results = self.execute_query(q) + results = self.execute_query(q, networks=networks_scoped_keys) network_weights = {str(network): None for network in networks} for record in results.records: @@ -1407,10 +1394,12 @@ def action_tasks( # so we can properly return `None` if needed task_map = {str(task): None for task in tasks} + tasks_scoped_keys = [str(task) for task in tasks if task is not None] + q = f""" // get our TaskHub - UNWIND {cypher_list_from_scoped_keys(tasks)} AS task_sk - MATCH (th:TaskHub {{_scoped_key: "{taskhub}"}})-[:PERFORMS]->(an:AlchemicalNetwork) + UNWIND $tasks as task_sk + MATCH (th:TaskHub {{_scoped_key: $taskhub}})-[:PERFORMS]->(an:AlchemicalNetwork) // get the task we want to add to the hub; check that it connects to same network MATCH (task:Task {{_scoped_key: task_sk}})-[:PERFORMS]->(tf:Transformation|NonTransformation)<-[:DEPENDS_ON]-(an) @@ -1419,7 +1408,7 @@ def action_tasks( // and where the task is either in 'waiting', 'running', or 'error' status WITH th, an, task WHERE NOT (th)-[:ACTIONS]->(task) - AND task.status IN ['{TaskStatusEnum.waiting.value}', '{TaskStatusEnum.running.value}', '{TaskStatusEnum.error.value}'] + AND task.status IN ['{TaskStatusEnum.waiting.value}', '{TaskStatusEnum.running.value}', '{TaskStatusEnum.error.value}'] // create the connection CREATE (th)-[ar:ACTIONS {{weight: 0.5}}]->(task) @@ -1430,7 +1419,7 @@ def action_tasks( RETURN task """ - results = self.execute_query(q) + results = self.execute_query(q, tasks=tasks_scoped_keys, taskhub=str(taskhub)) # update our map with the results, leaving None for tasks that aren't found for task_record in results.records: @@ -1492,13 +1481,19 @@ def set_task_weights( if not all([0 <= weight <= 1 for weight in tasks.values()]): raise ValueError("weights must be between 0 and 1 (inclusive)") - for t, w in tasks.items(): - q = f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) - SET ar.weight = {w} - RETURN task, ar - """ - results.append(tx.run(q).to_eager_result()) + tasks_list = [{"task": str(t), "weight": w} for t, w in tasks.items()] + + q = """ + UNWIND $tasks_list AS item + MATCH (th:TaskHub {_scoped_key: $taskhub})-[ar:ACTIONS]->(task:Task {_scoped_key: item.task}) + SET ar.weight = item.weight + RETURN task, ar + """ + results.append( + tx.run( + q, taskhub=str(taskhub), tasks_list=tasks_list + ).to_eager_result() + ) elif isinstance(tasks, list): if weight is None: @@ -1509,14 +1504,19 @@ def set_task_weights( if not 0 <= weight <= 1: raise ValueError("weight must be between 0 and 1 (inclusive)") - # TODO: remove for loop with an unwind clause - for t in tasks: - q = f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) - SET ar.weight = {weight} - RETURN task, ar - """ - results.append(tx.run(q).to_eager_result()) + tasks_list = [str(t) for t in tasks] + + q = """ + UNWIND $tasks_list AS task_sk + MATCH (th:TaskHub {_scoped_key: $taskhub})-[ar:ACTIONS]->(task:Task {_scoped_key: task_sk}) + SET ar.weight = $weight + RETURN task, ar + """ + results.append( + tx.run( + q, taskhub=str(taskhub), tasks_list=tasks_list, weight=weight + ).to_eager_result() + ) # return ScopedKeys for Tasks we changed; `None` for tasks we didn't for res in results: @@ -1549,22 +1549,20 @@ def get_task_weights( weights Weights for the list of Tasks, in the same order. """ - weights = [] + with self.transaction() as tx: - for t in tasks: - q = f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) - RETURN ar.weight - """ - result = tx.run(q) + tasks_list = [str(t) for t in tasks if t is not None] - weight = [record.get("ar.weight") for record in result] + q = """ + UNWIND $tasks_list AS task_scoped_key + OPTIONAL MATCH (th:TaskHub {_scoped_key: $taskhub})-[ar:ACTIONS]->(task:Task {_scoped_key: task_scoped_key}) + RETURN task_scoped_key, ar.weight AS weight + """ - # if no match for the given Task, we put a `None` as result - if len(weight) == 0: - weights.append(None) - else: - weights.extend(weight) + result = tx.run(q, taskhub=str(taskhub), tasks_list=tasks_list) + results = result.data() + + weights = [record["weight"] for record in results] return weights @@ -1605,13 +1603,13 @@ def get_taskhub_tasks( ) -> Union[List[ScopedKey], Dict[ScopedKey, Task]]: """Get a list of Tasks on the TaskHub.""" - q = f""" + q = """ // get list of all tasks associated with the taskhub - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[:ACTIONS]->(task:Task) + MATCH (th:TaskHub {_scoped_key: $taskhub})-[:ACTIONS]->(task:Task) RETURN task """ with self.transaction() as tx: - res = tx.run(q).to_eager_result() + res = tx.run(q, taskhub=str(taskhub)).to_eager_result() tasks = [] subgraph = Subgraph() @@ -1632,14 +1630,14 @@ def get_taskhub_unclaimed_tasks( ) -> Union[List[ScopedKey], Dict[ScopedKey, Task]]: """Get a list of unclaimed Tasks in the TaskHub.""" - q = f""" + q = """ // get list of all unclaimed tasks in the hub - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[:ACTIONS]->(task:Task) + MATCH (th:TaskHub {_scoped_key: $taskhub})-[:ACTIONS]->(task:Task) WHERE NOT (task)<-[:CLAIMS]-(:ComputeServiceRegistration) RETURN task """ with self.transaction() as tx: - res = tx.run(q).to_eager_result() + res = tx.run(q, taskhub=str(taskhub)).to_eager_result() tasks = [] subgraph = Subgraph() @@ -1691,7 +1689,7 @@ def claim_taskhub_tasks( raise ValueError("`protocols` must be either `None` or not empty") q = f""" - MATCH (th:TaskHub {{`_scoped_key`: '{taskhub}'}})-[actions:ACTIONS]-(task:Task) + MATCH (th:TaskHub {{_scoped_key: $taskhub}})-[actions:ACTIONS]-(task:Task) WHERE task.status = '{TaskStatusEnum.waiting.value}' AND actions.weight > 0 OPTIONAL MATCH (task)-[:EXTENDS]->(other_task:Task) @@ -1721,14 +1719,15 @@ def claim_taskhub_tasks( _tasks = {} with self.transaction() as tx: tx.run( - f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}}) + """ + MATCH (th:TaskHub {_scoped_key: $taskhub}) - // lock the TaskHub to avoid other queries from changing its state while we claim - SET th._lock = True - """ + // lock the TaskHub to avoid other queries from changing its state while we claim + SET th._lock = True + """, + taskhub=str(taskhub), ) - _taskpool = tx.run(q) + _taskpool = tx.run(q, taskhub=str(taskhub)) def task_count(task_dict: dict): return sum(map(len, task_dict.values())) @@ -1793,16 +1792,21 @@ def task_count(task_dict: dict): # if tasks is not empty, proceed with claiming if tasks: - q = _generate_claim_query(tasks, compute_service_id) - tx.run(q) + tx.run( + CLAIM_QUERY, + tasks_list=[str(task) for task in tasks if task is not None], + datetimestr=datetime.now(UTC).isoformat(), + compute_service_id=str(compute_service_id), + ) tx.run( - f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}}) + """ + MATCH (th:TaskHub {_scoped_key: $taskhub}) - // remove lock on the TaskHub now that we're done with it - SET th._lock = null - """ + // remove lock on the TaskHub now that we're done with it + SET th._lock = null + """, + taskhub=str(taskhub), ) return tasks + [None] * (count - len(tasks)) @@ -1814,13 +1818,13 @@ def _validate_extends_tasks(self, task_list) -> Dict[str, Tuple[Node, str]]: if not task_list: return {} - q = f""" - UNWIND {cypher_list_from_scoped_keys(task_list)} as task - MATCH (t:Task {{`_scoped_key`: task}})-[PERFORMS]->(tf:Transformation|NonTransformation) + q = """ + UNWIND $task_list AS task + MATCH (t:Task {_scoped_key: task})-[PERFORMS]->(tf:Transformation|NonTransformation) return t, tf._scoped_key as tf_sk """ - results = self.execute_query(q) + results = self.execute_query(q, task_list=list(map(str, task_list))) nodes = {} @@ -1915,12 +1919,14 @@ def create_tasks( continue q = f""" - UNWIND {cypher_list_from_scoped_keys(transformation_subset)} as sk + UNWIND $transformation_subset AS sk MATCH (n:{node_type} {{`_scoped_key`: sk}}) RETURN n """ - results = self.execute_query(q) + results = self.execute_query( + q, transformation_subset=list(map(str, transformation_subset)) + ) transformation_nodes = {} for record in results.records: @@ -2003,14 +2009,14 @@ def get_network_tasks( self, network: ScopedKey, status: Optional[TaskStatusEnum] = None ) -> List[ScopedKey]: """List ScopedKeys for all Tasks associated with the given AlchemicalNetwork.""" - q = f""" - MATCH (an:AlchemicalNetwork {{_scoped_key: "{network}"}})-[:DEPENDS_ON]->(tf:Transformation|NonTransformation), + q = """ + MATCH (an:AlchemicalNetwork {_scoped_key: $network})-[:DEPENDS_ON]->(tf:Transformation|NonTransformation), (tf)<-[:PERFORMS]-(t:Task) """ if status is not None: - q += f""" - WHERE t.status = '{status.value}' + q += """ + WHERE t.status = $status """ q += """ @@ -2019,7 +2025,9 @@ def get_network_tasks( """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run( + q, network=str(network), status=status.value if status else None + ) for rec in res: sks.append(rec["sk"]) @@ -2027,15 +2035,15 @@ def get_network_tasks( def get_task_networks(self, task: ScopedKey) -> List[ScopedKey]: """List ScopedKeys for AlchemicalNetworks associated with the given Task.""" - q = f""" - MATCH (t:Task {{_scoped_key: '{task}'}})-[:PERFORMS]->(tf:Transformation|NonTransformation), + q = """ + MATCH (t:Task {_scoped_key: $task})-[:PERFORMS]->(tf:Transformation|NonTransformation), (tf)<-[:DEPENDS_ON]-(an:AlchemicalNetwork) WITH an._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, task=str(task)) for rec in res: sks.append(rec["sk"]) @@ -2064,18 +2072,18 @@ def get_transformation_tasks( extends """ - q = f""" - MATCH (trans:Transformation|NonTransformation {{_scoped_key: '{transformation}'}})<-[:PERFORMS]-(task:Task) + q = """ + MATCH (trans:Transformation|NonTransformation {_scoped_key: $transformation})<-[:PERFORMS]-(task:Task) """ if status is not None: - q += f""" - WHERE task.status = '{status.value}' + q += """ + WHERE task.status = $status """ if extends: - q += f""" - MATCH (trans)<-[:PERFORMS]-(extends:Task {{_scoped_key: '{extends}'}}) + q += """ + MATCH (trans)<-[:PERFORMS]-(extends:Task {_scoped_key: $extends}) WHERE (task)-[:EXTENDS*]->(extends) RETURN task """ @@ -2085,7 +2093,12 @@ def get_transformation_tasks( """ with self.transaction() as tx: - res = tx.run(q).to_eager_result() + res = tx.run( + q, + transformation=str(transformation), + status=status.value if status else None, + extends=str(extends) if extends else None, + ).to_eager_result() tasks = [] for record in res.records: @@ -2119,14 +2132,14 @@ def get_task_transformation( `ScopedKey`\s for these instead. """ - q = f""" - MATCH (task:Task {{_scoped_key: "{task}"}})-[:PERFORMS]->(trans:Transformation|NonTransformation) + q = """ + MATCH (task:Task {_scoped_key: $task})-[:PERFORMS]->(trans:Transformation|NonTransformation) OPTIONAL MATCH (task)-[:EXTENDS]->(prev:Task)-[:RESULTS_IN]->(result:ProtocolDAGResultRef) RETURN trans, result """ with self.transaction() as tx: - res = tx.run(q).to_eager_result() + res = tx.run(q, task=str(task)).to_eager_result() transformations = [] results = [] @@ -2225,7 +2238,9 @@ def set_task_priority( RETURN scoped_key, t """ res = tx.run( - q, scoped_keys=[str(t) for t in tasks], priority=priority + q, + scoped_keys=[str(t) for t in tasks if t is not None], + priority=priority, ).to_eager_result() task_results = [] @@ -2262,7 +2277,7 @@ def get_task_priority(self, tasks: List[ScopedKey]) -> List[Optional[int]]: WHERE t._scoped_key = scoped_key RETURN t.priority as priority """ - res = tx.run(q, scoped_keys=[str(t) for t in tasks]) + res = tx.run(q, scoped_keys=[str(t) for t in tasks if t is not None]) priorities = [rec["priority"] for rec in res] return priorities @@ -2304,7 +2319,7 @@ def get_scope_status( } prop_string = ", ".join( - "{}: '{}'".format(key, value) + "{}: ${}".format(key, key) for key, value in properties.items() if value is not None ) @@ -2321,22 +2336,22 @@ def get_scope_status( RETURN n.status AS status, count(DISTINCT n) as counts """ with self.transaction() as tx: - res = tx.run(q, state_pattern=network_state) + res = tx.run(q, state_pattern=network_state, **properties) counts = {rec["status"]: rec["counts"] for rec in res} return counts def get_network_status(self, networks: List[ScopedKey]) -> List[Dict[str, int]]: """Return status counts for all Tasks associated with the given AlchemicalNetworks.""" - q = f""" - UNWIND {cypher_list_from_scoped_keys(networks)} as network - MATCH (an:AlchemicalNetwork {{_scoped_key: network}})-[:DEPENDS_ON]->(tf:Transformation|NonTransformation), + q = """ + UNWIND $networks AS network + MATCH (an:AlchemicalNetwork {_scoped_key: network})-[:DEPENDS_ON]->(tf:Transformation|NonTransformation), (tf)<-[:PERFORMS]-(t:Task) RETURN an._scoped_key AS sk, t.status AS status, count(t) as counts """ network_data = {str(network_sk): {} for network_sk in networks} - for rec in self.execute_query(q).records: + for rec in self.execute_query(q, networks=list(map(str, networks))).records: sk = rec["sk"] status = rec["status"] counts = rec["counts"] @@ -2346,12 +2361,12 @@ def get_network_status(self, networks: List[ScopedKey]) -> List[Dict[str, int]]: def get_transformation_status(self, transformation: ScopedKey) -> Dict[str, int]: """Return status counts for all Tasks associated with the given Transformation.""" - q = f""" - MATCH (:Transformation|NonTransformation {{_scoped_key: "{transformation}"}})<-[:PERFORMS]-(t:Task) + q = """ + MATCH (:Transformation|NonTransformation {_scoped_key: $transformation})<-[:PERFORMS]-(t:Task) RETURN t.status AS status, count(t) as counts """ with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, transformation=str(transformation)) counts = {rec["status"]: rec["counts"] for rec in res} return counts @@ -2503,15 +2518,15 @@ def set_task_waiting( """ - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE t_.status IN ['waiting', 'running', 'error'] - SET t_.status = 'waiting' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE t_.status IN ['{TaskStatusEnum.waiting.value}', '{TaskStatusEnum.running.value}', '{TaskStatusEnum.error.value}'] + SET t_.status = '{TaskStatusEnum.waiting.value}' WITH scoped_key, t, t_ @@ -2537,15 +2552,15 @@ def set_task_running( """ - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE t_.status IN ['running', 'waiting'] - SET t_.status = 'running' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE t_.status IN ['{TaskStatusEnum.running.value}', '{TaskStatusEnum.waiting.value}'] + SET t_.status = '{TaskStatusEnum.running.value}' RETURN scoped_key, t, t_ """ @@ -2564,15 +2579,15 @@ def set_task_complete( """ - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE t_.status IN ['complete', 'running'] - SET t_.status = 'complete' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE t_.status IN ['{TaskStatusEnum.complete.value}', '{TaskStatusEnum.running.value}'] + SET t_.status = '{TaskStatusEnum.complete.value}' WITH scoped_key, t, t_ @@ -2605,15 +2620,15 @@ def set_task_error( """ - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE t_.status IN ['error', 'running'] - SET t_.status = 'error' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE t_.status IN ['{TaskStatusEnum.error.value}', '{TaskStatusEnum.running.value}'] + SET t_.status = '{TaskStatusEnum.error.value}' WITH scoped_key, t, t_ @@ -2643,20 +2658,20 @@ def set_task_invalid( # set the status and delete the ACTIONS relationship # make sure we follow the extends chain and set all tasks to invalid # and remove actions relationships - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE NOT t_.status IN ['deleted'] - SET t_.status = 'invalid' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE NOT t_.status IN ['{TaskStatusEnum.deleted.value}'] + SET t_.status = '{TaskStatusEnum.invalid.value}' WITH scoped_key, t, t_ OPTIONAL MATCH (t_)<-[er:EXTENDS*]-(extends_task:Task) - SET extends_task.status = 'invalid' + SET extends_task.status = '{TaskStatusEnum.invalid.value}' WITH scoped_key, t, t_, extends_task @@ -2693,20 +2708,20 @@ def set_task_deleted( # set the status and delete the ACTIONS relationship # make sure we follow the extends chain and set all tasks to deleted # and remove actions relationships - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE NOT t_.status IN ['invalid'] - SET t_.status = 'deleted' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE NOT t_.status IN ['{TaskStatusEnum.invalid.value}'] + SET t_.status = '{TaskStatusEnum.deleted.value}' WITH scoped_key, t, t_ OPTIONAL MATCH (t_)<-[er:EXTENDS*]-(extends_task:Task) - SET extends_task.status = 'deleted' + SET extends_task.status = '{TaskStatusEnum.deleted.value}' WITH scoped_key, t, t_, extends_task From 80332c0d5a53acd57743ee36bdeb80a22f002025 Mon Sep 17 00:00:00 2001 From: LilDojd Date: Sun, 17 Nov 2024 02:39:38 +0400 Subject: [PATCH 06/11] Fix tests --- alchemiscale/storage/statestore.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 23f97482..366c8a54 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -5,7 +5,7 @@ """ import abc -from datetime import UTC, datetime +from datetime import datetime from contextlib import contextmanager import json from functools import lru_cache @@ -1795,7 +1795,7 @@ def task_count(task_dict: dict): tx.run( CLAIM_QUERY, tasks_list=[str(task) for task in tasks if task is not None], - datetimestr=datetime.now(UTC).isoformat(), + datetimestr=str(datetime.utcnow().isoformat()), compute_service_id=str(compute_service_id), ) From 1da0d87ac11632c13b5a10e11d07ebb647a89114 Mon Sep 17 00:00:00 2001 From: "David L. Dotson" Date: Mon, 18 Nov 2024 17:27:08 -0700 Subject: [PATCH 07/11] Update alchemiscale/storage/statestore.py Co-authored-by: Ian Kenney --- alchemiscale/storage/statestore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 366c8a54..7544210c 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -682,7 +682,7 @@ def delete_network( # then delete the network q = """ - MATCH (an:AlchemicalNetwork {{_scoped_key: $network}}) + MATCH (an:AlchemicalNetwork {_scoped_key: $network}) DETACH DELETE an """ raise NotImplementedError From fb9c9b830b976496e19f640e054428c34e89974d Mon Sep 17 00:00:00 2001 From: LilDojd Date: Tue, 19 Nov 2024 14:10:26 +0400 Subject: [PATCH 08/11] Address s in review #330 --- alchemiscale/models.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/alchemiscale/models.py b/alchemiscale/models.py index 2144dcf0..90b70423 100644 --- a/alchemiscale/models.py +++ b/alchemiscale/models.py @@ -9,7 +9,7 @@ from gufe.tokenization import GufeKey from re import fullmatch import unicodedata - +import string class Scope(BaseModel): org: Optional[str] = None @@ -137,22 +137,16 @@ def gufe_key_validator(cls, v): v = str(v) # GufeKey is of form - - prefix, token = v.split("-") - if not prefix or not token: + try: + _prefix, _token = v.split("-") + except ValueError: raise InvalidGufeKeyError("gufe_key must be of the form '-'") # Normalize the input to NFC form - v_normalized = unicodedata.normalize("NFC", v) - # Ensure that there are no control characters - if any(unicodedata.category(c) == "Cc" for c in v_normalized): - raise InvalidGufeKeyError("gufe_key contains invalid control characters") - # Allowed characters: letters, numbers, underscores, hyphens - allowed_chars = set( - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-" - ) + allowed_chars = set(string.ascii_letters + string.digits + "_-") if not set(v_normalized).issubset(allowed_chars): raise InvalidGufeKeyError("gufe_key contains invalid characters") From cd59ed948207950d6f356190ebe46b521efcccb7 Mon Sep 17 00:00:00 2001 From: LilDojd Date: Tue, 19 Nov 2024 14:16:13 +0400 Subject: [PATCH 09/11] Stay true to signature review #330 --- alchemiscale/storage/statestore.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 7544210c..438aa0e2 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -1551,15 +1551,13 @@ def get_task_weights( """ with self.transaction() as tx: - tasks_list = [str(t) for t in tasks if t is not None] - q = """ UNWIND $tasks_list AS task_scoped_key OPTIONAL MATCH (th:TaskHub {_scoped_key: $taskhub})-[ar:ACTIONS]->(task:Task {_scoped_key: task_scoped_key}) RETURN task_scoped_key, ar.weight AS weight """ - result = tx.run(q, taskhub=str(taskhub), tasks_list=tasks_list) + result = tx.run(q, taskhub=str(taskhub), tasks_list=list(map(str, tasks))) results = result.data() weights = [record["weight"] for record in results] @@ -2239,7 +2237,7 @@ def set_task_priority( """ res = tx.run( q, - scoped_keys=[str(t) for t in tasks if t is not None], + scoped_keys=list(map(str, tasks)), priority=priority, ).to_eager_result() @@ -2277,7 +2275,7 @@ def get_task_priority(self, tasks: List[ScopedKey]) -> List[Optional[int]]: WHERE t._scoped_key = scoped_key RETURN t.priority as priority """ - res = tx.run(q, scoped_keys=[str(t) for t in tasks if t is not None]) + res = tx.run(q, scoped_keys=list(map(str, tasks))) priorities = [rec["priority"] for rec in res] return priorities From 4b97635a7a548f7d120e148a0d941b1cf24ab6cf Mon Sep 17 00:00:00 2001 From: LilDojd Date: Tue, 19 Nov 2024 14:31:30 +0400 Subject: [PATCH 10/11] Add additional check in query test review #330 --- .../tests/integration/storage/test_statestore.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 883cd407..f2f25ef5 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -296,8 +296,15 @@ def test_query_transformations_exploit(self, n4js, multiple_scopes, network_tyk2 assert "'dict' object has no attribute 'labels'" in str(e) assert len(n4js.query_transformations(scope=multiple_scopes[0])) == 0 - mark = n4js._query(qualname="InjectionMark") - assert len(mark) == 0 + mark_from__query = n4js._query(qualname="InjectionMark") + # Just to be double sure, check explicitly + q = """ + match (m:InjectionMark) + return m + """ + mark_explicit = n4js.execute_query(q).records + + assert len(mark_from__query) == len(mark_explicit) == 0 assert len(n4js.query_transformations()) == len(network_tyk2.edges) * 2 assert len(n4js.query_transformations(scope=multiple_scopes[0])) == len( From 4848236230e5af400b92311726409527b684f6a8 Mon Sep 17 00:00:00 2001 From: LilDojd Date: Tue, 19 Nov 2024 14:37:07 +0400 Subject: [PATCH 11/11] Black black black --- alchemiscale/models.py | 3 +++ alchemiscale/storage/statestore.py | 2 ++ alchemiscale/tests/unit/test_models.py | 2 ++ 3 files changed, 7 insertions(+) diff --git a/alchemiscale/models.py b/alchemiscale/models.py index 90b70423..ed7a6cfb 100644 --- a/alchemiscale/models.py +++ b/alchemiscale/models.py @@ -11,6 +11,7 @@ import unicodedata import string + class Scope(BaseModel): org: Optional[str] = None campaign: Optional[str] = None @@ -114,8 +115,10 @@ def specific(self) -> bool: """Return `True` if this Scope has no unspecified elements.""" return all(self.to_tuple()) + class InvalidGufeKeyError(ValueError): ... + class ScopedKey(BaseModel): """Unique identifier for GufeTokenizables in state store. diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 438aa0e2..801b6600 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -99,6 +99,7 @@ def _select_tasks_from_taskpool(taskpool: List[Tuple[str, float]], count) -> Lis return list(np.random.choice(tasks, count, replace=False, p=prob)) + CLAIM_QUERY = f""" // only match the task if it doesn't have an existing CLAIMS relationship UNWIND $tasks_list AS task_sk @@ -116,6 +117,7 @@ def _select_tasks_from_taskpool(taskpool: List[Tuple[str, float]], count) -> Lis RETURN t """ + class Neo4jStore(AlchemiscaleStateStore): # uniqueness constraints applied to the database; key is node label, # 'property' is the property on which uniqueness is guaranteed for nodes diff --git a/alchemiscale/tests/unit/test_models.py b/alchemiscale/tests/unit/test_models.py index b687ff86..ba7fc389 100644 --- a/alchemiscale/tests/unit/test_models.py +++ b/alchemiscale/tests/unit/test_models.py @@ -102,6 +102,7 @@ def test_scope_non_alphanumeric_invalid(scope_string): def test_underscore_scopes_valid(scope_string): scope = Scope.from_str(scope_string) + @pytest.mark.parametrize( "gufe_key", [ @@ -119,6 +120,7 @@ def test_gufe_key_invalid(gufe_key): gufe_key=gufe_key, org="org1", campaign="campaignA", project="projectI" ) + @pytest.mark.parametrize( "gufe_key", [