Skip to content
This repository has been archived by the owner on Jan 17, 2025. It is now read-only.

Commit

Permalink
fix concurrent access to self.cache in nosqldict (#121)
Browse files Browse the repository at this point in the history
* fix concurrent access to self.cache in nosqldict

* bump version to 0.4.2

* add -d to docker run in CI
  • Loading branch information
ldruschk authored May 30, 2021
1 parent 654aff0 commit 6a3016d
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 49 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: pytest
run: |
docker run -d -p 27017:27017 mongo
pip install -r dev-requirements.txt
make test
- name: lint
Expand Down
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,9 @@ format:

test:
pip3 install .
coverage run -m pytest
ifdef GITHUB_ACTIONS
coverage run -m pytest -v --with_nosqldict
else
coverage run -m pytest -v
endif
coverage report -m
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setuptools.setup(
name="enochecker",
version="0.4.1",
version="0.4.2",
author="domenukk",
author_email="dmaier@sect.tu-berlin.de",
description="Library to build checker scripts for EnoEngine A/D CTF Framework in Python",
Expand Down
2 changes: 1 addition & 1 deletion src/enochecker/enochecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
self.storage_dir = storage_dir

self._setup_logger()
if use_db_cache:
if use_db_cache and not os.getenv("MONGO_ENABLED"):
self._active_dbs: Dict[str, Union[NoSqlDict, StoredDict]] = global_db_cache
else:
self._active_dbs = {}
Expand Down
105 changes: 59 additions & 46 deletions src/enochecker/nosqldict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from collections.abc import MutableMapping
from functools import wraps
from threading import RLock, current_thread
from threading import Lock, RLock, current_thread
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union

from . import utils
Expand Down Expand Up @@ -135,16 +135,16 @@ def __init__(
self.checker_name = checker_name
self.cache: Dict[Any, Any] = {}
self.hash_cache: Dict[Any, Any] = {}
self._lock: Lock = Lock()
host_: str = host or DB_DEFAULT_HOST
if isinstance(port, int):
port_: int = port
else:
port_ = int(port or DB_DEFAULT_PORT)
username_: Optional[str] = username or DB_DEFAULT_USER
password_: Optional[str] = password or DB_DEFAULT_PASS
self.db = self.get_client(host_, port_, username_, password_, self.logger)[
checker_name
][self.dict_name]
self.client = self.get_client(host_, port_, username_, password_, self.logger)
self.db = self.client[checker_name][self.dict_name]
try:
self.db.index_information()["checker_key"]
except KeyError:
Expand All @@ -162,14 +162,17 @@ def __setitem__(self, key: str, value: Any) -> None:
:param key: key in the dictionary
:param value: value in the dictionary
"""
key = str(key)
with self._lock:
key = str(key)

self.cache[key] = value
hash_ = value_to_hash(value)
if hash_:
self.hash_cache[key] = hash_
self.cache[key] = value
hash_ = value_to_hash(value)
if hash_:
self.hash_cache[key] = hash_
elif key in self.hash_cache:
del self.hash_cache[key]

self._upsert(key, value)
self._upsert(key, value)

def _upsert(self, key: Any, value: Any) -> None:
query_dict = {
Expand Down Expand Up @@ -198,28 +201,31 @@ def __getitem__(self, key: str, print_result: bool = False) -> Any:
:param print_result: TODO
:return: retrieved value
"""
key = str(key)
if key in self.cache.items():
return self.cache[key]
with self._lock:
key = str(key)

to_extract = {
"key": key,
"checker": self.checker_name,
"name": self.dict_name,
}
if key in self.cache:
return self.cache[key]

result = self.db.find_one(to_extract)
to_extract = {
"key": key,
"checker": self.checker_name,
"name": self.dict_name,
}

if print_result:
self.logger.debug(result)
result = self.db.find_one(to_extract)

if result:
self.cache[key] = result["value"]
hash_ = value_to_hash(result)
if hash_:
self.hash_cache[key] = hash_
return result["value"]
raise KeyError("Could not find {} in {}".format(key, self))
if print_result:
self.logger.debug(result)

if result:
val = result["value"]
self.cache[key] = val
hash_ = value_to_hash(val)
if hash_:
self.hash_cache[key] = hash_
return val
raise KeyError("Could not find {} in {}".format(key, self))

@_try_n_times
def __delitem__(self, key: str) -> None:
Expand All @@ -230,16 +236,19 @@ def __delitem__(self, key: str) -> None:
:param key: key to delete
"""
key = str(key)
if key in self.cache:
del self.cache[key]

to_extract = {
"key": key,
"checker": self.checker_name,
"name": self.dict_name,
}
self.db.delete_one(to_extract)
with self._lock:
key = str(key)
if key in self.cache:
del self.cache[key]
if key in self.hash_cache:
del self.hash_cache[key]

to_extract = {
"key": key,
"checker": self.checker_name,
"name": self.dict_name,
}
self.db.delete_one(to_extract)

@_try_n_times
def __len__(self) -> int:
Expand Down Expand Up @@ -267,14 +276,18 @@ def persist(self) -> None:
"""
Persist the changes in the backend.
"""
for (key, value) in self.cache.items():
hash_ = value_to_hash(value)
if (
(not hash_)
or (key not in self.hash_cache)
or (self.hash_cache[key] != hash_)
):
self._upsert(key, value)
with self._lock:
for (key, value) in list(self.cache.items()):
hash_ = value_to_hash(value)
if (
(not hash_)
or (key not in self.hash_cache)
or (self.hash_cache[key] != hash_)
):
self._upsert(key, value)
del self.cache[key]
if key in self.hash_cache:
del self.hash_cache[key]

def __del__(self) -> None:
"""
Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest


def pytest_addoption(parser):
parser.addoption(
"--with_nosqldict", action="store_true", help="Run the tests with the nosqldict"
)


def pytest_configure(config):
config.addinivalue_line("markers", "nosqldict: mark test as requiring MongoDB")


def pytest_collection_modifyitems(config, items):
if config.getoption("--with_nosqldict"):
return
skip_nosqldict = pytest.mark.skip(reason="need --with_nosqldict option to run")
for item in items:
if "nosqldict" in item.keywords:
item.add_marker(skip_nosqldict)
40 changes: 40 additions & 0 deletions tests/test_enochecker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/usr/bin/env python3
import functools
import hashlib
import secrets
import sys
import tempfile
from logging import DEBUG
from unittest import mock

import pytest
from enochecker_core import CheckerMethod, CheckerTaskMessage, CheckerTaskResult
Expand Down Expand Up @@ -310,6 +312,44 @@ def putflagfn(self: CheckerExampleImpl):
assert result.attack_info == attack_info


@pytest.mark.nosqldict
def test_nested_change_enochecker():
import os

with mock.patch.dict(
os.environ,
{
"MONGO_ENABLED": "1",
},
):
dict_name = secrets.token_hex(8)

def putflagfn(self: CheckerExampleImpl):
db = self.db(dict_name)
x = {
"asd": 123,
}
db["test"] = x

x["asd"] = 456

def getflagfn(self: CheckerExampleImpl):
db = self.db(dict_name)
assert db["test"]["asd"] == 456

setattr(CheckerExampleImpl, "putflag", putflagfn)
checker = CheckerExampleImpl(method="putflag")

result = checker.run()
assert result.result == CheckerTaskResult.OK

setattr(CheckerExampleImpl, "getflag", getflagfn)
checker = CheckerExampleImpl(method="getflag")

result = checker.run()
assert result.result == CheckerTaskResult.OK


def main():
pytest.main(sys.argv)

Expand Down
55 changes: 55 additions & 0 deletions tests/test_nosqldict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import secrets

import pytest

from enochecker.nosqldict import NoSqlDict


@pytest.fixture
def nosqldict():
dict_name = secrets.token_hex(8)
checker_name = secrets.token_hex(8)
return NoSqlDict(dict_name, checker_name)


@pytest.mark.nosqldict
def test_basic(nosqldict):
nosqldict["abc"] = "xyz"
assert nosqldict["abc"] == "xyz"

with pytest.raises(KeyError):
_ = nosqldict["xyz"]

nosqldict["abc"] = {"stuff": b"asd"}
assert nosqldict["abc"] == {"stuff": b"asd"}

del nosqldict["abc"]
with pytest.raises(KeyError):
_ = nosqldict["abc"]


@pytest.mark.nosqldict
def test_nested_change():
dict_name = secrets.token_hex(8)
checker_name = secrets.token_hex(8)

def scoped_access(dict_name, checker_name):
nosqldict = NoSqlDict(dict_name, checker_name)

x = {
"asd": 123,
}
nosqldict["test"] = x
x["asd"] = 456

assert nosqldict["test"] == {
"asd": 456,
}

scoped_access(dict_name, checker_name)

nosqldict_new = NoSqlDict(dict_name, checker_name)

assert nosqldict_new["test"] == {
"asd": 456,
}

0 comments on commit 6a3016d

Please sign in to comment.