diff --git a/singer/utils.py b/singer/utils.py index 85f3d39..84ddb94 100644 --- a/singer/utils.py +++ b/singer/utils.py @@ -104,9 +104,13 @@ def chunk(array, num): yield array[i:i + num] -def load_json(path): - with open(path) as fil: - return json.load(fil) +def load_json(path_or_json): + try: + inline_config = json.loads(path_or_json) + except ValueError: + with open(path_or_json) as fil: + return json.load(fil) + return inline_config def update_state(state, entity, dtime): diff --git a/tests/test_utils.py b/tests/test_utils.py index bb26da6..3794a03 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,8 @@ from datetime import datetime as dt import pytz import logging +import json +import tempfile import singer.utils as u @@ -33,3 +35,28 @@ def test_exception_fn(self): def foo(): raise RuntimeError("foo") self.assertRaises(RuntimeError, foo) + +class TestLoadJson(unittest.TestCase): + def setUp(self): + self.expected_json = """ + { + "key1": false, + "key2": [ + {"field1": 366, "field2": "2018-01-01T00:00:00+00:00"} + ] + } + """ + + def test_inline(self): + inline = u.load_json(self.expected_json) + self.assertEqual(inline, json.loads(self.expected_json)) + + def test_path(self): + # from valid path + with tempfile.NamedTemporaryFile() as fil: + fil.write(self.expected_json.encode()) + fil.seek(0) + from_path = u.load_json(fil.name) + self.assertEqual(from_path, json.loads(self.expected_json)) + # from invalid path + self.assertRaises(FileNotFoundError, u.load_json, 'does_not_exist.json')