Skip to content

Commit

Permalink
Output checkers refactor + error_contains checker
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Jan 30, 2025
1 parent 205481a commit ebad265
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 131 deletions.
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,19 @@ predict:
fixed_inputs: {}
predict_timeout: 300
test_cases:
- exact_string: <exact string match>
inputs:
- inputs:
<input1>: <value1>
exact_string: <exact string match>
- inputs:
<input2>: <value2>
match_url: <match output image against url>
- inputs:
<input3>: <value3>
match_prompt: <match output using AI prompt, e.g. 'an image of a cat'>
- inputs:
<input4>: <value4>
error_contains: <assert error and that the error message contains a string>
test_hardware: <hardware, e.g. cpu>
test_model: <test model, or empty to append '-test' to model>
train:
Expand All @@ -192,15 +196,18 @@ train:
iterations: 10
fixed_inputs: {}
test_cases:
- exact_string: <exact string match>
inputs:
- inputs:
<input1>: <value1>
exact_string: <exact string match>
- inputs:
<input2>: <value2>
match_url: <match output image against url>
- inputs:
<input3>: <value3>
match_prompt: <match output using AI prompt, e.g. 'an image of a cat'>
- inputs:
<input4>: <value4>
error_contains: <assert error and that the error message contains a string>
train_timeout: 300
# values between < and > should be edited
Expand Down
5 changes: 3 additions & 2 deletions cog_safe_push/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ class TestCase(BaseModel):
exact_string: str | None = None
match_url: str | None = None
match_prompt: str | None = None
error_contains: str | None = None

@model_validator(mode="after")
def check_mutually_exclusive(self):
set_fields = sum(
getattr(self, field) is not None
for field in ["exact_string", "match_url", "match_prompt"]
for field in ["exact_string", "match_url", "match_prompt", "error_contains"]
)
if set_fields > 1:
raise ArgumentError(
"At most one of 'exact_string', 'match_url', or 'match_prompt' must be set"
"At most one of 'exact_string', 'match_url', 'match_prompt', or 'error_contains' must be set"
)
return self

Expand Down
9 changes: 4 additions & 5 deletions cog_safe_push/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ class PredictionTimeoutError(CogSafePushError):
pass


class PredictionFailedError(CogSafePushError):
pass


class TestCaseFailedError(CogSafePushError):
pass
__test__ = False

def __init__(self, message):
super().__init__(f"Test case failed: {message}")


class AIError(Exception):
Expand Down
4 changes: 2 additions & 2 deletions cog_safe_push/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def vvv(message):

def set_verbosity(verbosity):
global level
if verbosity == 0:
if verbosity <= 0:
level = INFO
if verbosity == 1:
level = VERBOSE1
if verbosity == 2:
level = VERBOSE2
if verbosity == 3:
if verbosity >= 3:
level = VERBOSE3
39 changes: 25 additions & 14 deletions cog_safe_push/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@
)
from .config import TestCase as ConfigTestCase
from .exceptions import ArgumentError, CogSafePushError
from .output_checkers import (
AIChecker,
ErrorContainsChecker,
ExactStringChecker,
MatchURLChecker,
NoChecker,
OutputChecker,
)
from .task_context import TaskContext, make_task_context
from .tasks import (
AIOutput,
CheckOutputsMatch,
ExactStringOutput,
ExactURLOutput,
ExpectedOutput,
FuzzModel,
MakeFuzzInputs,
RunTestCase,
Expand Down Expand Up @@ -258,7 +262,7 @@ def cog_safe_push(
train: bool = False,
do_compare_outputs: bool = True,
predict_timeout: int = 300,
test_cases: list[tuple[dict[str, Any], ExpectedOutput]] = [],
test_cases: list[tuple[dict[str, Any], OutputChecker]] = [],
fuzz_fixed_inputs: dict = {},
fuzz_disabled_inputs: list = [],
fuzz_iterations: int = 10,
Expand Down Expand Up @@ -313,12 +317,12 @@ def cog_safe_push(
)

if test_cases:
for inputs, output in test_cases:
for inputs, checker in test_cases:
tasks.append(
RunTestCase(
context=task_context,
inputs=inputs,
output=output,
checker=checker,
predict_timeout=predict_timeout,
)
)
Expand Down Expand Up @@ -460,21 +464,24 @@ def parse_test_case(test_case_str: str) -> ConfigTestCase:

def parse_config_test_case(
config_test_case: ConfigTestCase,
) -> tuple[dict[str, Any], ExpectedOutput]:
output = None
) -> tuple[dict[str, Any], OutputChecker]:
if config_test_case.exact_string:
output = ExactStringOutput(string=config_test_case.exact_string)
checker = ExactStringChecker(string=config_test_case.exact_string)
elif config_test_case.match_url:
output = ExactURLOutput(url=config_test_case.match_url)
checker = MatchURLChecker(url=config_test_case.match_url)
elif config_test_case.match_prompt:
output = AIOutput(prompt=config_test_case.match_prompt)
checker = AIChecker(prompt=config_test_case.match_prompt)
elif config_test_case.error_contains:
checker = ErrorContainsChecker(string=config_test_case.error_contains)
else:
checker = NoChecker()

return (config_test_case.inputs, output)
return (config_test_case.inputs, checker)


def parse_config_test_cases(
config_test_cases: list[ConfigTestCase],
) -> list[tuple[dict[str, Any], ExpectedOutput]]:
) -> list[tuple[dict[str, Any], OutputChecker]]:
return [parse_config_test_case(tc) for tc in config_test_cases]


Expand Down Expand Up @@ -519,6 +526,10 @@ def print_help_config():
inputs={"<input3>": "<value3>"},
match_prompt="<match output using AI prompt, e.g. 'an image of a cat'>",
),
ConfigTestCase(
inputs={"<input3>": "<value3>"},
error_contains="<assert that these inputs throws an error, and that the error message contains a string>",
),
],
),
).model_dump(exclude_none=True),
Expand Down
100 changes: 100 additions & 0 deletions cog_safe_push/output_checkers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from dataclasses import dataclass
from typing import Any, Protocol

from . import log
from .exceptions import (
AIError,
TestCaseFailedError,
)
from .match_outputs import is_url, output_matches_prompt, urls_match
from .utils import truncate


class OutputChecker(Protocol):
async def __call__(self, output: Any | None, error: str | None) -> None: ...


@dataclass
class NoChecker(OutputChecker):
async def __call__(self, _: Any | None, error: str | None) -> None:
check_no_error(error)


@dataclass
class ExactStringChecker(OutputChecker):
string: str

async def __call__(self, output: Any | None, error: str | None) -> None:
check_no_error(error)

if not isinstance(output, str):
raise TestCaseFailedError(f"Expected string, got {truncate(output, 200)}")

if output != self.string:
raise TestCaseFailedError(
f"Expected '{self.string}', got '{truncate(output, 200)}'"
)


@dataclass
class MatchURLChecker(OutputChecker):
url: str

async def __call__(self, output: Any | None, error: str | None) -> None:
check_no_error(error)

output_url = None
if isinstance(output, str) and is_url(output):
output_url = output
if (
isinstance(output, list)
and len(output) == 1
and isinstance(output[0], str)
and is_url(output[0])
):
output_url = output[0]
if output_url is not None:
matches, error = await urls_match(
self.url, output_url, is_deterministic=True
)
if not matches:
raise TestCaseFailedError(
f"File at URL {self.url} does not match file at URL {output_url}. {error}"
)
log.info(f"File at URL {self.url} matched file at URL {output_url}")
else:
raise TestCaseFailedError(f"Expected URL, got '{truncate(output, 200)}'")


@dataclass
class AIChecker(OutputChecker):
prompt: str

async def __call__(self, output: Any | None, error: str | None) -> None:
check_no_error(error)

try:
matches, error = await output_matches_prompt(output, self.prompt)
if not matches:
raise TestCaseFailedError(error)
except AIError as e:
raise TestCaseFailedError(f"AI error: {str(e)}")


@dataclass
class ErrorContainsChecker(OutputChecker):
string: str

async def __call__(self, _: Any | None, error: str | None) -> None:
if error is None:
raise TestCaseFailedError("Expected error, prediction succeeded")

if self.string not in error:
raise TestCaseFailedError(
f"Expected error to contain {self.string}, got {error}"
)


def check_no_error(error: str | None) -> None:
if error is not None:
raise TestCaseFailedError(f"Prediction raised unexpected error: {error}")
19 changes: 7 additions & 12 deletions cog_safe_push/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from . import ai, log
from .exceptions import (
AIError,
PredictionFailedError,
PredictionTimeoutError,
)
from .utils import truncate


async def make_predict_inputs(
Expand Down Expand Up @@ -221,7 +221,7 @@ async def predict(
train_destination: Model | None,
inputs: dict,
timeout_seconds: float,
):
) -> tuple[Any | None, str | None]:
log.vv(
f"Running {'training' if train else 'prediction'} with inputs:\n{json.dumps(inputs, indent=2)}"
)
Expand Down Expand Up @@ -259,22 +259,17 @@ async def predict(
raise PredictionTimeoutError()
prediction.reload()

duration = time.time() - start_time

if prediction.status == "failed":
raise PredictionFailedError(prediction.error)
log.v(f"Got error: {prediction.error} ({duration:.2f} sec)")
return None, prediction.error

duration = time.time() - start_time
log.v(f"Got output: {truncate(prediction.output)} ({duration:.2f} sec)")

output = prediction.output

if _has_output_iterator_array_type(version):
output = "".join(cast(list[str], output))

return output


def truncate(s, max_length=500) -> str:
s = str(s)
if len(s) <= max_length:
return s
return s[:max_length] + "..."
return output, None
Loading

0 comments on commit ebad265

Please sign in to comment.