From 1a58ab3da41f81ef7603c7540859a19615489261 Mon Sep 17 00:00:00 2001 From: Bao Nguyen Date: Wed, 30 Sep 2020 13:17:42 -0700 Subject: [PATCH] replace modeling script with notebook --- h1st/cli/project.py | 5 ++- h1st/cli/templates/notebook.txt | 43 +++-------------------- h1st/model_repository/model_repository.py | 8 +++++ tests/cli/test_cli.py | 5 --- 4 files changed, 15 insertions(+), 46 deletions(-) diff --git a/h1st/cli/project.py b/h1st/cli/project.py index 8cec07ae..a5004fe2 100644 --- a/h1st/cli/project.py +++ b/h1st/cli/project.py @@ -127,9 +127,8 @@ def new_project(project_name, base_path): model_file=f"{model_package}.py", ) - with open(tmppath / f"{project_name_snake_case}_modeling.py", "w") as f: - f.write(_render_template('modeling', { - 'SCRIPT_NAME': f'{project_name_snake_case}_modeling.py', + with open(tmppath / f"{project_name_snake_case}_notebook.ipynb", "w") as f: + f.write(_render_template('notebook', { 'MODEL_CLASS': model_name, 'MODEL_PACKAGE': model_package, })) diff --git a/h1st/cli/templates/notebook.txt b/h1st/cli/templates/notebook.txt index 1482328b..5539805d 100644 --- a/h1st/cli/templates/notebook.txt +++ b/h1st/cli/templates/notebook.txt @@ -6,33 +6,8 @@ "metadata": {}, "outputs": [], "source": [ - "### H1st Configuration ###\n", - "'''\n", - "This part configures the default variable names for various step of the modelling process.\n", - "Modify with care.\n", - "'''\n", - "VAR_MODEL_FILE = '$$MODEL_FILE$$.py'\n", - "VAR_MODEL = 'm'\n", - "VAR_LOADED_DATA = 'loaded_data'\n", - "VAR_PREPARED_DATA = 'prepared_data'\n", - "VAR_PREDICTED_VALUE = 'predicted_val'\n", - "\n", - "\n", - "%load_ext autoreload\n", - "%autoreload 3\n", - "import h1st as h1\n", - "h1.init()\n", - "### End H1st Configuration ###" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from $$PACKAGE_NAME$$.models.$$MODEL_FILE$$ import $$MODEL_NAME$$ \n", - "m = $$MODEL_NAME$$()" + "from $$MODEL_PACKAGE$$ import $$MODEL_CLASS$$ \n", + "m = $$MODEL_CLASS$$()" ] }, { @@ -70,8 +45,7 @@ "metadata": {}, "outputs": [], "source": [ - "eval_data = {} # Load evaluation data\n", - "m.evaluate(eval_data)" + "m.evaluate(prepared_data)" ] }, { @@ -89,16 +63,9 @@ "metadata": {}, "outputs": [], "source": [ - "predict_data = {} # Get your data \n", - "predicted_val = m.predict(predict_data)" + "input_data = {} # Get your data \n", + "m.predict(input_data)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/h1st/model_repository/model_repository.py b/h1st/model_repository/model_repository.py index 952a3e53..f7bbb724 100644 --- a/h1st/model_repository/model_repository.py +++ b/h1st/model_repository/model_repository.py @@ -332,6 +332,14 @@ def get_model_repo(cls, ref=None): except ModuleNotFoundError: repo_path = None + # in the new structure, the config file may be at root folder + if not repo_path: + try: + import config + repo_path = config.MODEL_REPO_PATH + except ModuleNotFoundError: + repo_path = None + if not repo_path: repo_path = os.environ.get('H1ST_MODEL_REPO_PATH', '') diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 8af8379c..9303d598 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -40,11 +40,6 @@ def test_cli_smoketest(self): cwd=tmpdir + '/AutoCyber', ) - subprocess.check_call( - ["python", "auto_cyber_modeling.py"], - cwd=tmpdir + '/AutoCyber', - ) - subprocess.check_call( ["python", "-m", "models.model2"], cwd=tmpdir + '/AutoCyber'