Skip to content

Commit

Permalink
[fix] Resolve tag addition issue from parallel runs (#3247)
Browse files Browse the repository at this point in the history
  • Loading branch information
mihran113 authored Dec 6, 2024
1 parent 61bee5d commit 979efe0
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 27 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
### Fixes:
- Fix aggregated metrics' computations (mihran113)
- Fix bug in RunStatusReporter raising non-deterministic RuntimeError exception (VassilisVassiliadis)

- Fix tag addition issue from parallel runs (mihran113)

## 3.26.1 Dec 3, 2024
- Re-upload after PyPI size limitation fix
Expand Down
2 changes: 1 addition & 1 deletion aim/cli/runs/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import tqdm

from aim.cli.runs.utils import make_zip_archive, match_runs, upload_repo_runs
from aim.sdk.repo import Repo
from aim.sdk.index_manager import RepoIndexManager
from aim.sdk.repo import Repo
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP
from psutil import cpu_count

Expand Down
6 changes: 2 additions & 4 deletions aim/sdk/reporter/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@

class FileManager(object):
@abstractmethod
def poll(self, pattern: str) -> Optional[str]:
...
def poll(self, pattern: str) -> Optional[str]: ...

@abstractmethod
def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None):
...
def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None): ...


class LocalFileManager(FileManager):
Expand Down
2 changes: 1 addition & 1 deletion aim/sdk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from aim.sdk.reporter import RunStatusReporter, ScheduledStatusReporter
from aim.sdk.reporter.file_manager import LocalFileManager
from aim.sdk.sequence import Sequence
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP
from aim.sdk.sequence_collection import SingleRunSequenceCollection
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP
from aim.sdk.tracker import RunTracker
from aim.sdk.types import AimObject
from aim.sdk.utils import (
Expand Down
9 changes: 3 additions & 6 deletions aim/sdk/run_status_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,13 @@ def __init__(self, *, obj_idx: Optional[str] = None, rank: Optional[int] = None,
self.message = message

@abstractmethod
def is_sent(self):
...
def is_sent(self): ...

@abstractmethod
def update_last_sent(self):
...
def update_last_sent(self): ...

@abstractmethod
def get_msg_details(self):
...
def get_msg_details(self): ...


class StatusNotification(Notification):
Expand Down
2 changes: 1 addition & 1 deletion aim/sdk/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

from aim.sdk.configs import AIM_ENABLE_TRACKING_THREAD
from aim.sdk.num_utils import convert_to_py_number, is_number
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP
from aim.sdk.utils import check_types_compatibility, get_object_typename
from aim.storage.context import Context
from aim.storage.hashing import hash_auto
from aim.storage.object import CustomObject
from aim.storage.types import AimObject
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP


if TYPE_CHECKING:
Expand Down
25 changes: 18 additions & 7 deletions aim/storage/structured/sql_engine/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,25 @@ def tags(self) -> List[str]:
return [tag.name for tag in self.tags_obj]

def add_tag(self, value: str) -> None:
def unsafe_add_tag():
if value is None:
tag = None
else:
tag = session.query(TagModel).filter(TagModel.name == value).first()
if not tag:
tag = TagModel(value)
session.add(tag)
self._model.tags.append(tag)
session.add(self._model)

session = self._session
tag = session.query(TagModel).filter(TagModel.name == value).first()
if not tag:
tag = TagModel(value)
session.add(tag)
self._model.tags.append(tag)
session.add(self._model)
session_commit_or_flush(session)
unsafe_add_tag()
try:
session_commit_or_flush(session)
except IntegrityError:
session.rollback()
unsafe_add_tag()
session_commit_or_flush(session)

def remove_tag(self, tag_name: str) -> bool:
session = self._session
Expand Down
8 changes: 2 additions & 6 deletions troubleshooting/base_project_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import sys
import time

import tqdm

import aim
import tqdm


def count_metrics(run):
Expand All @@ -24,10 +23,7 @@ def count_dict_keys(params):
Count the number of leaf nodes in a nested dictionary.
A leaf node is a value that is not a dictionary.
"""
return sum(
count_dict_keys(value) if isinstance(value, dict) else 1
for value in params.values()
)
return sum(count_dict_keys(value) if isinstance(value, dict) else 1 for value in params.values())


parser = argparse.ArgumentParser(description='Process command line arguments.')
Expand Down

0 comments on commit 979efe0

Please sign in to comment.