From 2d9f3b8b0f4ef23fc94818847e48307dbc5f6564 Mon Sep 17 00:00:00 2001 From: Albert Torosyan <32957250+alberttorosyan@users.noreply.github.com> Date: Wed, 12 Feb 2025 11:24:08 +0400 Subject: [PATCH] [feat] Skip metrics check when run is known to yield false result (#3288) * [feat] Skip metrics check when run is known to yeld false result * [fix] Code style checks * [fix] More styling errors --- aim/sdk/query_analyzer.py | 151 +++++++++++++++++++++++++++++++++ aim/sdk/sequence_collection.py | 46 ++++++---- 2 files changed, 182 insertions(+), 15 deletions(-) create mode 100644 aim/sdk/query_analyzer.py diff --git a/aim/sdk/query_analyzer.py b/aim/sdk/query_analyzer.py new file mode 100644 index 0000000000..1930bcf78d --- /dev/null +++ b/aim/sdk/query_analyzer.py @@ -0,0 +1,151 @@ +import ast + +from typing import Any, List, Tuple + + +class Unknown(ast.AST): + pass + + +Unknown = Unknown() # create a single instance of value node + + +class QueryExpressionTransformer(ast.NodeTransformer): + def __init__(self, *, var_names: List[str]): + self._var_names = var_names + + def transform(self, expr: str) -> Tuple[str, bool]: + node = ast.parse(expr, mode='eval') + transformed = self.visit(node) + if transformed is Unknown: + return expr, False + else: + return ast.unparse(transformed), True + + def visit_Expression(self, node: ast.Expression) -> Any: + node: ast.Expression = self.generic_visit(node) + if node.body is Unknown: + return Unknown + return node + + def visit_Expr(self, node: ast.Expr) -> Any: + node: ast.Expr = self.generic_visit(node) + if node.value is Unknown: + return Unknown + return node + + def visit_Constant(self, node: ast.Constant) -> Any: + return node + + def visit_JoinedStr(self, node: ast.JoinedStr) -> Any: + node: ast.JoinedStr = self.generic_visit(node) + for val in node.values: + if val is Unknown: + return Unknown + return node + + def visit_FormattedValue(self, node: ast.FormattedValue) -> Any: + node: ast.FormattedValue = self.generic_visit(node) + if node.value is Unknown: + return Unknown + return node + + def visit_Name(self, node: ast.Name) -> Any: + if node.id in self._var_names: + return Unknown + else: + return node + + def visit_Compare(self, node: ast.Compare) -> Any: + node: ast.Compare = self.generic_visit(node) + if node.left is Unknown: + return Unknown + for comp in node.comparators: + if comp is Unknown: + return Unknown + return node + + def visit_List(self, node: ast.List) -> Any: + node: ast.List = self.generic_visit(node) + for sub in node.elts: + if sub is Unknown: + return Unknown + return node + + def visit_Tuple(self, node: ast.Tuple) -> Any: + node: ast.Tuple = self.generic_visit(node) + for sub in node.elts: + if sub is Unknown: + return Unknown + return node + + def visit_Dict(self, node: ast.Dict) -> Any: + node: ast.Dict = self.generic_visit(node) + for key in node.keys: + if key is Unknown: + return Unknown + for val in node.values: + if val is Unknown: + return Unknown + return node + + def visit_BoolOp(self, node: ast.BoolOp) -> Any: + node: ast.BoolOp = self.generic_visit(node) + node_values = list(filter(lambda x: x is not Unknown, node.values)) + if isinstance(node.op, ast.And): + if len(node_values) == 1: + return node_values[0] + elif len(node_values) == 0: + return Unknown + else: + if len(node_values) < len(node.values): + return Unknown + return ast.BoolOp(op=node.op, values=node_values) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: + node: ast.UnaryOp = self.generic_visit(node) + if node.operand is Unknown: + return Unknown + return node + + def visit_BinOp(self, node: ast.BinOp) -> Any: + node: ast.BinOp = self.generic_visit(node) + if node.left is Unknown or node.right is Unknown: + return Unknown + return node + + def visit_IfExp(self, node: ast.IfExp) -> Any: + node: ast.IfExp = self.generic_visit(node) + if node.test is Unknown or node.body is Unknown or node.orelse is Unknown: + return Unknown + return node + + def visit_Attribute(self, node: ast.Attribute) -> Any: + node: ast.Attribute = self.generic_visit(node) + if node.value is Unknown: + return Unknown + return node + + def visit_Call(self, node: ast.Call) -> Any: + node: ast.Call = self.generic_visit(node) + if node.func is Unknown: + return Unknown + for arg in node.args: + if arg is Unknown: + return Unknown + for kwarg in node.keywords: + if kwarg is Unknown: + return Unknown + return node + + def visit_Subscript(self, node: ast.Subscript) -> Any: + node: ast.Subscript = self.generic_visit(node) + if node.value is Unknown or node.slice is Unknown: + return Unknown + return node + + def visit_Slice(self, node: ast.Slice) -> Any: + node: ast.Slice = self.generic_visit(node) + if node.lower is Unknown or node.upper is Unknown or node.step is Unknown: + return Unknown + return node diff --git a/aim/sdk/sequence_collection.py b/aim/sdk/sequence_collection.py index 5738a8e280..3c4699bc27 100644 --- a/aim/sdk/sequence_collection.py +++ b/aim/sdk/sequence_collection.py @@ -3,17 +3,20 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Iterator +from tqdm import tqdm + +from aim.sdk.query_analyzer import QueryExpressionTransformer from aim.sdk.query_utils import RunView, SequenceView from aim.sdk.sequence import Sequence from aim.sdk.types import QueryReportMode from aim.storage.query import RestrictedPythonQuery -from tqdm import tqdm if TYPE_CHECKING: + from pandas import DataFrame + from aim.sdk.repo import Repo from aim.sdk.run import Run - from pandas import DataFrame logger = logging.getLogger(__name__) @@ -170,20 +173,33 @@ def iter_runs(self) -> Iterator['SequenceCollection']: if self.report_mode == QueryReportMode.PROGRESS_BAR: progress_bar = tqdm(total=total_runs) + seq_var = self.seq_cls.sequence_name() + t = QueryExpressionTransformer(var_names=[seq_var, ]) + run_expr, is_transformed = t.transform(self.query) + run_query = RestrictedPythonQuery(run_expr) + for run in runs_iterator: - seq_collection = SingleRunSequenceCollection( - run, - self.seq_cls, - self.query, - runs_proxy_cache=self.runs_proxy_cache, - timezone_offset=self._timezone_offset, - ) - if self.report_mode == QueryReportMode.PROGRESS_TUPLE: - yield seq_collection, (runs_counter, total_runs) - else: - if self.report_mode == QueryReportMode.PROGRESS_BAR: - progress_bar.update(1) - yield seq_collection + check_run_sequences = True + if is_transformed: + run_view = RunView(run, runs_proxy_cache=self.runs_proxy_cache, timezone_offset=self._timezone_offset) + match = run_query.check(**{'run': run_view}) + if not match: + check_run_sequences = False + + if check_run_sequences: + seq_collection = SingleRunSequenceCollection( + run, + self.seq_cls, + self.query, + runs_proxy_cache=self.runs_proxy_cache, + timezone_offset=self._timezone_offset, + ) + if self.report_mode == QueryReportMode.PROGRESS_TUPLE: + yield seq_collection, (runs_counter, total_runs) + else: + if self.report_mode == QueryReportMode.PROGRESS_BAR: + progress_bar.update(1) + yield seq_collection runs_counter += 1 def iter(self) -> Iterator[Sequence]: