Skip to content
This repository has been archived by the owner on Jan 9, 2025. It is now read-only.

Commit

Permalink
Merge pull request #30 from tsdataclinic/feature/allowed-values
Browse files Browse the repository at this point in the history
Add ability to limit a column with a set of allowed values
  • Loading branch information
Juan Pablo Sarmiento authored Nov 15, 2024
2 parents 0232ffa + 2198b01 commit 64554ec
Show file tree
Hide file tree
Showing 11 changed files with 370 additions and 105 deletions.
5 changes: 3 additions & 2 deletions server/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
)
from server.models.workflow.db_model import DBWorkflow
from server.models.workflow.workflow_schema import CsvData, WorkflowSchema
from server.workflow_runner.workflow_runner import process_workflow, WorkflowParamValue
from server.workflow_runner.workflow_runner import process_workflow
from server.workflow_runner.validators import WorkflowParamValue

LOG = logging.getLogger(__name__)

Expand All @@ -53,7 +54,7 @@ class Settings(BaseSettings):
AZURE_POLICY_AUTH_NAME: str = Field(default="")
AZURE_B2C_SCOPES: str = Field(default="")

model_config: SettingsConfigDict = SettingsConfigDict(
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore
env_file=".env.server", case_sensitive=True
)

Expand Down
8 changes: 7 additions & 1 deletion server/models/workflow/api_schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Workflow schemas that are used in the API."""

from datetime import datetime
from typing import Any

Expand Down Expand Up @@ -35,6 +36,7 @@ class WorkflowUpdate(FullWorkflow):

pass


class ValidationFailure(BaseModel):
"""
A validation failure with a message.
Expand All @@ -44,13 +46,17 @@ class ValidationFailure(BaseModel):
- row_number (int | None) -- The row number of the error. Or None if there
is no row number (e.g. if this is a file type error).
"""

message: str
row_number: int | None = Field(default=None, serialization_alias="rowNumber")


class WorkflowRunReport(BaseModel):
"""Report for a server-side run of a workflow."""

row_count: int = Field(serialization_alias="rowCount")
filename: str
workflow_id: str = Field(serialization_alias="workflowId")
validation_failures: list[ValidationFailure] = Field(serialization_alias="validationFailures")
validation_failures: list[ValidationFailure] = Field(
serialization_alias="validationFailures"
)
62 changes: 41 additions & 21 deletions server/tests/workflow_runner/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
from unittest.mock import MagicMock

from server.models.workflow.workflow_schema import (
WorkflowParam,
BasicFieldDataTypeSchema,
FieldSchema,
FileTypeValidation,
ParamReference,
RowCountValidation,
FieldsetSchema,
TimestampDataTypeSchema,
WorkflowParam,
)
from server.workflow_runner.validators import (
ValidationFailure,
WorkflowParamValue,
_check_csv_columns,
_validate_field,
validate_file_type,
Expand Down Expand Up @@ -242,9 +245,9 @@ def test_validate_field_case_insensitive(self):
allow_empty_values=False,
allowed_values=None,
)
self.assertEqual(_validate_field(1, {"Name": "John"}, field, {}), [])
self.assertEqual(_validate_field(1, {"Name": "John"}, field, {}, {}), [])
self.assertEqual(
_validate_field(1, {"Name": ""}, field, {}),
_validate_field(1, {"Name": ""}, field, {}, {}),
[
ValidationFailure(
row_number=1, message="Empty value for the field 'name'"
Expand All @@ -258,26 +261,26 @@ def test_validate_field_allow_empty(self):
"name",
allow_empty_values=True,
)
self.assertEqual(_validate_field(1, {"name": "John"}, field, {}), [])
self.assertEqual(_validate_field(1, {"name": ""}, field, {}), [])
self.assertEqual(_validate_field(1, {"name": None}, field, {}), [])
self.assertEqual(_validate_field(1, {"name": "John"}, field, {}, {}), [])
self.assertEqual(_validate_field(1, {"name": ""}, field, {}, {}), [])
self.assertEqual(_validate_field(1, {"name": None}, field, {}, {}), [])

# now disallow empty values
field = mock_field_schema(
"name",
allow_empty_values=False,
)
self.assertEqual(_validate_field(1, {"name": "John"}, field, {}), [])
self.assertEqual(_validate_field(1, {"name": "John"}, field, {}, {}), [])
self.assertEqual(
_validate_field(1, {"name": ""}, field, {}),
_validate_field(1, {"name": ""}, field, {}, {}),
[
ValidationFailure(
row_number=1, message="Empty value for the field 'name'"
)
],
)
self.assertEqual(
_validate_field(1, {"name": None}, field, {}),
_validate_field(1, {"name": None}, field, {}, {}),
[
ValidationFailure(
row_number=1, message="Empty value for the field 'name'"
Expand All @@ -293,10 +296,10 @@ def test_validate_field_allowed_values(self):
allow_empty_values=False,
allowed_values=["John", "Jane"],
)
self.assertEqual(_validate_field(1, {"name": "John"}, field, {}), [])
self.assertEqual(_validate_field(1, {"name": "Jane"}, field, {}), [])
self.assertEqual(_validate_field(1, {"name": "John"}, field, {}, {}), [])
self.assertEqual(_validate_field(1, {"name": "Jane"}, field, {}, {}), [])
self.assertEqual(
_validate_field(1, {"name": "Bob"}, field, {}),
_validate_field(1, {"name": "Bob"}, field, {}, {}),
[
ValidationFailure(
row_number=1, message="Value 'Bob' is not allowed for field 'name'"
Expand All @@ -310,13 +313,27 @@ def test_validate_field_allowed_values_from_param(self):
required=True,
case_sensitive=True,
allow_empty_values=False,
allowed_values=ParamReference(paramId="allowed_names"),
allowed_values=ParamReference(paramId="allowed_names_uuid"),
)
param_schemas = {
"allowed_names_uuid": WorkflowParam(
id="allowed_names_uuid",
name="allowed_names",
displayName="Allowed names",
required=True,
description="",
type="string list",
)
}
params: dict[str, WorkflowParamValue] = {"allowed_names": ["John", "Jane"]}
self.assertEqual(
_validate_field(1, {"name": "John"}, field, param_schemas, params), []
)
self.assertEqual(
_validate_field(1, {"name": "Jane"}, field, param_schemas, params), []
)
params = {"allowed_names": ["John", "Jane"]}
self.assertEqual(_validate_field(1, {"name": "John"}, field, params), [])
self.assertEqual(_validate_field(1, {"name": "Jane"}, field, params), [])
self.assertEqual(
_validate_field(1, {"name": "Bob"}, field, params),
_validate_field(1, {"name": "Bob"}, field, param_schemas, params),
[
ValidationFailure(
row_number=1, message="Value 'Bob' is not allowed for field 'name'"
Expand All @@ -333,10 +350,10 @@ def test_validate_field_number_type(self):
allowed_values=None,
data_type_validation=BasicFieldDataTypeSchema(dataType="number"),
)
self.assertEqual(_validate_field(1, {"age": "42"}, field, {}), [])
self.assertEqual(_validate_field(1, {"age": "42.0"}, field, {}), [])
self.assertEqual(_validate_field(1, {"age": "42"}, field, {}, {}), [])
self.assertEqual(_validate_field(1, {"age": "42.0"}, field, {}, {}), [])
self.assertEqual(
_validate_field(1, {"age": "42.0.0"}, field, {}),
_validate_field(1, {"age": "42.0.0"}, field, {}, {}),
[
ValidationFailure(
row_number=1,
Expand All @@ -356,9 +373,9 @@ def test_validate_field_timestamp_type(self):
dataType="timestamp", dateTimeFormat="%Y-%m-%d"
),
)
self.assertEqual(_validate_field(1, {"date": "2021-01-01"}, field, {}), [])
self.assertEqual(_validate_field(1, {"date": "2021-01-01"}, field, {}, {}), [])
self.assertEqual(
_validate_field(1, {"date": "2021-01-01 00:00:00.000"}, field, {}),
_validate_field(1, {"date": "2021-01-01 00:00:00.000"}, field, {}, {}),
[
ValidationFailure(
row_number=1,
Expand Down Expand Up @@ -406,6 +423,7 @@ def test_validate_fieldset_with_bad_data(self):
[{"name": "John", "age": "42", "city": "New York"}],
fieldset_schema,
{},
{},
),
[],
)
Expand All @@ -415,6 +433,7 @@ def test_validate_fieldset_with_bad_data(self):
[{"name": "John", "age": "42", "city": ""}],
fieldset_schema,
{},
{},
),
[
ValidationFailure(
Expand All @@ -428,6 +447,7 @@ def test_validate_fieldset_with_bad_data(self):
[{"name": "John", "age": None, "city": "New York"}],
fieldset_schema,
{},
{},
),
[
ValidationFailure(
Expand Down
51 changes: 38 additions & 13 deletions server/workflow_runner/validators.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
from typing import Any
from dataclasses import dataclass
import datetime

from frictionless import validate, Resource
from pydantic import BaseModel, Field

from server.models.workflow.workflow_schema import (
BasicFieldDataTypeSchema,
FieldSchema,
FieldsetSchema,
FileTypeValidation,
RowCountValidation,
TimestampDataTypeSchema,
CsvData,
WorkflowParam,
)
from server.models.workflow.api_schemas import ValidationFailure
from .exceptions import ParameterDefinitionNotFoundException

WorkflowParamValue = int | str | list[str] | None


def _get_param_schema_by_id(
param_schemas: dict[str, WorkflowParam], param_id: str
) -> WorkflowParam:
if param_id in param_schemas:
return param_schemas[param_id]
raise ParameterDefinitionNotFoundException(
f"Parameter definition for id '{param_id}' not found in schema."
)


def parse_frictionless(
Expand Down Expand Up @@ -120,7 +128,11 @@ def _check_csv_columns(


def _validate_field(
row_num: int, row: dict, field: FieldSchema, params: dict[str, Any]
row_num: int,
row: dict,
field: FieldSchema,
param_schemas: dict[str, WorkflowParam],
param_values: dict[str, WorkflowParamValue],
) -> list[ValidationFailure]:
"""Validate a field in a row."""
validations = []
Expand All @@ -147,7 +159,7 @@ def _validate_field(
pass # no additional validation needed
case BasicFieldDataTypeSchema(data_type="number"):
try:
float(value)
float(value) # type: ignore
# TypeError is raised if value is None
except (TypeError, ValueError) as e:
# if value is None, and we haven't already added a validation failure for a missing value,
Expand All @@ -165,7 +177,7 @@ def _validate_field(
data_type="timestamp", date_time_format=date_time_format
):
try:
datetime.datetime.strptime(value, date_time_format)
datetime.datetime.strptime(value, date_time_format) # type: ignore
except ValueError:
validations.append(
ValidationFailure(
Expand All @@ -179,8 +191,12 @@ def _validate_field(
if isinstance(field.allowed_values, list):
allowed_values = field.allowed_values
else:
allowed_values = params[field.allowed_values.param_id]
if value not in allowed_values:
param = _get_param_schema_by_id(
param_schemas, field.allowed_values.param_id
)
allowed_values = param_values[param.name]

if value not in allowed_values: # type: ignore
validations.append(
ValidationFailure(
row_number=row_num,
Expand All @@ -195,14 +211,23 @@ def validate_fieldset(
csv_columns,
csv_data: list[dict],
fieldset_schema: FieldsetSchema,
params: dict[str, Any],
param_schemas: dict[str, WorkflowParam],
param_values: dict[str, WorkflowParamValue],
) -> list[ValidationFailure]:
"""Validate the fieldset schema of a file."""
validations = []

validations.extend(_check_csv_columns(csv_columns, fieldset_schema))
for row_num, row in enumerate(csv_data):
for field in fieldset_schema.fields:
validations.extend(_validate_field(row_num + 1, row, field, params))
validations.extend(
_validate_field(
row_num=row_num + 1,
row=row,
field=field,
param_schemas=param_schemas,
param_values=param_values,
)
)

return validations
20 changes: 6 additions & 14 deletions server/workflow_runner/workflow_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Any
import csv
import io

from frictionless import Resource

Expand All @@ -20,14 +18,14 @@
ParameterDefinitionNotFoundException,
)
from .validators import (
_get_param_schema_by_id,
validate_fieldset,
validate_file_type,
validate_row_count,
parse_frictionless,
WorkflowParamValue,
)

WorkflowParamValue = int | str | list[str] | None


def process_workflow(
file_name: str,
Expand Down Expand Up @@ -131,16 +129,9 @@ def _validate_csv(
operation.fieldset_schema, schema.fieldset_schemas
)
case ParamReference():
param_id = operation.fieldset_schema.param_id
param = param_schemas.get(param_id, None)

if not param:
return [
ValidationFailure(
message=f"Param with id {param_id} could not be found in the param schemas."
)
]

param = _get_param_schema_by_id(
param_schemas, operation.fieldset_schema.param_id
)
fieldset_schema_name = param_values[param.name]
if type(fieldset_schema_name) != str:
return [
Expand All @@ -157,6 +148,7 @@ def _validate_csv(
csv_data.column_names,
csv_data.data,
fieldset_schema,
param_schemas,
param_values,
)
)
Expand Down
Loading

0 comments on commit 64554ec

Please sign in to comment.