Skip to content

Commit

Permalink
[feat] Skip metrics check when run is known to yield false result (#3288
Browse files Browse the repository at this point in the history
)

* [feat] Skip metrics check when run is known to yeld false result

* [fix] Code style checks

* [fix] More styling errors
  • Loading branch information
alberttorosyan authored Feb 12, 2025
1 parent c6e0c7f commit 2d9f3b8
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 15 deletions.
151 changes: 151 additions & 0 deletions aim/sdk/query_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import ast

from typing import Any, List, Tuple


class Unknown(ast.AST):
pass


Unknown = Unknown() # create a single instance of <unknown> value node


class QueryExpressionTransformer(ast.NodeTransformer):
def __init__(self, *, var_names: List[str]):
self._var_names = var_names

def transform(self, expr: str) -> Tuple[str, bool]:
node = ast.parse(expr, mode='eval')
transformed = self.visit(node)
if transformed is Unknown:
return expr, False
else:
return ast.unparse(transformed), True

def visit_Expression(self, node: ast.Expression) -> Any:
node: ast.Expression = self.generic_visit(node)
if node.body is Unknown:
return Unknown
return node

def visit_Expr(self, node: ast.Expr) -> Any:
node: ast.Expr = self.generic_visit(node)
if node.value is Unknown:
return Unknown
return node

def visit_Constant(self, node: ast.Constant) -> Any:
return node

def visit_JoinedStr(self, node: ast.JoinedStr) -> Any:
node: ast.JoinedStr = self.generic_visit(node)
for val in node.values:
if val is Unknown:
return Unknown
return node

def visit_FormattedValue(self, node: ast.FormattedValue) -> Any:
node: ast.FormattedValue = self.generic_visit(node)
if node.value is Unknown:
return Unknown
return node

def visit_Name(self, node: ast.Name) -> Any:
if node.id in self._var_names:
return Unknown
else:
return node

def visit_Compare(self, node: ast.Compare) -> Any:
node: ast.Compare = self.generic_visit(node)
if node.left is Unknown:
return Unknown
for comp in node.comparators:
if comp is Unknown:
return Unknown
return node

def visit_List(self, node: ast.List) -> Any:
node: ast.List = self.generic_visit(node)
for sub in node.elts:
if sub is Unknown:
return Unknown
return node

def visit_Tuple(self, node: ast.Tuple) -> Any:
node: ast.Tuple = self.generic_visit(node)
for sub in node.elts:
if sub is Unknown:
return Unknown
return node

def visit_Dict(self, node: ast.Dict) -> Any:
node: ast.Dict = self.generic_visit(node)
for key in node.keys:
if key is Unknown:
return Unknown
for val in node.values:
if val is Unknown:
return Unknown
return node

def visit_BoolOp(self, node: ast.BoolOp) -> Any:
node: ast.BoolOp = self.generic_visit(node)
node_values = list(filter(lambda x: x is not Unknown, node.values))
if isinstance(node.op, ast.And):
if len(node_values) == 1:
return node_values[0]
elif len(node_values) == 0:
return Unknown
else:
if len(node_values) < len(node.values):
return Unknown
return ast.BoolOp(op=node.op, values=node_values)

def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
node: ast.UnaryOp = self.generic_visit(node)
if node.operand is Unknown:
return Unknown
return node

def visit_BinOp(self, node: ast.BinOp) -> Any:
node: ast.BinOp = self.generic_visit(node)
if node.left is Unknown or node.right is Unknown:
return Unknown
return node

def visit_IfExp(self, node: ast.IfExp) -> Any:
node: ast.IfExp = self.generic_visit(node)
if node.test is Unknown or node.body is Unknown or node.orelse is Unknown:
return Unknown
return node

def visit_Attribute(self, node: ast.Attribute) -> Any:
node: ast.Attribute = self.generic_visit(node)
if node.value is Unknown:
return Unknown
return node

def visit_Call(self, node: ast.Call) -> Any:
node: ast.Call = self.generic_visit(node)
if node.func is Unknown:
return Unknown
for arg in node.args:
if arg is Unknown:
return Unknown
for kwarg in node.keywords:
if kwarg is Unknown:
return Unknown
return node

def visit_Subscript(self, node: ast.Subscript) -> Any:
node: ast.Subscript = self.generic_visit(node)
if node.value is Unknown or node.slice is Unknown:
return Unknown
return node

def visit_Slice(self, node: ast.Slice) -> Any:
node: ast.Slice = self.generic_visit(node)
if node.lower is Unknown or node.upper is Unknown or node.step is Unknown:
return Unknown
return node
46 changes: 31 additions & 15 deletions aim/sdk/sequence_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Iterator

from tqdm import tqdm

from aim.sdk.query_analyzer import QueryExpressionTransformer
from aim.sdk.query_utils import RunView, SequenceView
from aim.sdk.sequence import Sequence
from aim.sdk.types import QueryReportMode
from aim.storage.query import RestrictedPythonQuery
from tqdm import tqdm


if TYPE_CHECKING:
from pandas import DataFrame

from aim.sdk.repo import Repo
from aim.sdk.run import Run
from pandas import DataFrame

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -170,20 +173,33 @@ def iter_runs(self) -> Iterator['SequenceCollection']:
if self.report_mode == QueryReportMode.PROGRESS_BAR:
progress_bar = tqdm(total=total_runs)

seq_var = self.seq_cls.sequence_name()
t = QueryExpressionTransformer(var_names=[seq_var, ])
run_expr, is_transformed = t.transform(self.query)
run_query = RestrictedPythonQuery(run_expr)

for run in runs_iterator:
seq_collection = SingleRunSequenceCollection(
run,
self.seq_cls,
self.query,
runs_proxy_cache=self.runs_proxy_cache,
timezone_offset=self._timezone_offset,
)
if self.report_mode == QueryReportMode.PROGRESS_TUPLE:
yield seq_collection, (runs_counter, total_runs)
else:
if self.report_mode == QueryReportMode.PROGRESS_BAR:
progress_bar.update(1)
yield seq_collection
check_run_sequences = True
if is_transformed:
run_view = RunView(run, runs_proxy_cache=self.runs_proxy_cache, timezone_offset=self._timezone_offset)
match = run_query.check(**{'run': run_view})
if not match:
check_run_sequences = False

if check_run_sequences:
seq_collection = SingleRunSequenceCollection(
run,
self.seq_cls,
self.query,
runs_proxy_cache=self.runs_proxy_cache,
timezone_offset=self._timezone_offset,
)
if self.report_mode == QueryReportMode.PROGRESS_TUPLE:
yield seq_collection, (runs_counter, total_runs)
else:
if self.report_mode == QueryReportMode.PROGRESS_BAR:
progress_bar.update(1)
yield seq_collection
runs_counter += 1

def iter(self) -> Iterator[Sequence]:
Expand Down

0 comments on commit 2d9f3b8

Please sign in to comment.