diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0f97a06..d892e29 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,18 +5,18 @@ on: [push] jobs: test: - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 strategy: max-parallel: 2 fail-fast: false matrix: # There is is no Python 3.4 on ubuntu-latest - python-version: [2.7, 3.7] + python-version: [3.11] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -30,13 +30,13 @@ jobs: run: | pytest tests lint: - runs-on: ubuntu-18.04 + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Set up Python 3.7 - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: 3.11 - name: Install pre-commit run: | python -m pip install --upgrade pip diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 36915ce..cb0935f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,20 +1,20 @@ repos: - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/ambv/black - rev: 22.10.0 + rev: 23.11.0 hooks: - id: black - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 6.1.0 hooks: - id: flake8 args: ['--config=.flake8'] - additional_dependencies: ['flake8-coding==1.3.2', 'flake8-copyright==0.2.3', 'flake8-debugger==4.1.2', 'flake8-mypy==17.8.0'] + additional_dependencies: ['flake8-coding==1.3.2', 'flake8-copyright==0.2.4', 'flake8-debugger==4.1.2'] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.5.0 hooks: - id: check-json - id: check-merge-conflict @@ -24,7 +24,7 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/codespell-project/codespell - rev: v2.2.2 + rev: v2.2.6 hooks: - id: codespell exclude_types: [json] diff --git a/comet_for_mlflow/comet_for_mlflow.py b/comet_for_mlflow/comet_for_mlflow.py index 8f132bb..52c656b 100644 --- a/comet_for_mlflow/comet_for_mlflow.py +++ b/comet_for_mlflow/comet_for_mlflow.py @@ -31,7 +31,6 @@ import tempfile import traceback from os.path import abspath -from typing import Optional from zipfile import ZipFile from comet_ml import API @@ -49,7 +48,6 @@ from .compat import ( get_artifact_repository, - get_mlflow_model_name, get_mlflow_run_id, search_mlflow_store_experiments, search_mlflow_store_runs, @@ -104,7 +102,6 @@ def __init__( answer, email, ): - # type: (bool, str, str, bool, str, Optional[bool], str) -> None self.answer = answer self.email = email self.config = get_config() @@ -164,7 +161,6 @@ def prepare(self): # First prepare all the data except the metadata as we need a project name for experiment_number, experiment in enumerate(self.mlflow_experiments): - experiment_name = experiment.experiment_id if experiment.name: experiment_name = experiment.name @@ -381,17 +377,11 @@ def prepare_single_mlflow_run(self, run, original_experiment_name): LOGGER.debug("### Importing artifacts") artifact_store = get_artifact_repository(run.info.artifact_uri) - # List all the registered models if possible - models_prefixes = {} - if self.model_registry_store: - query = "run_id='%s'" % run.info.run_id - registered_models = self.model_registry_store.search_model_versions( - query - ) + # Get all of the artifact list as we need to search for the + # specific MLModel file to detect models + all_artifacts = list(walk_run_artifacts(artifact_store)) - for model in registered_models: - model_relpath = os.path.relpath(model.source, run.info.artifact_uri) - models_prefixes[model_relpath] = model + models_prefixes = self.get_model_prefixes(all_artifacts) for artifact in walk_run_artifacts(artifact_store): artifact_path = artifact.path @@ -405,27 +395,33 @@ def prepare_single_mlflow_run(self, run, original_experiment_name): self.summary["artifacts"] += 1 # Check if it's belonging to one of the registered model - matching_model = None - for model_prefix, model in models_prefixes.items(): + matching_model_name = None + for model_prefix, model_name in models_prefixes.items(): if artifact_path.startswith(model_prefix): - matching_model = model + matching_model_name = model_name # We should match at most one model break - if matching_model: - model_name = get_mlflow_model_name(matching_model) - + if matching_model_name: prefix = "models/" + if artifact_path.startswith(prefix): comet_artifact_path = artifact_path[len(prefix) :] else: comet_artifact_path = artifact_path + if comet_artifact_path.startswith(model_prefix): + comet_artifact_path = comet_artifact_path[ + len(model_prefix) + 1 : + ] + else: + comet_artifact_path = comet_artifact_path + json_writer.log_artifact_as_model( local_artifact_path, comet_artifact_path, run_start_time, - model_name, + matching_model_name, ) else: json_writer.log_artifact_as_asset( @@ -436,6 +432,22 @@ def prepare_single_mlflow_run(self, run, original_experiment_name): return self.compress_archive(run.info.run_id) + def get_model_prefixes(self, artifact_list): + """Return the model names from a list of artifacts""" + + # Dict of model prefix to model name + models = {} + + for artifact in artifact_list: + # Similar logic to MLFlw UI + # https://github.com/mlflow/mlflow/blob/v2.2.2/mlflow/server/js/src/experiment-tracking/components/ArtifactView.js#L253 + parts = artifact.path.split("/") + if parts[-1].lower() == "MLmodel".lower(): + # Comet don't support model names with / in their name + models["/".join(parts[:-1])] = parts[-2] + + return models + def upload(self, prepared_data): LOGGER.info("# Start uploading data to Comet.ml") diff --git a/comet_for_mlflow/compat.py b/comet_for_mlflow/compat.py index 8ab2619..f82d613 100644 --- a/comet_for_mlflow/compat.py +++ b/comet_for_mlflow/compat.py @@ -67,10 +67,3 @@ def get_mlflow_run_id(mlflow_run): return mlflow_run.info.run_id else: return mlflow_run.run_id - - -def get_mlflow_model_name(mlflow_model): - if hasattr(mlflow_model, "name"): - return mlflow_model.name - else: - return mlflow_model.registered_model.name diff --git a/comet_for_mlflow/file_writer.py b/comet_for_mlflow/file_writer.py index 10a8871..725e048 100644 --- a/comet_for_mlflow/file_writer.py +++ b/comet_for_mlflow/file_writer.py @@ -294,7 +294,6 @@ def log_artifact_as_visualization( def log_artifact_as_model( self, artifact_path, artifact_name, timestamp, model_name ): - _, extension = os.path.splitext( artifact_path ) # TODO: Support extension less file names? @@ -328,7 +327,6 @@ def log_artifact_as_model( self.write_line_data(data) def log_artifact_as_asset(self, artifact_path, artifact_name, timestamp): - _, extension = os.path.splitext( artifact_path ) # TODO: Support extension less file names?