From ed7c2aa028eddf67f24091db72803e87e5c74df0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C5=A1a=20Zejnilovi=C4=87?= Date: Mon, 13 Dec 2021 13:30:30 +0100 Subject: [PATCH] Base feature (#1) * Base feature * Update sbt version * #1 github workflow to build, sbt-header check + github workflow * #1 headers update (newline) to be compatible with sbt header plugin * #1 Constants reverted to Enceladus values (test passing and Enceladus-compatiblity), created issue #5 to follow up * #1 testfix (missing test config, typeName wrapper - typeOf[Boolean].toString returning `Boolean` vs `scala.Boolean` + test * #1 cleanup, SectionSuite added (originating in Enceladus). `FieldValidationFailure` renamed to `FieldValidationIssue` * #1 sbt `in` -> `/` * #1 removed unused test resources * #1 added header checking for src/test/scala, too. + header NL fixes, build name = small caps, * #1 assembly removed (plugin + config) * #1 StructFieldImplicitsSuite added, `StandardizationRerunSuite` renamed to `StandardizationCsvSuite` * #1 Enceladus#677 changed to #7 * `StdInterpreterSuite` and `StandardizationInterpreterSuite` merged into the latter. * #1 `StandardizationInterpreterSuite` test fix Co-authored-by: Daniel Kavan --- .editorconfig | 36 + .github/workflows/build.yml | 41 + .github/workflows/licence_check.yml | 41 + .gitignore | 1 + build.sbt | 41 + project/build.properties | 1 + project/plugins.sbt | 1 + src/main/resources/reference.conf | 31 + .../ArrayTransformations.scala | 97 +++ .../absa/standardization/ConfigReader.scala | 180 +++++ .../co/absa/standardization/Constants.scala | 26 + .../absa/standardization/ErrorMessage.scala | 109 +++ .../co/absa/standardization/FileReader.scala | 41 + .../co/absa/standardization/FlatField.scala | 21 + .../co/absa/standardization/JsonUtils.scala | 60 ++ .../standardization/RecordIdGeneration.scala | 73 ++ .../standardization/SchemaValidator.scala | 138 ++++ .../standardization/Standardization.scala | 109 +++ .../StandardizationConfig.scala | 22 + .../standardization/ValidationException.scala | 25 + .../standardization/ValidationIssue.scala | 23 + .../implicits/ColumnImplicits.scala | 112 +++ .../implicits/DataFrameImplicits.scala | 73 ++ .../implicits/OptionImplicits.scala | 27 + .../implicits/StringImplicits.scala | 208 +++++ .../implicits/StructFieldImplicits.scala | 48 ++ .../numeric/DecimalSymbols.scala | 105 +++ .../numeric/NumericPattern.scala | 38 + .../absa/standardization/numeric/Radix.scala | 71 ++ .../standardization/schema/MetadataKeys.scala | 46 ++ .../standardization/schema/SchemaUtils.scala | 605 +++++++++++++++ .../standardization/schema/SparkUtils.scala | 108 +++ .../stages/PlainSchemaGenerator.scala | 56 ++ .../stages/SchemaChecker.scala | 70 ++ .../standardization/stages/TypeParser.scala | 658 ++++++++++++++++ .../time/DateTimePattern.scala | 184 +++++ .../time/TimeZoneNormalizer.scala | 54 ++ .../typeClasses/DoubleLike.scala | 41 + .../typeClasses/LongLike.scala | 68 ++ .../absa/standardization/types/Defaults.scala | 109 +++ .../types/DefaultsByFormat.scala | 89 +++ .../standardization/types/ParseOutput.scala | 21 + .../absa/standardization/types/Section.scala | 341 +++++++++ .../standardization/types/TypePattern.scala | 30 + .../types/TypedStructField.scala | 476 ++++++++++++ .../types/parsers/DateTimeParser.scala | 124 +++ .../types/parsers/DecimalParser.scala | 57 ++ .../types/parsers/FractionalParser.scala | 68 ++ .../types/parsers/IntegralParser.scala | 171 +++++ .../types/parsers/NumericParser.scala | 106 +++ .../types/parsers/ParseViaDecimalFormat.scala | 56 ++ .../absa/standardization/udf/UDFBuilder.scala | 56 ++ .../absa/standardization/udf/UDFLibrary.scala | 99 +++ .../absa/standardization/udf/UDFNames.scala | 34 + .../absa/standardization/udf/UDFResult.scala | 40 + .../field/BinaryFieldValidator.scala | 68 ++ .../validation/field/DateFieldValidator.scala | 79 ++ .../field/DateTimeFieldValidator.scala | 91 +++ .../field/DecimalFieldValidator.scala | 28 + .../field/FieldValidationIssue.scala | 24 + .../validation/field/FieldValidator.scala | 83 ++ .../field/FractionalFieldValidator.scala | 29 + .../field/IntegralFieldValidator.scala | 48 ++ .../field/NumericFieldValidator.scala | 42 ++ .../field/ScalarFieldValidator.scala | 39 + .../field/TimestampFieldValidator.scala | 81 ++ src/test/resources/application.conf | 18 + src/test/resources/data/bug.json | 26 + src/test/resources/data/data1Schema.json | 83 ++ .../resources/data/dateTimestampSchemaOk.json | 93 +++ .../data/dateTimestampSchemaWrong.json | 107 +++ .../resources/data/integral_overflow_test.csv | 11 + .../data/integral_overflow_test_numbers.json | 1 + .../data/integral_overflow_test_text.json | 1 + src/test/resources/data/patients.json | 59 ++ .../resources/data/standardizeJsonSrc.json | 1 + .../ArrayTransformationsSuite.scala | 90 +++ .../standardization/ErrorMessageFactory.scala | 36 + .../absa/standardization/JsonUtilsSuite.scala | 63 ++ .../absa/standardization/LoggerTestBase.scala | 47 ++ .../RecordIdGenerationSuite.scala | 91 +++ .../SchemaValidationSuite.scala | 227 ++++++ .../absa/standardization/SparkTestBase.scala | 137 ++++ .../StandardizationCsvSuite.scala | 133 ++++ .../StandardizationParquetSuite.scala | 361 +++++++++ .../co/absa/standardization/TestSamples.scala | 88 +++ .../implicits/DataFrameImplicitsSuite.scala | 157 ++++ .../implicits/StringImplicitsSuite.scala | 310 ++++++++ .../implicits/StructFieldImplicitsSuite.scala | 67 ++ .../interpreter/CounterPartySuite.scala | 61 ++ .../interpreter/DateTimeSuite.scala | 107 +++ .../interpreter/SampleDataSuite.scala | 47 ++ .../StandardizationInterpreterSuite.scala | 375 +++++++++ ...tandardizationInterpreter_ArraySuite.scala | 210 ++++++ ...andardizationInterpreter_BinarySuite.scala | 192 +++++ ...StandardizationInterpreter_DateSuite.scala | 361 +++++++++ ...ndardizationInterpreter_DecimalSuite.scala | 280 +++++++ ...rdizationInterpreter_FractionalSuite.scala | 385 ++++++++++ ...dardizationInterpreter_IntegralSuite.scala | 554 ++++++++++++++ ...ardizationInterpreter_TimestampSuite.scala | 368 +++++++++ .../stages/PlainSchemaGeneratorSuite.scala | 71 ++ .../stages/SchemaCheckerSuite.scala | 37 + .../interpreter/stages/TypeParserSuite.scala | 55 ++ .../stages/TypeParserSuiteTemplate.scala | 263 +++++++ .../TypeParser_FromBooleanTypeSuite.scala | 119 +++ .../stages/TypeParser_FromDateTypeSuite.scala | 120 +++ .../TypeParser_FromDecimalTypeSuite.scala | 121 +++ .../TypeParser_FromDoubleTypeSuite.scala | 132 ++++ .../stages/TypeParser_FromLongTypeSuite.scala | 121 +++ .../TypeParser_FromStringTypeSuite.scala | 121 +++ .../TypeParser_FromTimestampTypeSuite.scala | 120 +++ .../standardizationInterpreter_RowTypes.scala | 132 ++++ .../schema/SchemaUtilsSuite.scala | 477 ++++++++++++ .../schema/SparkUtilsSuite.scala | 119 +++ .../time/DateTimePatternSuite.scala | 271 +++++++ .../types/DefaultsByFormatSuite.scala | 92 +++ .../standardization/types/DefaultsSuite.scala | 94 +++ .../standardization/types/SectionSuite.scala | 713 ++++++++++++++++++ .../types/TypedStructFieldSuite.scala | 281 +++++++ .../types/parsers/DateTimeParserSuite.scala | 203 +++++ .../types/parsers/DecimalParserSuite.scala | 132 ++++ .../types/parsers/FractionalParserSuite.scala | 184 +++++ ...ralParser_PatternIntegralParserSuite.scala | 130 ++++ ...egralParser_RadixIntegralParserSuite.scala | 173 +++++ .../standardization/udf/UDFBuilderSuite.scala | 126 ++++ .../field/BinaryValidatorSuite.scala | 61 ++ .../field/DateFieldValidatorSuite.scala | 213 ++++++ .../field/FieldValidatorSuite.scala | 32 + .../field/FractionalFieldValidatorSuite.scala | 91 +++ .../field/IntegralFieldValidatorSuite.scala | 94 +++ .../field/NumericFieldValidatorSuite.scala | 92 +++ .../field/ScalarFieldValidatorSuite.scala | 52 ++ .../field/TimestampFieldValidatorSuite.scala | 233 ++++++ 133 files changed, 16600 insertions(+) create mode 100644 .editorconfig create mode 100644 .github/workflows/build.yml create mode 100644 .github/workflows/licence_check.yml create mode 100644 build.sbt create mode 100644 project/build.properties create mode 100644 project/plugins.sbt create mode 100644 src/main/resources/reference.conf create mode 100644 src/main/scala/za/co/absa/standardization/ArrayTransformations.scala create mode 100644 src/main/scala/za/co/absa/standardization/ConfigReader.scala create mode 100644 src/main/scala/za/co/absa/standardization/Constants.scala create mode 100644 src/main/scala/za/co/absa/standardization/ErrorMessage.scala create mode 100644 src/main/scala/za/co/absa/standardization/FileReader.scala create mode 100644 src/main/scala/za/co/absa/standardization/FlatField.scala create mode 100644 src/main/scala/za/co/absa/standardization/JsonUtils.scala create mode 100644 src/main/scala/za/co/absa/standardization/RecordIdGeneration.scala create mode 100644 src/main/scala/za/co/absa/standardization/SchemaValidator.scala create mode 100644 src/main/scala/za/co/absa/standardization/Standardization.scala create mode 100644 src/main/scala/za/co/absa/standardization/StandardizationConfig.scala create mode 100644 src/main/scala/za/co/absa/standardization/ValidationException.scala create mode 100644 src/main/scala/za/co/absa/standardization/ValidationIssue.scala create mode 100644 src/main/scala/za/co/absa/standardization/implicits/ColumnImplicits.scala create mode 100644 src/main/scala/za/co/absa/standardization/implicits/DataFrameImplicits.scala create mode 100644 src/main/scala/za/co/absa/standardization/implicits/OptionImplicits.scala create mode 100644 src/main/scala/za/co/absa/standardization/implicits/StringImplicits.scala create mode 100644 src/main/scala/za/co/absa/standardization/implicits/StructFieldImplicits.scala create mode 100644 src/main/scala/za/co/absa/standardization/numeric/DecimalSymbols.scala create mode 100644 src/main/scala/za/co/absa/standardization/numeric/NumericPattern.scala create mode 100644 src/main/scala/za/co/absa/standardization/numeric/Radix.scala create mode 100644 src/main/scala/za/co/absa/standardization/schema/MetadataKeys.scala create mode 100644 src/main/scala/za/co/absa/standardization/schema/SchemaUtils.scala create mode 100644 src/main/scala/za/co/absa/standardization/schema/SparkUtils.scala create mode 100644 src/main/scala/za/co/absa/standardization/stages/PlainSchemaGenerator.scala create mode 100644 src/main/scala/za/co/absa/standardization/stages/SchemaChecker.scala create mode 100644 src/main/scala/za/co/absa/standardization/stages/TypeParser.scala create mode 100644 src/main/scala/za/co/absa/standardization/time/DateTimePattern.scala create mode 100644 src/main/scala/za/co/absa/standardization/time/TimeZoneNormalizer.scala create mode 100644 src/main/scala/za/co/absa/standardization/typeClasses/DoubleLike.scala create mode 100644 src/main/scala/za/co/absa/standardization/typeClasses/LongLike.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/Defaults.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/DefaultsByFormat.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/ParseOutput.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/Section.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/TypePattern.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/TypedStructField.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/parsers/DateTimeParser.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/parsers/DecimalParser.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/parsers/FractionalParser.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/parsers/IntegralParser.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/parsers/NumericParser.scala create mode 100644 src/main/scala/za/co/absa/standardization/types/parsers/ParseViaDecimalFormat.scala create mode 100644 src/main/scala/za/co/absa/standardization/udf/UDFBuilder.scala create mode 100644 src/main/scala/za/co/absa/standardization/udf/UDFLibrary.scala create mode 100644 src/main/scala/za/co/absa/standardization/udf/UDFNames.scala create mode 100644 src/main/scala/za/co/absa/standardization/udf/UDFResult.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/BinaryFieldValidator.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/DateFieldValidator.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/DateTimeFieldValidator.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/DecimalFieldValidator.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/FieldValidationIssue.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/FieldValidator.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/FractionalFieldValidator.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/IntegralFieldValidator.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/NumericFieldValidator.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/ScalarFieldValidator.scala create mode 100644 src/main/scala/za/co/absa/standardization/validation/field/TimestampFieldValidator.scala create mode 100644 src/test/resources/application.conf create mode 100644 src/test/resources/data/bug.json create mode 100644 src/test/resources/data/data1Schema.json create mode 100644 src/test/resources/data/dateTimestampSchemaOk.json create mode 100644 src/test/resources/data/dateTimestampSchemaWrong.json create mode 100644 src/test/resources/data/integral_overflow_test.csv create mode 100644 src/test/resources/data/integral_overflow_test_numbers.json create mode 100644 src/test/resources/data/integral_overflow_test_text.json create mode 100644 src/test/resources/data/patients.json create mode 100644 src/test/resources/data/standardizeJsonSrc.json create mode 100644 src/test/scala/za/co/absa/standardization/ArrayTransformationsSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/ErrorMessageFactory.scala create mode 100644 src/test/scala/za/co/absa/standardization/JsonUtilsSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/LoggerTestBase.scala create mode 100644 src/test/scala/za/co/absa/standardization/RecordIdGenerationSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/SchemaValidationSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/SparkTestBase.scala create mode 100644 src/test/scala/za/co/absa/standardization/StandardizationCsvSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/StandardizationParquetSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/TestSamples.scala create mode 100644 src/test/scala/za/co/absa/standardization/implicits/DataFrameImplicitsSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/implicits/StringImplicitsSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/implicits/StructFieldImplicitsSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/CounterPartySuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/DateTimeSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/SampleDataSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreterSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_ArraySuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_BinarySuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_DateSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_DecimalSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_FractionalSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_IntegralSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_TimestampSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/PlainSchemaGeneratorSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/SchemaCheckerSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParserSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParserSuiteTemplate.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromBooleanTypeSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDateTypeSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDecimalTypeSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDoubleTypeSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromLongTypeSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromStringTypeSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromTimestampTypeSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/interpreter/standardizationInterpreter_RowTypes.scala create mode 100644 src/test/scala/za/co/absa/standardization/schema/SchemaUtilsSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/schema/SparkUtilsSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/time/DateTimePatternSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/types/DefaultsByFormatSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/types/DefaultsSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/types/SectionSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/types/TypedStructFieldSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/types/parsers/DateTimeParserSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/types/parsers/DecimalParserSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/types/parsers/FractionalParserSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/types/parsers/IntegralParser_PatternIntegralParserSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/types/parsers/IntegralParser_RadixIntegralParserSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/udf/UDFBuilderSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/validation/field/BinaryValidatorSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/validation/field/DateFieldValidatorSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/validation/field/FieldValidatorSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/validation/field/FractionalFieldValidatorSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/validation/field/IntegralFieldValidatorSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/validation/field/NumericFieldValidatorSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/validation/field/ScalarFieldValidatorSuite.scala create mode 100644 src/test/scala/za/co/absa/standardization/validation/field/TimestampFieldValidatorSuite.scala diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..eb319bc --- /dev/null +++ b/.editorconfig @@ -0,0 +1,36 @@ +# +# Copyright 2021 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# top-most EditorConfig file +root = true + +[*] +charset = utf-8 +end_of_line = lf +trim_trailing_whitespace = true + +[*.xml] +indent_size = 4 +indent_style = space +insert_final_newline = true + +[*.{java,scala,js,json,css}] +indent_size = 2 +indent_style = space +insert_final_newline = true +max_line_length = 120 + +[*.md] +trim_trailing_whitespace = false diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..00dca55 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,41 @@ +# +# Copyright 2021 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +name: Build + +on: + push: + branches: [ main, develop, master ] + pull_request: + branches: [ master, develop ] + types: [ assigned, opened, synchronize, reopened, labeled ] + +jobs: + test-sbt: + runs-on: ubuntu-latest + strategy: + fail-fast: false + name: SBT Test + steps: + - name: Checkout code + uses: actions/checkout@v2 + - uses: coursier/cache-action@v5 + - name: Setup Scala + uses: olafurpg/setup-scala@v10 + with: + java-version: "adopt@1.8" + - name: Build and run tests + run: sbt test diff --git a/.github/workflows/licence_check.yml b/.github/workflows/licence_check.yml new file mode 100644 index 0000000..6d96c06 --- /dev/null +++ b/.github/workflows/licence_check.yml @@ -0,0 +1,41 @@ +# +# Copyright 2021 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +name: License Check + +on: + push: + branches: [ main, develop, master ] + pull_request: + branches: [ master ] + types: [ assigned, opened, synchronize, reopened, labeled ] + +jobs: + license-test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Setup Scala + uses: olafurpg/setup-scala@v10 + with: + java-version: "adopt@1.8" + # note, that task "headerCheck" defaults to just "compile:headerCheck" - see https://github.com/sbt/sbt-header/issues/14 + - name: SBT src licence header check + run: sbt Compile/headerCheck + - name: SBT test licence header check + run: sbt Test/headerCheck + diff --git a/.gitignore b/.gitignore index 545f3fc..56b7b5e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at +# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software diff --git a/build.sbt b/build.sbt new file mode 100644 index 0000000..eef3587 --- /dev/null +++ b/build.sbt @@ -0,0 +1,41 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.00 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ThisBuild / name := "standardization" +ThisBuild / organization := "za.co.absa" +ThisBuild / version := "0.0.1-SNAPSHOT" +ThisBuild / scalaVersion := "2.11.12" + +libraryDependencies ++= List( + "org.apache.spark" %% "spark-core" % "2.4.7" % "provided", + "org.apache.spark" %% "spark-sql" % "2.4.7" % "provided", + "za.co.absa" %% "spark-hats" % "0.2.2", + "za.co.absa" %% "spark-hofs" % "0.4.0", + "org.scalatest" %% "scalatest" % "3.2.2" % Test, + "com.typesafe" % "config" % "1.4.1" +) + +Test / parallelExecution := false + +// licenceHeader check: + +ThisBuild / organizationName := "ABSA Group Limited" +ThisBuild / startYear := Some(2021) +ThisBuild / licenses += "Apache-2.0" -> url("https://www.apache.org/licenses/LICENSE-2.0.txt") + +// linting +Global / excludeLintKeys += ThisBuild / name // will be used in publish, todo #3 - confirm if lint ignore is still needed diff --git a/project/build.properties b/project/build.properties new file mode 100644 index 0000000..10fd9ee --- /dev/null +++ b/project/build.properties @@ -0,0 +1 @@ +sbt.version=1.5.5 diff --git a/project/plugins.sbt b/project/plugins.sbt new file mode 100644 index 0000000..c107291 --- /dev/null +++ b/project/plugins.sbt @@ -0,0 +1 @@ +addSbtPlugin("de.heikoseeberger" % "sbt-header" % "5.6.0") diff --git a/src/main/resources/reference.conf b/src/main/resources/reference.conf new file mode 100644 index 0000000..6ab630d --- /dev/null +++ b/src/main/resources/reference.conf @@ -0,0 +1,31 @@ +# Copyright 2021 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Configuration added here is considered the application default and it will be used +# for keys that are not specified in the provided 'application.conf' or system properties. +# Here is the precedence of configuration (top ones have higher precedence): +# 1. System Properties (e.g. passed as '-Dkey=value') +# 2. application.conf (e.g. provided as '-Dconfig.file=...') +# 3. reference.conf + +# 'enceladus_record_id' with an id can be added containing either true UUID, always the same IDs (row-hash-based) or the +# column will not be added at all. Allowed values: "uuid", "stableHashId", "none" +standardization.recordId.generation.strategy="uuid" + +# system-wide time zone +timezone="UTC" + +standardization.testUtils.sparkTestBaseMaster="local[4]" + +standardization.failOnInputNotPerSchema=false diff --git a/src/main/scala/za/co/absa/standardization/ArrayTransformations.scala b/src/main/scala/za/co/absa/standardization/ArrayTransformations.scala new file mode 100644 index 0000000..3364edc --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/ArrayTransformations.scala @@ -0,0 +1,97 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import org.apache.spark.sql.{Column, Dataset, Row, SparkSession} +import org.apache.spark.sql.api.java.UDF1 +import org.apache.spark.sql.functions.{callUDF, col, struct} +import org.apache.spark.sql.types.{ArrayType, DataType, StructType} +import org.slf4j.LoggerFactory +import za.co.absa.standardization.schema.SchemaUtils + +object ArrayTransformations { + private val logger = LoggerFactory.getLogger(this.getClass) + def flattenArrays(df: Dataset[Row], colName: String)(implicit spark: SparkSession): Dataset[Row] = { + val typ = SchemaUtils.getFieldType(colName, df.schema).getOrElse(throw new Error(s"Field $colName does not exist in ${df.schema.printTreeString()}")) + if (!typ.isInstanceOf[ArrayType]) { + logger.info(s"Field $colName is not an ArrayType, returning the original dataset!") + df + } else { + val arrType = typ.asInstanceOf[ArrayType] + if (!arrType.elementType.isInstanceOf[ArrayType]) { + logger.info(s"Field $colName is not a nested array, returning the original dataset!") + df + } else { + val udfName = colName.replace('.', '_') + System.currentTimeMillis() + + spark.udf.register(udfName, new UDF1[Seq[Seq[Row]], Seq[Row]] { + def call(t1: Seq[Seq[Row]]): Seq[Row] = if (t1 == null) null.asInstanceOf[Seq[Row]] else t1.filter(_ != null).flatten // scalastyle:ignore null + }, arrType.elementType) + + nestedWithColumn(df)(colName, callUDF(udfName, col(colName))) + } + } + + } + + def nestedWithColumn(ds: Dataset[Row])(columnName: String, column: Column): Dataset[Row] = { + val toks = columnName.split("\\.").toList + + def helper(tokens: List[String], pathAcc: Seq[String]): Column = { + val currPath = (pathAcc :+ tokens.head).mkString(".") + val topType = SchemaUtils.getFieldType(currPath, ds.schema) + + // got a match + if (currPath == columnName) { + column as tokens.head + } // some other attribute + else if (!columnName.startsWith(currPath)) { + arrCol(currPath) + } // partial match, keep going + else if (topType.isEmpty) { + struct(helper(tokens.tail, pathAcc ++ List(tokens.head))) as tokens.head + } else { + topType.get match { + case s: StructType => + val cols = s.fields.map(_.name) + val fields = if (tokens.size > 1 && !cols.contains(tokens(1))) { + cols :+ tokens(1) + } else { + cols + } + struct(fields.map(field => helper((List(field) ++ tokens.tail).distinct, pathAcc :+ tokens.head) as field): _*) as tokens.head + case _: ArrayType => throw new IllegalStateException("Cannot reconstruct array columns. Please use this within arrayTransform.") + case _: DataType => arrCol(currPath) as tokens.head + } + } + } + + ds.withColumn(toks.head, helper(toks, Seq())) + } + + def arrCol(any: String): Column = { + val toks = any.replaceAll("\\[(\\d+)\\]", "\\.$1").split("\\.") + toks.tail.foldLeft(col(toks.head)){ + case (acc, tok) => + if (tok.matches("\\d+")) { + acc(tok.toInt) + } else { + acc(tok) + } + } + } +} diff --git a/src/main/scala/za/co/absa/standardization/ConfigReader.scala b/src/main/scala/za/co/absa/standardization/ConfigReader.scala new file mode 100644 index 0000000..b91c2aa --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/ConfigReader.scala @@ -0,0 +1,180 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import com.typesafe.config.{Config, ConfigException, ConfigFactory, ConfigRenderOptions, ConfigValueFactory} +import org.slf4j.{Logger, LoggerFactory} +import scala.collection.JavaConverters._ + +object ConfigReader { + type ConfigExceptionBadValue = ConfigException.BadValue + + val redactedReplacement: String = "*****" + private val defaultConfig: ConfigReader = new ConfigReader(ConfigFactory.load()) + + def apply(): ConfigReader = defaultConfig + def apply(config: Config): ConfigReader = new ConfigReader(config) + def apply(configMap: Map[String, String]): ConfigReader = { + val config = ConfigFactory.parseMap(configMap.asJava) + apply(config) + } + + def parseString(configLine: String): ConfigReader = { + val config = ConfigFactory.parseString(configLine) + apply(config) + } +} + +class ConfigReader(val config: Config = ConfigFactory.load()) { + import ConfigReader._ + + + def hasPath(path: String): Boolean = { + config.hasPath(path) + } + + def getString(path: String): String = { + config.getString(path) + } + + def getInt(path: String): Int = { + config.getInt(path) + } + + def getBoolean(path: String): Boolean = { + config.getBoolean(path) + } + + /** + * Inspects the config for the presence of the `path` and returns an optional result. + * + * @param path path to look for, e.g. "group1.subgroup2.value3 + * @return None if not found or defined Option[String] + */ + def getStringOption(path: String): Option[String] = { + getIfExists(path)(getString) + } + + def getIntOption(path: String): Option[Int] = { + getIfExists(path)(getInt) + } + + /** + * Inspects the config for the presence of the `path` and returns an optional result. + * + * @param path path to look for, e.g. "group1.subgroup2.value3 + * @return None if not found or defined Option[Boolean] + */ + def getBooleanOption(path: String): Option[Boolean] = { + getIfExists(path)(getBoolean) + } + + /** Handy shorthand of frequent `config.withValue(key, ConfigValueFactory.fromAnyRef(value))` */ + def withAnyRefValue(key: String, value: AnyRef) : ConfigReader = { + ConfigReader(config.withValue(key, ConfigValueFactory.fromAnyRef(value))) + } + + /** + * Given a configuration returns a new configuration which has all sensitive keys redacted. + * + * @param keysToRedact A set of keys to be redacted. + */ + def getRedactedConfig(keysToRedact: Set[String]): ConfigReader = { + def withAddedKey(accumulatedConfig: Config, key: String): Config = { + if (config.hasPath(key)) { + accumulatedConfig.withValue(key, ConfigValueFactory.fromAnyRef(redactedReplacement)) + } else { + accumulatedConfig + } + } + + val redactingConfig = keysToRedact.foldLeft(ConfigFactory.empty)(withAddedKey) + + ConfigReader(redactingConfig.withFallback(config)) + } + + def getLong(path: String): Long = { + config.getLong(path) + } + + def getLongOption(path: String): Option[Long] = { + getIfExists(path)(getLong) + } + + /** + * Flattens TypeSafe config tree and returns the effective configuration + * while redacting sensitive keys. + * + * @param keysToRedact A set of keys for which should be redacted. + * @return the effective configuration as a map + */ + def getFlatConfig(keysToRedact: Set[String] = Set()): Map[String, AnyRef] = { + getRedactedConfig(keysToRedact).config.entrySet().asScala.map({ entry => + entry.getKey -> entry.getValue.unwrapped() + }).toMap + } + + /** + * Logs the effective configuration while redacting sensitive keys + * in HOCON format. + * + * @param keysToRedact A set of keys for which values shouldn't be logged. + */ + def logEffectiveConfigHocon(keysToRedact: Set[String] = Set(), log: Logger = LoggerFactory.getLogger(this.getClass)): Unit = { + val redactedConfig = getRedactedConfig(keysToRedact) + + val renderOptions = ConfigRenderOptions.defaults() + .setComments(false) + .setOriginComments(false) + .setJson(false) + + val rendered = redactedConfig.config.root().render(renderOptions) + + log.info(s"Effective configuration:\n$rendered") + } + + /** + * Logs the effective configuration while redacting sensitive keys + * in Properties format. + * + * @param keysToRedact A set of keys for which values shouldn't be logged. + */ + def logEffectiveConfigProps(keysToRedact: Set[String] = Set(), log: Logger = LoggerFactory.getLogger(this.getClass)): Unit = { + val redactedConfig = getFlatConfig(keysToRedact) + + val rendered = redactedConfig.map { + case (k, v) => s"$k = $v" + }.toArray + .sortBy(identity) + .mkString("\n") + + log.info(s"Effective configuration:\n$rendered") + } + + private def getIfExists[T](path: String)(readFnc: String => T): Option[T] = { + if (config.hasPathOrNull(path)) { + if (config.getIsNull(path)) { + None + } else { + Option(readFnc(path)) + } + } else { + None + } + } + +} diff --git a/src/main/scala/za/co/absa/standardization/Constants.scala b/src/main/scala/za/co/absa/standardization/Constants.scala new file mode 100644 index 0000000..eb75dfd --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/Constants.scala @@ -0,0 +1,26 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +object Constants { + // ToDo Configurable - Issue #5 + final val InfoDateColumn = "enceladus_info_date" // TODO #5 "standardization_info_date" + final val InfoDateColumnString = s"${InfoDateColumn}_string" + final val ReportDateFormat = "yyyy-MM-dd" + final val InfoVersionColumn = "enceladus_info_version" // TODO #5 "standardization_info_version" + final val EnceladusRecordId = "enceladus_record_id" // TODO #5 "standardization_record_id" +} diff --git a/src/main/scala/za/co/absa/standardization/ErrorMessage.scala b/src/main/scala/za/co/absa/standardization/ErrorMessage.scala new file mode 100644 index 0000000..aa76001 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/ErrorMessage.scala @@ -0,0 +1,109 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType + +/** + * Case class to represent an error message + * + * @param errType - Type or source of the error + * @param errCode - Internal error code + * @param errMsg - Textual description of the error + * @param errCol - The name of the column where the error occurred + * @param rawValues - Sequence of raw values (which are the potential culprits of the error) + * @param mappings - Sequence of Mappings i.e Mapping Table Column -> Equivalent Mapped Dataset column + */ +case class ErrorMessage(errType: String, errCode: String, errMsg: String, errCol: String, rawValues: Seq[String], mappings: Seq[Mapping] = Seq()) +case class Mapping(mappingTableColumn: String, mappedDatasetColumn: String) + +object ErrorMessage { + val errorColumnName = "errCol" + + def stdCastErr(errCol: String, rawValue: String): ErrorMessage = ErrorMessage( + errType = "stdCastError", + errCode = ErrorCodes.StdCastError, + errMsg = "Standardization Error - Type cast", + errCol = errCol, + rawValues = Seq(rawValue)) + def stdNullErr(errCol: String): ErrorMessage = ErrorMessage( + errType = "stdNullError", + errCode = ErrorCodes.StdNullError, + errMsg = "Standardization Error - Null detected in non-nullable attribute", + errCol = errCol, + rawValues = Seq("null")) + def stdTypeError(errCol: String, sourceType: String, targetType: String): ErrorMessage = ErrorMessage( + errType = "stdTypeError", + errCode = ErrorCodes.StdTypeError, + errMsg = s"Standardization Error - Type '$sourceType' cannot be cast to '$targetType'", + errCol = errCol, + rawValues = Seq.empty) + def stdSchemaError(errRow: String): ErrorMessage = ErrorMessage( + errType = "stdSchemaError", + errCode = ErrorCodes.StdSchemaError, + errMsg = s"The input data does not adhere to requested schema", + errCol = null, // scalastyle:ignore null + rawValues = Seq(errRow)) + def confMappingErr(errCol: String, rawValues: Seq[String], mappings: Seq[Mapping]): ErrorMessage = ErrorMessage( + errType = "confMapError", + errCode = ErrorCodes.ConfMapError, + errMsg = "Conformance Error - Null produced by mapping conformance rule", + errCol = errCol, + rawValues = rawValues, mappings = mappings) + def confCastErr(errCol: String, rawValue: String): ErrorMessage = ErrorMessage( + errType = "confCastError", + errCode = ErrorCodes.ConfCastErr, + errMsg = "Conformance Error - Null returned by casting conformance rule", + errCol = errCol, + rawValues = Seq(rawValue)) + def confNegErr(errCol: String, rawValue: String): ErrorMessage = ErrorMessage( + errType = "confNegError", + errCode = ErrorCodes.ConfNegErr, + errMsg = "Conformance Error - Negation of numeric type with minimum value overflows and remains unchanged", + errCol = errCol, + rawValues = Seq(rawValue)) + def confLitErr(errCol: String, rawValue: String): ErrorMessage = ErrorMessage( + errType = "confLitError", + errCode = ErrorCodes.ConfLitErr, + errMsg = "Conformance Error - Special column value has changed", + errCol = errCol, + rawValues = Seq(rawValue)) + + def errorColSchema(implicit spark: SparkSession): StructType = { + import spark.implicits._ + spark.emptyDataset[ErrorMessage].schema + } + + /** + * This object purpose it to group the error codes together to decrease a chance of them being in conflict + */ + object ErrorCodes { + final val StdCastError = "E00000" + final val ConfMapError = "E00001" + final val StdNullError = "E00002" + final val ConfCastErr = "E00003" + final val ConfNegErr = "E00004" + final val ConfLitErr = "E00005" + final val StdTypeError = "E00006" + final val StdSchemaError = "E00007" + + val standardizationErrorCodes = Seq(StdCastError, StdNullError, StdTypeError, StdSchemaError) + val conformanceErrorCodes = Seq(ConfMapError, ConfCastErr, ConfNegErr, ConfLitErr) + } +} + diff --git a/src/main/scala/za/co/absa/standardization/FileReader.scala b/src/main/scala/za/co/absa/standardization/FileReader.scala new file mode 100644 index 0000000..4da1169 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/FileReader.scala @@ -0,0 +1,41 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import scala.io.Source + +object FileReader { + def readFileAsListOfLines(filename: String): List[String] = { + val sourceFile = Source.fromFile(filename) + try { + sourceFile.getLines().toList // making it a List to copy the content of the file into memory before it's closed + } finally { + sourceFile.close() + } + } + + def readFileAsString(filename: String, lineSeparator: String = "\n"): String = { + val sourceFile = Source.fromFile(filename) + try { + sourceFile.getLines().mkString(lineSeparator) + } finally { + sourceFile.close() + } + } + + +} diff --git a/src/main/scala/za/co/absa/standardization/FlatField.scala b/src/main/scala/za/co/absa/standardization/FlatField.scala new file mode 100644 index 0000000..337a237 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/FlatField.scala @@ -0,0 +1,21 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import org.apache.spark.sql.types.StructField + +case class FlatField(structPath: String, field: StructField) diff --git a/src/main/scala/za/co/absa/standardization/JsonUtils.scala b/src/main/scala/za/co/absa/standardization/JsonUtils.scala new file mode 100644 index 0000000..ef97360 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/JsonUtils.scala @@ -0,0 +1,60 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import com.fasterxml.jackson.databind.ObjectMapper +import org.apache.spark.sql.{DataFrame, SparkSession} + +object JsonUtils { + + /** + * Formats a JSON string so it looks pretty. + * + * @param jsonIn A JSON string + * @return A pretty formatted JSON string + */ + def prettyJSON(jsonIn: String): String = { + val mapper = new ObjectMapper() + + val jsonUnindented = mapper.readValue(jsonIn, classOf[Any]) + val indented = mapper.writerWithDefaultPrettyPrinter.writeValueAsString(jsonUnindented) + indented.replace("\r\n", "\n") + } + + /** + * Formats a Spark-generated JSON strings that are returned by + * applying `.toJSON.collect()` to a DataFrame. + * + * @param jsons A list of JSON documents + * @return A pretty formatted JSON string + */ + def prettySparkJSON(jsons: Seq[String]): String = { + val singleJSON = jsons.mkString("[", ",", "]") + prettyJSON(singleJSON) + } + + /** + * Creates a Spark DataFrame from a JSON document(s). + * + * @param json A json string to convert to a DataFrame + * @return A data frame + */ + def getDataFrameFromJson(spark: SparkSession, json: Seq[String]): DataFrame = { + import spark.implicits._ + spark.read.json(json.toDS) + } +} diff --git a/src/main/scala/za/co/absa/standardization/RecordIdGeneration.scala b/src/main/scala/za/co/absa/standardization/RecordIdGeneration.scala new file mode 100644 index 0000000..f90aa6b --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/RecordIdGeneration.scala @@ -0,0 +1,73 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, expr, hash} +import org.slf4j.{Logger, LoggerFactory} + +object RecordIdGeneration { + + sealed trait IdType + + object IdType { + case object TrueUuids extends IdType + case object StableHashId extends IdType + case object NoId extends IdType + } + + private val log: Logger = LoggerFactory.getLogger(this.getClass) + + def getRecordIdGenerationType(idTypeName: String): IdType = { + idTypeName.toLowerCase match { + case "uuid" => IdType.TrueUuids + case "stablehashid" => IdType.StableHashId + case "none" => IdType.NoId + case _ => throw new IllegalArgumentException(s"Invalid value '$idTypeName' was encountered for id generation strategy, use one of: uuid, stableHashId, none.") + } + } + + /** + * The supplied dataframe `origDf` is either kept as-is (`strategy` = [[IdType.NoId]]) or has a column appended + * with an (presumably) unique value for each record. These are true UUIDs (`strategy` = [[IdType.TrueUuids]]) or + * values always the same for the same row, mainly for testing purposes (`strategy` = [[IdType.StableHashId]] + * + * @param origDf dataframe to be possibly extended + * @param idColumnName name of the id column to be used (usually [[Constants.EnceladusRecordId]]) + * @param strategy decides if and what ids will be appended to the origDf + * @return possibly updated `origDf` + */ + def addRecordIdColumnByStrategy(origDf: DataFrame, idColumnName: String, strategy: IdType): DataFrame = { + strategy match { + case IdType.NoId => + log.info("Record id generation is off.") + origDf + + case IdType.StableHashId => + log.info(s"Record id generation is set to 'stableHashId' - all runs will yield the same IDs.") + origDf.transform(hashFromAllColumns(Constants.EnceladusRecordId, _)) // adds hashId + + case IdType.TrueUuids => + log.info("Record id generation is on and true UUIDs will be added to output.") + origDf.withColumn(Constants.EnceladusRecordId, expr("uuid()")) + } + } + + private def hashFromAllColumns(hashColName: String, df: DataFrame): DataFrame = + df.withColumn(hashColName, hash(df.columns.map(col): _*)) + +} diff --git a/src/main/scala/za/co/absa/standardization/SchemaValidator.scala b/src/main/scala/za/co/absa/standardization/SchemaValidator.scala new file mode 100644 index 0000000..8dc631c --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/SchemaValidator.scala @@ -0,0 +1,138 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types._ +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField} +import za.co.absa.standardization.validation.field.FieldValidationIssue + +import scala.collection.mutable.ListBuffer + +/** + * Object responsible for Spark schema validation against self inconsistencies (not against the actual data) + */ +object SchemaValidator { + private implicit val defaults: Defaults = GlobalDefaults + + /** + * Validate a schema + * + * @param schema A Spark schema + * @return A list of ValidationErrors objects, each containing a column name and the list of errors and warnings + */ + def validateSchema(schema: StructType): List[FieldValidationIssue] = { + var errorsAccumulator = new ListBuffer[FieldValidationIssue] + val flatSchema = flattenSchema(schema) + for {s <- flatSchema} { + val fieldWithPath = if (s.structPath.isEmpty) s.field else s.field.copy(name = s.structPath + "." + s.field.name) + val issues = validateColumnName(s.field.name, s.structPath) ++ TypedStructField(fieldWithPath).validate() + if (issues.nonEmpty) { + val pattern = if (s.field.metadata contains "pattern") s.field.metadata.getString("pattern") else "" + errorsAccumulator += FieldValidationIssue(fieldWithPath.name, pattern, issues) + } + } + errorsAccumulator.toList + } + + /** + * Validates the error column. + * Most of the time the error column should not exist in the input schema. But if it does exist, it should + * conform to the expected type. + * Nullability of the error column is not an issue. + * + * @param schema A Spark schema + * @return A list of ValidationErrors, each containing a column name and the list of errors and warnings + */ + def validateErrorColumn(schema: StructType) + (implicit spark: SparkSession) + : List[FieldValidationIssue] = { + val expectedTypeNonNullable = ArrayType(ErrorMessage.errorColSchema, containsNull = false) + val expectedTypeNullable = ArrayType(ErrorMessage.errorColSchema, containsNull = true) + val errCol = schema.find(f => f.name == ErrorMessage.errorColumnName) + errCol match { + case Some(errColField) => + if (errColField.dataType != expectedTypeNonNullable && errColField.dataType != expectedTypeNullable) { + val actualType = errColField.dataType + List(FieldValidationIssue(errColField.name, "", + ValidationError("The error column in the input data does not conform to the expected type. " + + s"Expected: $expectedTypeNonNullable, actual: $actualType") :: Nil)) + } else { + Nil + } + case None => + Nil + } + } + + /** + * Validate a column name, check for illegal characters. + * Currently it checks for dots only, but it is extendable. + * + * @param columnName A column name + * @param structPath A path to the column name inside the nested structures + * @return A list of ValidationErrors objects, each containing a column name and the list of errors and warnings + */ + private def validateColumnName(columnName: String, structPath: String = "") : Seq[ValidationIssue] = { + if (columnName contains '.') { + val structMsg = if (structPath.isEmpty) "" else s" of the struct '$structPath'" + Seq(ValidationError(s"Column name '$columnName'$structMsg contains an illegal character: '.'")) + } else { + Nil + } + } + + /** + * This method flattens an input schema to a list of columns and their types + * Struct types are collapsed as 'column.element' and arrays as 'column[].element', arrays as 'column[][].element'. + * + * @param schema A Spark schema + * @return A sequence of all fields as a StructField + */ + private def flattenSchema(schema: StructType): Seq[FlatField] = { + + def flattenStruct(schema: StructType, structPath: String): Seq[FlatField] = { + var fields = new ListBuffer[FlatField] + val prefix = if (structPath.isEmpty) structPath else structPath + "." + for (field <- schema) { + field.dataType match { + case s: StructType => fields ++= flattenStruct(s, prefix + field.name) + case a: ArrayType => fields ++= flattenArray(field, a, prefix + field.name + "[]") + case _ => + val prefixedField = FlatField(structPath, field) + fields += prefixedField + } + } + fields + } + + def flattenArray(field: StructField, arr: ArrayType, structPath: String): Seq[FlatField] = { + var arrayFields = new ListBuffer[FlatField] + arr.elementType match { + case stuctInArray: StructType => arrayFields ++= flattenStruct(stuctInArray, structPath) + case arrayType: ArrayType => arrayFields ++= flattenArray(field, arrayType, structPath + "[]") + case _ => + val prefixedField = FlatField(structPath, field) + arrayFields += prefixedField + } + arrayFields + } + + flattenStruct(schema, "") + } + +} diff --git a/src/main/scala/za/co/absa/standardization/Standardization.scala b/src/main/scala/za/co/absa/standardization/Standardization.scala new file mode 100644 index 0000000..140c282 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/Standardization.scala @@ -0,0 +1,109 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import com.typesafe.config.{Config, ConfigFactory} +import org.apache.hadoop.conf.Configuration +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} +import za.co.absa.standardization.RecordIdGeneration.getRecordIdGenerationType +import za.co.absa.standardization.schema.{SchemaUtils, SparkUtils} +import za.co.absa.standardization.stages.{SchemaChecker, TypeParser} +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, ParseOutput} +import za.co.absa.standardization.udf.{UDFLibrary, UDFNames} + +object Standardization { + private implicit val defaults: Defaults = GlobalDefaults + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + + def standardize(df: DataFrame, schema: StructType, generalConfig: Config = ConfigFactory.load()) + (implicit sparkSession: SparkSession): DataFrame = { + implicit val udfLib: UDFLibrary = new UDFLibrary + implicit val hadoopConf: Configuration = sparkSession.sparkContext.hadoopConfiguration + + logger.info(s"Step 1: Schema validation") + validateSchemaAgainstSelfInconsistencies(schema) + + logger.info(s"Step 2: Standardization") + val std = standardizeDataset(df, schema, generalConfig.getBoolean("standardization.failOnInputNotPerSchema")) + + logger.info(s"Step 3: Clean the final error column") + val cleanedStd = cleanTheFinalErrorColumn(std) + + val idedStd = if (SchemaUtils.fieldExists(Constants.EnceladusRecordId, cleanedStd.schema)) { + cleanedStd // no new id regeneration + } else { + val recordIdGenerationStrategy = getRecordIdGenerationType(generalConfig.getString("standardization.recordId.generation.strategy")) + RecordIdGeneration.addRecordIdColumnByStrategy(cleanedStd, Constants.EnceladusRecordId, recordIdGenerationStrategy) + } + + logger.info(s"Standardization process finished, returning to the application...") + idedStd + } + + + private def validateSchemaAgainstSelfInconsistencies(expSchema: StructType) + (implicit spark: SparkSession): Unit = { + val validationErrors = SchemaChecker.validateSchemaAndLog(expSchema) + if (validationErrors._1.nonEmpty) { + throw new ValidationException("A fatal schema validation error occurred.", validationErrors._1) + } + } + + private def standardizeDataset(df: DataFrame, expSchema: StructType, failOnInputNotPerSchema: Boolean) + (implicit spark: SparkSession, udfLib: UDFLibrary): DataFrame = { + + val rowErrors: List[Column] = gatherRowErrors(df.schema) + val (stdCols, errorCols, oldErrorColumn) = expSchema.fields.foldLeft(List.empty[Column], rowErrors, None: Option[Column]) { + (acc, field) => + logger.info(s"Standardizing field: ${field.name}") + val (accCols, accErrorCols, accOldErrorColumn) = acc + if (field.name == ErrorMessage.errorColumnName) { + (accCols, accErrorCols, Option(df.col(field.name))) + } else { + val ParseOutput(stdColumn, errColumn) = TypeParser.standardize(field, "", df.schema, failOnInputNotPerSchema) + logger.info(s"Applying standardization plan for ${field.name}") + (stdColumn :: accCols, errColumn :: accErrorCols, accOldErrorColumn) + } + } + + val errorColsAllInCorrectOrder: List[Column] = (oldErrorColumn.toList ++ errorCols).reverse + val cols = (array(errorColsAllInCorrectOrder: _*) as ErrorMessage.errorColumnName) :: stdCols + df.select(cols.reverse: _*) + } + + private def cleanTheFinalErrorColumn(dataFrame: DataFrame) + (implicit spark: SparkSession, udfLib: UDFLibrary): DataFrame = { + ArrayTransformations.flattenArrays(dataFrame, ErrorMessage.errorColumnName) + .withColumn(ErrorMessage.errorColumnName, callUDF(UDFNames.cleanErrCol, col(ErrorMessage.errorColumnName))) + } + + private def gatherRowErrors(origSchema: StructType)(implicit spark: SparkSession): List[Column] = { + val corruptRecordColumn = spark.conf.get(SparkUtils.ColumnNameOfCorruptRecordConf) + SchemaUtils.getField(corruptRecordColumn, origSchema).map {_ => + val column = col(corruptRecordColumn) + when(column.isNotNull, // input row was not per expected schema + array(callUDF(UDFNames.stdSchemaErr, column.cast(StringType)) //column should be StringType but better to be sure + )).otherwise( // schema is OK + typedLit(Seq.empty[ErrorMessage]) + ) + }.toList + } +} diff --git a/src/main/scala/za/co/absa/standardization/StandardizationConfig.scala b/src/main/scala/za/co/absa/standardization/StandardizationConfig.scala new file mode 100644 index 0000000..82db673 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/StandardizationConfig.scala @@ -0,0 +1,22 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +case class StandardizationConfig(recordIdGenerationStrategy: RecordIdGeneration.IdType, + failOnInputNotPerSchema: Boolean) { + +} diff --git a/src/main/scala/za/co/absa/standardization/ValidationException.scala b/src/main/scala/za/co/absa/standardization/ValidationException.scala new file mode 100644 index 0000000..7aa0cd0 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/ValidationException.scala @@ -0,0 +1,25 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +class ValidationException(val msg: String, val errors: Seq[String]) + extends Exception(s"$msg Due to errors: ${errors.mkString(",")}") + +object ValidationException { + def unapply(arg: ValidationException): Option[(String, Seq[String])] = Some(arg.msg, arg.errors) +} + diff --git a/src/main/scala/za/co/absa/standardization/ValidationIssue.scala b/src/main/scala/za/co/absa/standardization/ValidationIssue.scala new file mode 100644 index 0000000..a6db6bb --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/ValidationIssue.scala @@ -0,0 +1,23 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +sealed abstract class ValidationIssue + +case class ValidationWarning(msg: String) extends ValidationIssue + +case class ValidationError(msg: String) extends ValidationIssue diff --git a/src/main/scala/za/co/absa/standardization/implicits/ColumnImplicits.scala b/src/main/scala/za/co/absa/standardization/implicits/ColumnImplicits.scala new file mode 100644 index 0000000..1af730c --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/implicits/ColumnImplicits.scala @@ -0,0 +1,112 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.implicits + +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions._ +import za.co.absa.standardization.types.Section + +object ColumnImplicits { + implicit class ColumnEnhancements(column: Column) { + def isInfinite: Column = { + column.isin(Double.PositiveInfinity, Double.NegativeInfinity) + } + + /** + * Spark string functions are 1-based (position of first char is 1) unlike 0-based in Java/Scala. The function shifts the substring indexation to be in accordance with + * Scala/ Java. + * Another enhancement is, that the function allows a negative index, denoting counting of the index from back + * This version takes the substring from the startPos until the end. + * @param startPos the index (zero-based) where to start the substring from, if negative it's counted from end + * @return column with requested substring + */ + def zeroBasedSubstr(startPos: Int): Column = { + if (startPos >= 0) { + zeroBasedSubstr(startPos, Int.MaxValue - startPos) + } else { + zeroBasedSubstr(startPos, -startPos) + } + } + + /** + * Spark strings are base on 1 unlike scala. The function shifts the substring indexation to be in accordance with + * Scala/ Java. + * Another enhancement is, that the function allows a negative index, denoting counting of the index from back + * This version takes the substring from the startPos and takes up to the given number of characters (less. + * @param startPos the index (zero based) where to start the substring from, if negative it's counted from end + * @param len length of the desired substring, if longer then the rest of the string, all the remaining characters are taken + * @return column with requested substring + */ + def zeroBasedSubstr(startPos: Int, len: Int): Column = { + if (startPos >= 0) { + column.substr(startPos + 1, len) + } else { + val startPosColumn = greatest(length(column) + startPos + 1, lit(1)) + val lenColumn = lit(len) + when(length(column) + startPos <= 0, length(column) + startPos).otherwise(0) + column.substr(startPosColumn,lenColumn) + } + } + + /** + * Spark strings are base on 1 unlike scala. The function shifts the substring indexation to be in accordance with + * Scala/ Java. + * Another enhancement is, that the function allows a negative index, denoting counting of the index from back + * This version takes the startPos and len from the provided Section object. + * @param section the start and requested length of the substring encoded within the Section object + * @return column with requested substring + */ + def zeroBasedSubstr(section: Section): Column = zeroBasedSubstr(section.start, section.length) + + /** + * Removes part of a StringType column, defined by the provided section. A column containing the remaining part of + * the string is returned + * @param section Definition of the part of the string to remove + * @return Column with the remaining parts of the string (concatenated) + */ + def removeSection(section: Section): Column = { + splitBySection(section) match { + case Left(result) => result + case Right((leftColumn, rightColumn)) => concat(leftColumn, rightColumn) + } + } + + /** + * Removes multiple sections from a StringType column. The operation is done in a way, as if the sections would be + * removed all "at once". E.g. removing Section(3, 1) and Section(6,1) removes the 3rd and 6th character (zero based), + * NOT the 3rd and 7th. + * @param sections Sections to removed + * @return Column with the remainders of the string concatenated + */ + def removeSections(sections: Seq[Section]): Column = { + val mergedSections = Section.mergeTouchingSectionsAndSort(sections) + mergedSections.foldLeft(column) ((columnAcc, item) => columnAcc.removeSection(item)) //TODO try more effectively #678 + } + + private def splitBySection(section: Section): Either[Column, (Column, Column)] = { + def upToNegative(negativeIndex: Int): Column = column.substr(lit(1), greatest(length(column) + negativeIndex, lit(0))) + + section match { + case Section(_, 0) => Left(column) + case Section(0, l) => Left(zeroBasedSubstr(l)) + case Section(s, l) if (s < 0) && (s + l >= 0) => Left(upToNegative(s)) //till the end + case Section(s, l) if s >= 0 => Right(zeroBasedSubstr(0, s), zeroBasedSubstr(s + l, Int.MaxValue)) + case Section(s, l) => Right(upToNegative(s), zeroBasedSubstr(s + l)) + } + } + + } +} diff --git a/src/main/scala/za/co/absa/standardization/implicits/DataFrameImplicits.scala b/src/main/scala/za/co/absa/standardization/implicits/DataFrameImplicits.scala new file mode 100644 index 0000000..86c91b6 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/implicits/DataFrameImplicits.scala @@ -0,0 +1,73 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.implicits + +import java.io.ByteArrayOutputStream +import org.apache.spark.sql.{Column, DataFrame} +import za.co.absa.standardization.schema.SparkUtils + +object DataFrameImplicits { + implicit class DataFrameEnhancements(val df: DataFrame) { + + private def gatherData(showFnc: () => Unit): String = { + val outCapture = new ByteArrayOutputStream + Console.withOut(outCapture) { + showFnc() + } + val dfData = new String(outCapture.toByteArray).replace("\r\n", "\n") + dfData + } + + def dataAsString(): String = { + val showFnc: () => Unit = df.show + gatherData(showFnc) + } + + def dataAsString(truncate: Boolean): String = { + val showFnc: () => Unit = ()=>{df.show(truncate)} + gatherData(showFnc) + } + + def dataAsString(numRows: Int, truncate: Boolean): String = { + val showFnc: ()=>Unit = () => df.show(numRows, truncate) + gatherData(showFnc) + } + + def dataAsString(numRows: Int, truncate: Int): String = { + val showFnc: ()=>Unit = () => df.show(numRows, truncate) + gatherData(showFnc) + } + + def dataAsString(numRows: Int, truncate: Int, vertical: Boolean): String = { + val showFnc: ()=>Unit = () => df.show(numRows, truncate, vertical) + gatherData(showFnc) + } + + /** + * Adds a column to a dataframe if it does not exist + * + * @param colName A column to add if it does not exist already + * @param col An expression for the column to add + * @return a new dataframe with the new column + */ + def withColumnIfDoesNotExist(colName: String, col: Column): DataFrame = { + SparkUtils.withColumnIfDoesNotExist(df, colName, col) + } + + } + +} diff --git a/src/main/scala/za/co/absa/standardization/implicits/OptionImplicits.scala b/src/main/scala/za/co/absa/standardization/implicits/OptionImplicits.scala new file mode 100644 index 0000000..861c563 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/implicits/OptionImplicits.scala @@ -0,0 +1,27 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.implicits + +import scala.util.{Failure, Success, Try} + +object OptionImplicits { + implicit class OptionEnhancements[T](option: Option[T]) { + def toTry(failure: Exception): Try[T] = { + option.fold[Try[T]](Failure(failure))(Success(_)) + } + } +} diff --git a/src/main/scala/za/co/absa/standardization/implicits/StringImplicits.scala b/src/main/scala/za/co/absa/standardization/implicits/StringImplicits.scala new file mode 100644 index 0000000..8d6fd80 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/implicits/StringImplicits.scala @@ -0,0 +1,208 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.implicits + +import java.security.InvalidParameterException +import scala.annotation.tailrec + +object StringImplicits { + implicit class StringEnhancements(string: String) { + + /** + * Replaces all occurrences of the provided characters with their mapped values + * @param replacements the map of replacements where key's are chars to search for and values are their replacements + * @return a string with characters replaced + */ + def replaceChars(replacements: Map[Char, Char]): String = { + if (replacements.isEmpty) { + string + } else { + val result = new StringBuilder(string.length) + string.foreach(char => result.append(replacements.getOrElse(char, char))) + result.toString + } + } + + /** + * Function to find the first occurrence of any of the characters from the charsToFind in the string. The + * occurrence is not considered if the character is part of a sequence within a pair of quote characters specified + * by quoteChars param. + * Escape character in front of a quote character will cancel its "quoting" function. + * Escape character in front of a searched-for character will not result in positive identification of a find + * Double escape character is considered to be escape character itself, without its special meaning + * The escape character can be part of the charsToFind set or quoteChars set and the function will work as + * expected (e.g. double escape character being recognized as a searched-for character or quote character), but it + * cannot be both - that will fire an exception. + * @param charsToFind set of characters to look for + * @param quoteChars set of characters that are considered as quotes, everything within two (same) quote characters + * is not considered + * @param escape the special character to escape the expected behavior within string + * @return the index of the first find within the string, or None in case of no find + */ + def findFirstUnquoted(charsToFind: Set[Char], quoteChars: Set[Char], escape: Char = '\\'): Option[Integer] = { + @tailrec + def scan(sub: String, idx: Integer, charToExitQuotes: Option[Char], escaped: Boolean = false): Option[Integer] = { + //escaped flag defaults to false, as every non-escape character clears it + val head = sub.headOption + (head, examineChar(head, charsToFind, quoteChars, escape, charToExitQuotes, escaped)) match { + case (None, _) => None // scanned the whole string without a hit + case (_, None) => Option(idx) // hit found + case (_, Some((nextCharToExitQuotes, nextEscaped))) => + scan(sub.tail, idx + 1, nextCharToExitQuotes, nextEscaped) // continue search + } + } + + checkInputsOverlap(charsToFind, quoteChars, escape: Char) + scan(string, 0, charToExitQuotes = None) + } + + /** + * Similar to above, only te result is true if anything is found, false otherwise + * @param charsToFind set of characters to look for + * @param quoteChars set of characters that are considered as quotes, everything within two (same) quote characters + * is not considered + * @param escape the special character to escape the expected behavior within string + * @return true if anything is found, false otherwise + */ + def hasUnquoted(charsToFind: Set[Char], quoteChars: Set[Char], escape: Char = '\\' ): Boolean = { + findFirstUnquoted(charsToFind, quoteChars, escape).nonEmpty + } + + /** + * Counts the occurrences of the chars to find. The occurrence is not considered if the character is part of a + * sequence within a pair of quote characters specified by quoteChars param. + * Escape character in front of a quote character will cancel its "quoting" function. + * Escape character in front of a searched-for character will not result in positive identification of a find + * Double escape character is considered to be escape character itself, without its special meaning + * The escape character can be part of the charsToFind set or quoteChars set and the function will work as + * expected (e.g. double escape character being recognized as a searched-for character or quote character), but it + * cannot be both - that will fire an exception. + * + * @param charsToFind set of characters to look for + * @param quoteChars set of characters that are considered as quotes, everything within two (same) quote characters + * is not considered + * @param escape the special character to escape the expected behavior within string + * @return map where charsToFind are the keys and values are the respective number of occurrences + */ + def countUnquoted(charsToFind: Set[Char], quoteChars: Set[Char], escape: Char = '\\'): Map[Char, Int] = { + checkInputsOverlap(charsToFind, quoteChars, escape: Char) + val resultInit: Map[Char, Int] = charsToFind.map((_, 0)).toMap + val examineInit: (Option[Char], Boolean) = (Option.empty, false) + val (result, _) = string.foldLeft((resultInit, examineInit)) ((acc, char) => { + val (resultAcc, (charToExitQuotes, escaped)) = acc + val examineResult = examineChar(Option(char), charsToFind, quoteChars, escape, charToExitQuotes, escaped) + examineResult.map((resultAcc, _)) //no hit, propagate the examineResult + .getOrElse(resultAcc ++ Map(char->(resultAcc(char) + 1)), examineInit) + }) + result + } + + private def checkInputsOverlap(charsToFind: Set[Char], quoteChars: Set[Char], escape: Char = '\\'):Unit = { + if (charsToFind.contains(escape) && quoteChars.contains(escape)) { + throw new InvalidParameterException( + s"Escape character '$escape 'is both between charsToFind and quoteChars. That's not allowed." + ) + } + } + + /** + * Investigates if the character in the relation to previous characters and charsToFind + * @param char character to examine, for easier matching and also supprot end of stirng, it's an Option + * @param charsToFind set of characters to look for + * @param quoteChars set of characters that are considered as quotes, everything within two (same) quote characters + * is not considered + * @param escape the special character to escape the expected behavior within string + * @param charToExitQuotes character that would is awaited to exit the "quotes"; if not empty means scan is within + * "quotes" + * @param escaped if true the previous character was the escape character + * @return Optional 2-tuple, None means hit (char is one of the charaToFind), otherwise it's the value of + * charToExitQuotes and escaped for the next character + */ + private def examineChar(char: Option[Char], + charsToFind: Set[Char], + quoteChars: Set[Char], + escape: Char, + charToExitQuotes: Option[Char], + escaped: Boolean = false): Option[(Option[Char], Boolean)] = { + (char, escaped) match { + // no more chars on input probably + case (None, _) => Option(None, false) + // following cases are to address situations when the char character is within quotes (not yet closed) + // exit quote unless it's escaped or is the escape character itself + case (`charToExitQuotes`, false) if !charToExitQuotes.contains(escape) => Option(None, false) + // escaped exit quote means exit only if it's the escape character itself + case (`charToExitQuotes`, true) if charToExitQuotes.contains(escape) => Option(None, false) + // escape charter found (not necessary withing quotes, but has to be handled it this order) + case (Some(`escape`), false) => Option(charToExitQuotes, true) + // any other character within quotes, no special case + case _ if charToExitQuotes.nonEmpty => Option(charToExitQuotes, false) + // following cases addresses situations when the char character is outside quotes + //escaped escape character if it's also a quote character + case (Some(`escape`), true) if quoteChars.contains(escape) => Option(Option(escape), false) + //escaped escape character if it's also a character to find + case (Some(`escape`), true) if charsToFind.contains(escape) => None + // entering quotes + case (Some(c), false) if quoteChars.contains(c) => Option(Option(c), false) + // found one of the characters to search for + case (Some(c), false) if charsToFind.contains(c) => None + // an escaped quote character that is also within the characters to find + case (Some(c), true) if quoteChars.contains(c) && charsToFind.contains(c) => None + //all other cases, continue scan + case _ => Option(None, false) + } + } + + + private[implicits] def joinWithSingleSeparator(another: String, sep: String): String = { + val sb = new StringBuilder + sb.append(string.stripSuffix(sep)) + sb.append(sep) + sb.append(another.stripPrefix(sep)) + sb.mkString + } + + /** + * Joins two strings with / while stripping single existing trailing/leading "/" in between: + * {{{ + * "abc" / "123" -> "abc/123" + * "abc/" / "123" -> "abc/123" + * "abc" / "/123" -> "abc/123" + * "abc/" / "/123" -> "abc/123", + * but: + * "file:///" / "path" -> "file:///path", + * }}} + * + * @param another the second string we are appending after the `/` separator + * @return this/another (this has stripped trailing / if present, another has leading / stripped if present) + */ + def /(another: String): String = { // scalastyle:ignore method.name + joinWithSingleSeparator(another, "/") + } + + def nonEmpyOrElse(default: => String): String = { + if (string.isEmpty) { + default + } else { + string + } + } + + def coalesce(alternatives: String*): String = { + alternatives.foldLeft(string)(_.nonEmpyOrElse(_)) + } + } +} diff --git a/src/main/scala/za/co/absa/standardization/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/standardization/implicits/StructFieldImplicits.scala new file mode 100644 index 0000000..d1787c9 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/implicits/StructFieldImplicits.scala @@ -0,0 +1,48 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.implicits + +import org.apache.spark.sql.types._ +import scala.util.Try + +object StructFieldImplicits { + implicit class StructFieldEnhancements(val structField: StructField) { + def getMetadataString(key: String): Option[String] = { + Try(structField.metadata.getString(key)).toOption + } + + def getMetadataChar(key: String): Option[Char] = { + val resultString = Try(structField.metadata.getString(key)).toOption + resultString.flatMap { s => + if (s.length == 1) { + Option(s(0)) + } else { + None + } + } + } + + def getMetadataStringAsBoolean(key: String): Option[Boolean] = { + Try(structField.metadata.getString(key).toBoolean).toOption + } + + + def hasMetadataKey(key: String): Boolean = { + structField.metadata.contains(key) + } + } +} diff --git a/src/main/scala/za/co/absa/standardization/numeric/DecimalSymbols.scala b/src/main/scala/za/co/absa/standardization/numeric/DecimalSymbols.scala new file mode 100644 index 0000000..053c7f6 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/numeric/DecimalSymbols.scala @@ -0,0 +1,105 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.numeric + +import java.text.DecimalFormatSymbols +import java.util.Locale + +/** + * It's an immutable wrapper of Java's DecimalFormatSymbols + * @param decimalSeparator the character used for decimal sign + * @param groupingSeparator the character used for thousands separator + * @param minusSign the character used to represent minus sign + * @param patternSeparator the character used to separate positive and negative subpatterns in a pattern + * @param percentSign the character used for percent sign + * @param permillSign the character used for per mille sign + * @param exponentSeparator the string used to separate the mantissa from the exponent + * @param infinityValue the string used to represent infinity + * @param naNValue the string used to represent "not a number" + */ +case class DecimalSymbols( + decimalSeparator: Char, + groupingSeparator: Char, + minusSign: Char, + patternSeparator: Char, + percentSign: Char, + permillSign: Char, + exponentSeparator: String, + infinityValue: String, + naNValue: String) { + val negativeInfinityValue = s"$minusSign$infinityValue" + + def toDecimalFormatSymbols: DecimalFormatSymbols = { + val result = new DecimalFormatSymbols(Locale.US) + result.setDecimalSeparator(decimalSeparator) + result.setGroupingSeparator(groupingSeparator) + result.setMinusSign(minusSign) + result.setPatternSeparator(patternSeparator) + result.setPercent(percentSign) + result.setPerMill(permillSign) + result.setExponentSeparator(exponentSeparator) + result.setInfinity(infinityValue) + result.setNaN(naNValue) + result + } + + def charSymbolsDifference(from: DecimalSymbols): Map[Char, Char] = { + Map( + from.decimalSeparator -> decimalSeparator, + from.minusSign -> minusSign, + from.patternSeparator -> patternSeparator, + from.percentSign -> percentSign, + from.permillSign -> permillSign, + from.groupingSeparator -> groupingSeparator + ).filter(charsDiffer) + } + + def basicSymbolsDifference(from: DecimalSymbols): Map[Char, Char] = { + Map( + from.decimalSeparator -> decimalSeparator, + from.minusSign -> minusSign, + // replace the standard symbols to "invalidate" them + decimalSeparator -> from.decimalSeparator, + minusSign -> from.minusSign + ).filter(charsDiffer) + } + + private def charsDiffer(chars: (Char, Char)): Boolean = { + chars._1 != chars._2 + } + +} + +object DecimalSymbols { + def apply(locale: Locale): DecimalSymbols = { + DecimalSymbols(new DecimalFormatSymbols(locale)) + } + + def apply(dfs: DecimalFormatSymbols): DecimalSymbols = { + DecimalSymbols( + decimalSeparator = dfs.getDecimalSeparator, + minusSign = dfs.getMinusSign, + patternSeparator = dfs.getPatternSeparator, + percentSign = dfs.getPercent, + permillSign = dfs.getPerMill, + infinityValue = dfs.getInfinity, + exponentSeparator = dfs.getExponentSeparator, + naNValue = dfs.getNaN, + groupingSeparator = dfs.getGroupingSeparator + ) + } +} diff --git a/src/main/scala/za/co/absa/standardization/numeric/NumericPattern.scala b/src/main/scala/za/co/absa/standardization/numeric/NumericPattern.scala new file mode 100644 index 0000000..43888a6 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/numeric/NumericPattern.scala @@ -0,0 +1,38 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.numeric + +import za.co.absa.standardization.numeric.NumericPattern._ +import za.co.absa.standardization.types.TypePattern + +case class NumericPattern (override val pattern: String, decimalSymbols: DecimalSymbols) + extends TypePattern(pattern, isDefault = (pattern == DefaultPatternValue)) { + + def specifiedPattern: Option[String] = { + if (isDefault) { + None + } else { + Option(pattern) + } + } +} + +object NumericPattern { + val DefaultPatternValue = "" + + def apply(decimalSymbols: DecimalSymbols): NumericPattern = NumericPattern(DefaultPatternValue, decimalSymbols) +} diff --git a/src/main/scala/za/co/absa/standardization/numeric/Radix.scala b/src/main/scala/za/co/absa/standardization/numeric/Radix.scala new file mode 100644 index 0000000..bbda7ae --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/numeric/Radix.scala @@ -0,0 +1,71 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.numeric + +import scala.util.control.NonFatal + +class Radix private(val value: Int) extends AnyVal { + override def toString: String = { + s"Radix($value)" + } +} + +object Radix { + private val MaxSupportedRadixValue = 36 //that's up to 0..9A..Z (case insensitive) + + implicit object RadixOrdering extends Ordering[Radix] { + override def compare(a: Radix, b: Radix): Int = a.value compare b.value + } + + + val MaxSupportedRadix = Radix(MaxSupportedRadixValue) + val DefaultRadix = Radix(10) // scalastyle:ignore magic.number + + + def apply(value: Int): Radix = { + if (value <= 0) { + throw new RadixFormatException(s"Radix has to be greater then 0, $value was entered") + } + if (value > MaxSupportedRadixValue) { + throw new RadixFormatException(s"Maximum supported radix is ${Radix.MaxSupportedRadix.value}, $value was entered") + } + + new Radix(value) + } + def apply(string: String): Radix = { + // scalastyle:off magic.number obvious meaning + val value = string.toLowerCase() match { + case "" | "dec" | "decimal" => 10 + case "hex" | "hexadecimal" => 16 + case "bin" | "binary" => 2 + case "oct" | "octal" => 8 + case x => + try { + x.toInt + } + catch { + case NonFatal(e) => throw new RadixFormatException(s"'$x' was not recognized as a Radix value") + } + } + // scalastyle:on magic.number + Radix(value) + } + + def unapply(arg: Radix): Option[Int] = Some(arg.value) + + class RadixFormatException(s: String = "") extends NumberFormatException(s) +} diff --git a/src/main/scala/za/co/absa/standardization/schema/MetadataKeys.scala b/src/main/scala/za/co/absa/standardization/schema/MetadataKeys.scala new file mode 100644 index 0000000..050a2d2 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/schema/MetadataKeys.scala @@ -0,0 +1,46 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.schema + +object MetadataKeys { + // all + val SourceColumn = "sourcecolumn" + val DefaultValue = "default" + // date & timestamp + val DefaultTimeZone = "timezone" + // date & timestamp & all numeric + val Pattern = "pattern" + // all numeric + val DecimalSeparator = "decimal_separator" + val GroupingSeparator = "grouping_separator" + val MinusSign = "minus_sign" + // float and double + val AllowInfinity = "allow_infinity" + // integral types + val Radix = "radix" + // binary + val Encoding = "encoding" + //decimal + val StrictParsing = "strict_parsing" +} + +object MetadataValues { + object Encoding { + val Base64 = "base64" + val None = "none" + } +} diff --git a/src/main/scala/za/co/absa/standardization/schema/SchemaUtils.scala b/src/main/scala/za/co/absa/standardization/schema/SchemaUtils.scala new file mode 100644 index 0000000..65cf4f6 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/schema/SchemaUtils.scala @@ -0,0 +1,605 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.schema + +import org.apache.spark.sql.types._ +import scala.annotation.tailrec +import scala.util.{Random, Try} + +object SchemaUtils { + + /** + * Returns the parent path of a field. Returns an empty string if a root level field name is provided. + * + * @param columnName A fully qualified column name + * @return The parent column name or an empty string if the input column is a root level column + */ + def getParentPath(columnName: String): String = { + val index = columnName.lastIndexOf('.') + if (index > 0) { + columnName.substring(0, index) + } else { + "" + } + } + + /** + * Get a field from a text path and a given schema + * @param path The dot-separated path to the field + * @param schema The schema which should contain the specified path + * @return Some(the requested field) or None if the field does not exist + */ + def getField(path: String, schema: StructType): Option[StructField] = { + + @tailrec + def goThroughArrayDataType(dataType: DataType): DataType = { + dataType match { + case ArrayType(dt, _) => goThroughArrayDataType(dt) + case result => result + } + } + + @tailrec + def examineStructField(names: List[String], structField: StructField): Option[StructField] = { + if (names.isEmpty) { + Option(structField) + } else { + structField.dataType match { + case struct: StructType => examineStructField(names.tail, struct(names.head)) + case ArrayType(el: DataType, _) => + goThroughArrayDataType(el) match { + case struct: StructType => examineStructField(names.tail, struct(names.head)) + case _ => None + } + case _ => None + } + } + } + + val pathTokens = path.split('.').toList + Try{ + examineStructField(pathTokens.tail, schema(pathTokens.head)) + }.getOrElse(None) + } + + /** + * Get a type of a field from a text path and a given schema + * + * @param path The dot-separated path to the field + * @param schema The schema which should contain the specified path + * @return Some(the type of the field) or None if the field does not exist + */ + def getFieldType(path: String, schema: StructType): Option[DataType] = { + getField(path, schema).map(_.dataType) + } + + /** + * Checks if the specified path is an array of structs + * + * @param path The dot-separated path to the field + * @param schema The schema which should contain the specified path + * @return true if the field is an array of structs + */ + def isColumnArrayOfStruct(path: String, schema: StructType): Boolean = { + getFieldType(path, schema) match { + case Some(dt) => + dt match { + case arrayType: ArrayType => + arrayType.elementType match { + case _: StructType => true + case _ => false + } + case _ => false + } + case None => false + } + } + + /** + * Get nullability of a field from a text path and a given schema + * + * @param path The dot-separated path to the field + * @param schema The schema which should contain the specified path + * @return Some(nullable) or None if the field does not exist + */ + def getFieldNullability(path: String, schema: StructType): Option[Boolean] = { + getField(path, schema).map(_.nullable) + } + + /** + * Checks if a field specified by a path and a schema exists + * @param path The dot-separated path to the field + * @param schema The schema which should contain the specified path + * @return True if the field exists false otherwise + */ + def fieldExists(path: String, schema: StructType): Boolean = { + getField(path, schema).nonEmpty + } + + /** + * Returns all renames in the provided schema. + * @param schema schema to examine + * @param includeIfPredecessorChanged if set to true, fields are included even if their name have not changed but + * a predecessor's (parent, grandparent etc.) has + * @return the keys of the returned map are the columns' names after renames, the values are the source columns; + * the name are full paths denoted with dot notation + */ + def getRenamesInSchema(schema: StructType, includeIfPredecessorChanged: Boolean = true): Map[String, String] = { + + def getRenamesRecursively(path: String, + sourcePath: String, + struct: StructType, + renamesAcc: Map[String, String], + predecessorChanged: Boolean): Map[String, String] = { + import za.co.absa.standardization.implicits.StructFieldImplicits.StructFieldEnhancements + + struct.fields.foldLeft(renamesAcc) { (renamesSoFar, field) => + val fieldFullName = appendPath(path, field.name) + val fieldSourceName = field.getMetadataString(MetadataKeys.SourceColumn).getOrElse(field.name) + val fieldFullSourceName = appendPath(sourcePath, fieldSourceName) + + val (renames, renameOnPath) = if ((fieldSourceName != field.name) || (predecessorChanged && includeIfPredecessorChanged)) { + (renamesSoFar + (fieldFullName -> fieldFullSourceName), true) + } else { + (renamesSoFar, predecessorChanged) + } + + field.dataType match { + case st: StructType => getRenamesRecursively(fieldFullName, fieldFullSourceName, st, renames, renameOnPath) + case at: ArrayType => getStructInArray(at.elementType).fold(renames) { item => + getRenamesRecursively(fieldFullName, fieldFullSourceName, item, renames, renameOnPath) + } + case _ => renames + } + } + } + + @tailrec + def getStructInArray(dataType: DataType): Option[StructType] = { + dataType match { + case st: StructType => Option(st) + case at: ArrayType => getStructInArray(at.elementType) + case _ => None + } + } + + getRenamesRecursively("", "", schema, Map.empty, predecessorChanged = false) + } + + /** + * Get first array column's path out of complete path. + * + * E.g if the path argument is "a.b.c.d.e" where b and d are arrays, "a.b" will be returned. + * + * @param path The path to the attribute + * @param schema The schema of the whole dataset + * @return The path of the first array field or "" if none were found + */ + def getFirstArrayPath(path: String, schema: StructType): String = { + @tailrec + def helper(remPath: Seq[String], pathAcc: Seq[String]): Seq[String] = { + if (remPath.isEmpty) Seq() else { + val currPath = (pathAcc :+ remPath.head).mkString(".") + val currType = getFieldType(currPath, schema) + currType match { + case Some(_: ArrayType) => pathAcc :+ remPath.head + case Some(_) => helper(remPath.tail, pathAcc :+ remPath.head) + case None => Seq() + } + } + } + + val pathToks = path.split('.') + helper(pathToks, Seq()).mkString(".") + } + + /** + * Get paths for all array subfields of this given datatype + */ + def getAllArraySubPaths(path: String, name: String, dt: DataType): Seq[String] = { + val currPath = appendPath(path, name) + dt match { + case s: StructType => s.fields.flatMap(f => getAllArraySubPaths(currPath, f.name, f.dataType)) + case _@ArrayType(elType, _) => getAllArraySubPaths(path, name, elType) :+ currPath + case _ => Seq() + } + } + + /** + * Get all array columns' paths out of complete path. + * + * E.g. if the path argument is "a.b.c.d.e" where b and d are arrays, "a.b" and "a.b.c.d" will be returned. + * + * @param path The path to the attribute + * @param schema The schema of the whole dataset + * @return Seq of dot-separated paths for all array fields in the provided path + */ + def getAllArraysInPath(path: String, schema: StructType): Seq[String] = { + @tailrec + def helper(remPath: Seq[String], pathAcc: Seq[String], arrayAcc: Seq[String]): Seq[String] = { + if (remPath.isEmpty) arrayAcc else { + val currPath = (pathAcc :+ remPath.head).mkString(".") + val currType = getFieldType(currPath, schema) + currType match { + case Some(_: ArrayType) => + val strings = pathAcc :+ remPath.head + helper(remPath.tail, strings, arrayAcc :+ strings.mkString(".")) + case Some(_) => helper(remPath.tail, pathAcc :+ remPath.head, arrayAcc) + case None => arrayAcc + } + } + } + + val pathToks = path.split("\\.") + helper(pathToks, Seq(), Seq()) + } + + /** + * For a given list of field paths determines the deepest common array path. + * + * For instance, if given 'a.b', 'a.b.c', 'a.b.c.d' where b and c are arrays the common deepest array + * path is 'a.b.c'. + * + * If any of the arrays are on diverging paths this function returns None. + * + * The purpose of the function is to determine the order of explosions to be made before the dataframe can be + * joined on a field inside an array. + * + * @param schema A Spark schema + * @param fieldPaths A list of paths to analyze + * @return Returns a common array path if there is one and None if any of the arrays are on diverging paths + */ + def getDeepestCommonArrayPath(schema: StructType, fieldPaths: Seq[String]): Option[String] = { + val arrayPaths = fieldPaths.flatMap(path => getAllArraysInPath(path, schema)).distinct + + if (arrayPaths.nonEmpty && isCommonSubPath(arrayPaths: _*)) { + Some(arrayPaths.maxBy(_.length)) + } else { + None + } + } + + /** + * For a field path determines the deepest array path. + * + * For instance, if given 'a.b.c.d' where b and c are arrays the deepest array is 'a.b.c'. + * + * @param schema A Spark schema + * @param fieldPath A path to analyze + * @return Returns a common array path if there is one and None if any of the arrays are on diverging paths + */ + def getDeepestArrayPath(schema: StructType, fieldPath: String): Option[String] = { + val arrayPaths = getAllArraysInPath(fieldPath, schema) + + if (arrayPaths.nonEmpty) { + Some(arrayPaths.maxBy(_.length)) + } else { + None + } + } + + /** + * For a given list of field paths determines if any path pair is a subset of one another. + * + * For instance, + * - 'a.b', 'a.b.c', 'a.b.c.d' have this property. + * - 'a.b', 'a.b.c', 'a.x.y' does NOT have it, since 'a.b.c' and 'a.x.y' have diverging subpaths. + * + * @param paths A list of paths to be analyzed + * @return true if for all pathe the above property holds + */ + def isCommonSubPath(paths: String*): Boolean = { + def sliceRoot(paths: Seq[Seq[String]]): Seq[Seq[String]] = { + paths.map(path => path.drop(1)).filter(_.nonEmpty) + } + + var isParentCommon = true // For Seq() the property holds by [my] convention + var restOfPaths: Seq[Seq[String]] = paths.map(_.split('.').toSeq).filter(_.nonEmpty) + while (isParentCommon && restOfPaths.nonEmpty) { + val parent = restOfPaths.head.head + isParentCommon = restOfPaths.forall(path => path.head == parent) + restOfPaths = sliceRoot(restOfPaths) + } + isParentCommon + } + + /** + * Get paths for all array fields in the schema + * + * @param schema The schema in which to look for array fields + * @return Seq of dot separated paths of fields in the schema, which are of type Array + */ + def getAllArrayPaths(schema: StructType): Seq[String] = { + schema.fields.flatMap(f => getAllArraySubPaths("", f.name, f.dataType)).toSeq + } + + /** + * Append a new attribute to path or empty string. + * + * @param path The dot-separated existing path + * @param fieldName Name of the field to be appended to the path + * @return The path with the new field appended or the field itself if path is empty + */ + def appendPath(path: String, fieldName: String): String = { + if (path.isEmpty) { + fieldName + } else if (fieldName.isEmpty) { + path + } else { + s"$path.$fieldName" + } + } + + /** + * Determine if a datatype is a primitive one + */ + def isPrimitive(dt: DataType): Boolean = dt match { + case _: BinaryType + | _: BooleanType + | _: ByteType + | _: DateType + | _: DecimalType + | _: DoubleType + | _: FloatType + | _: IntegerType + | _: LongType + | _: NullType + | _: ShortType + | _: StringType + | _: TimestampType => true + case _ => false + } + + /** + * Determine the name of a field + * Will override to "sourcecolumn" in the Metadata if it exists + * + * @param field field to work with + * @return Metadata "sourcecolumn" if it exists or field.name + */ + def getFieldNameOverriddenByMetadata(field: StructField): String = { + if (field.metadata.contains(MetadataKeys.SourceColumn)) { + field.metadata.getString(MetadataKeys.SourceColumn) + } else { + field.name + } + } + + /** + * For an array of arrays of arrays, ... get the final element type at the bottom of the array + * + * @param arrayType An array data type from a Spark dataframe schema + * @return A non-array data type at the bottom of array nesting + */ + @tailrec + def getDeepestArrayType(arrayType: ArrayType): DataType = { + arrayType.elementType match { + case a: ArrayType => getDeepestArrayType(a) + case b => b + } + } + + /** + * Generate a unique column name + * + * @param prefix A prefix to use for the column name + * @param schema An optional schema to validate if the column already exists (a very low probability) + * @return A name that can be used as a unique column name + */ + def getUniqueName(prefix: String, schema: Option[StructType]): String = { + schema match { + case None => + s"${prefix}_${Random.nextLong().abs}" + case Some(sch) => + var exists = true + var columnName = "" + while (exists) { + columnName = s"${prefix}_${Random.nextLong().abs}" + exists = sch.fields.exists(_.name.compareToIgnoreCase(columnName) == 0) + } + columnName + } + } + + /** + * Get a closest unique column name + * + * @param desiredName A prefix to use for the column name + * @param schema A schema to validate if the column already exists + * @return A name that can be used as a unique column name + */ + def getClosestUniqueName(desiredName: String, schema: StructType): String = { + var exists = true + var columnName = "" + var i = 0 + while (exists) { + columnName = if (i == 0) desiredName else s"${desiredName}_$i" + exists = schema.fields.exists(_.name.compareToIgnoreCase(columnName) == 0) + i += 1 + } + columnName + } + + /** + * Checks if a casting between types always succeeds + * + * @param sourceType A type to be casted + * @param targetType A type to be casted to + * @return true if casting never fails + */ + def isCastAlwaysSucceeds(sourceType: DataType, targetType: DataType): Boolean = { + (sourceType, targetType) match { + case (_: StructType, _) | (_: ArrayType, _) => false + case (a, b) if a == b => true + case (_, _: StringType) => true + case (_: ByteType, _: ShortType | _: IntegerType | _: LongType) => true + case (_: ShortType, _: IntegerType | _: LongType) => true + case (_: IntegerType, _: LongType) => true + case (_: DateType, _: TimestampType) => true + case _ => false + } + } + + /** + * Checks if a field is an array + * + * @param schema A schema + * @param fieldPathName A field to check + * @return true if the specified field is an array + */ + def isArray(schema: StructType, fieldPathName: String): Boolean = { + @tailrec + def arrayHelper(arrayField: ArrayType, path: Seq[String]): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + + arrayField.elementType match { + case st: StructType => structHelper(st, path.tail) + case ar: ArrayType => arrayHelper(ar, path) + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + false + } + } + + def structHelper(structField: StructType, path: Seq[String]): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + var isArray = false + structField.fields.foreach(field => + if (field.name == currentField) { + field.dataType match { + case st: StructType => + if (!isLeaf) { + isArray = structHelper(st, path.tail) + } + case ar: ArrayType => + if (isLeaf) { + isArray = true + } else { + isArray = arrayHelper(ar, path) + } + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + } + } + ) + isArray + } + + val path = fieldPathName.split('.') + structHelper(schema, path) + } + + /** + * Checks if a field is an array that is not nested in another array + * + * @param schema A schema + * @param fieldPathName A field to check + * @return true if a field is an array that is not nested in another array + */ + def isNonNestedArray(schema: StructType, fieldPathName: String): Boolean = { + def structHelper(structField: StructType, path: Seq[String]): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + var isArray = false + structField.fields.foreach(field => + if (field.name == currentField) { + field.dataType match { + case st: StructType => + if (!isLeaf) { + isArray = structHelper(st, path.tail) + } + case _: ArrayType => + if (isLeaf) { + isArray = true + } + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + } + } + ) + isArray + } + + val path = fieldPathName.split('.') + structHelper(schema, path) + } + + /** + * Checks if a field is the only field in a struct + * + * @param schema A schema + * @param column A column to check + * @return true if the column is the only column in a struct + */ + def isOnlyField(schema: StructType, column: String): Boolean = { + def structHelper(structField: StructType, path: Seq[String]): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + var isOnlyField = false + structField.fields.foreach(field => + if (field.name == currentField) { + if (isLeaf) { + isOnlyField = structField.fields.length == 1 + } else { + field.dataType match { + case st: StructType => + isOnlyField = structHelper(st, path.tail) + case _: ArrayType => + throw new IllegalArgumentException( + s"SchemaUtils.isOnlyField() does not support checking struct fields inside an array") + case _ => + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $column") + } + } + } + ) + isOnlyField + } + val path = column.split('.') + structHelper(schema, path) + } + + /** + * Converts a fully qualified field name (including its path, e.g. containing fields) to a unique field name without + * dot notation + * @param path the fully qualified field name + * @return unique top level field name + */ + def unpath(path: String): String = { + path.replace("_", "__") + .replace('.', '_') + } + + implicit class FieldWithSource(val structField: StructField) { + def sourceName: String = { + SchemaUtils.getFieldNameOverriddenByMetadata(structField.structField) + } + } + +} diff --git a/src/main/scala/za/co/absa/standardization/schema/SparkUtils.scala b/src/main/scala/za/co/absa/standardization/schema/SparkUtils.scala new file mode 100644 index 0000000..7a2c61e --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/schema/SparkUtils.scala @@ -0,0 +1,108 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.schema + +import org.apache.log4j.{LogManager, Logger} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import za.co.absa.standardization.ErrorMessage +import za.co.absa.standardization.udf.UDFLibrary +import za.co.absa.spark.hats.transformations.NestedArrayTransformations + + +/** + * General Spark utils + */ +object SparkUtils { + private val log: Logger = LogManager.getLogger(this.getClass) + private final val DefaultColumnNameOfCorruptRecord = "_corrupt_record" + + final val ColumnNameOfCorruptRecordConf = "spark.sql.columnNameOfCorruptRecord" + + /** + * Ensures that the 'spark.sql.columnNameOfCorruptRecord' Spark setting is set to unique field name not present in the + * provided schema + * @param spark the spark session to set the + * @param schema the schema to check uniqueness against + * @return the field name set + */ + def setUniqueColumnNameOfCorruptRecord(spark: SparkSession, schema: StructType): String = { + val result = if (SchemaUtils.fieldExists(DefaultColumnNameOfCorruptRecord, schema)) { + SchemaUtils.getClosestUniqueName(DefaultColumnNameOfCorruptRecord, schema) + } else { + DefaultColumnNameOfCorruptRecord + } + spark.conf.set(ColumnNameOfCorruptRecordConf, result) + result + } + + /** + * Adds a column to a dataframe if it does not exist + * + * @param df A dataframe + * @param colName A column to add if it does not exist already + * @param colExpr An expression for the column to add + * @return a new dataframe with the new column + */ + def withColumnIfDoesNotExist(df: DataFrame, colName: String, colExpr: Column): DataFrame = { + if (df.schema.exists(field => field.name.equalsIgnoreCase(colName))) { + log.warn(s"Column '$colName' already exists. The content of the column will be overwritten.") + overwriteWithErrorColumn(df, colName, colExpr) + } else { + df.withColumn(colName, colExpr) + } + } + + /** + * Overwrites a column with a value provided by an expression. + * If the value in the column does not match the one provided by the expression, an error will be + * added to the error column. + * + * @param df A dataframe + * @param colName A column to be overwritten + * @param colExpr An expression for the value to write + * @return a new dataframe with the value of the column being overwritten + */ + private def overwriteWithErrorColumn(df: DataFrame, colName: String, colExpr: Column): DataFrame = { + implicit val spark: SparkSession = df.sparkSession + implicit val udfLib: UDFLibrary = new UDFLibrary + + + val tmpColumn = SchemaUtils.getUniqueName("tmpColumn", Some(df.schema)) + val tmpErrColumn = SchemaUtils.getUniqueName("tmpErrColumn", Some(df.schema)) + val litErrUdfCall = callUDF("confLitErr", lit(colName), col(tmpColumn)) + + // Rename the original column to a temporary name. We need it for comparison. + val dfWithColRenamed = df.withColumnRenamed(colName, tmpColumn) + + // Add new column with the intended value + val dfWithIntendedColumn = dfWithColRenamed.withColumn(colName, colExpr) + + // Add a temporary error column containing errors if the original value does not match the intended one + val dfWithErrorColumn = dfWithIntendedColumn + .withColumn(tmpErrColumn, array(when(col(tmpColumn) =!= colExpr, litErrUdfCall).otherwise(null))) // scalastyle:ignore null + + // Gather all errors in errCol + val dfWithAggregatedErrColumn = NestedArrayTransformations + .gatherErrors(dfWithErrorColumn, tmpErrColumn, ErrorMessage.errorColumnName) + + // Drop the temporary column + dfWithAggregatedErrColumn.drop(tmpColumn) + } + +} diff --git a/src/main/scala/za/co/absa/standardization/stages/PlainSchemaGenerator.scala b/src/main/scala/za/co/absa/standardization/stages/PlainSchemaGenerator.scala new file mode 100644 index 0000000..288a21a --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/stages/PlainSchemaGenerator.scala @@ -0,0 +1,56 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.stages + +import org.apache.spark.sql.types._ +import za.co.absa.standardization.implicits.StructFieldImplicits.StructFieldEnhancements +import za.co.absa.standardization.schema.MetadataKeys + +/** + * This component is used in the standardization job. We've got a strongly typed (target) schema. When reading the data however, we do not want spark to apply casts + * automatically. Instead we want to read all primitive types as strings. This component takes the target (desired) schema and generates a plain one which is used for + * reading source data in the job. + */ +object PlainSchemaGenerator { + + private def structTypeFieldsConversion(fields: Array[StructField]): Array[StructField] = { + import za.co.absa.standardization.implicits.StructFieldImplicits.StructFieldEnhancements + fields.map { field => + // If the meta data value sourcecolumn is set override the field name + val fieldName = field.getMetadataString(MetadataKeys.SourceColumn).getOrElse(field.name) + val dataType = inputSchemaAsStringTypes(field.dataType) + StructField(fieldName, dataType, nullable = true, field.metadata) + } + } + + private def inputSchemaAsStringTypes(inp: DataType): DataType = { + inp match { + case st: StructType => StructType(structTypeFieldsConversion(st.fields)) + case at: ArrayType => ArrayType(inputSchemaAsStringTypes(at.elementType), containsNull = true) + case _: DataType => StringType + } + } + + def generateInputSchema(structType: StructType, corruptRecordFieldName: Option[String] = None): StructType = { + val inputSchema = structTypeFieldsConversion(structType.fields) + val corruptRecordField = corruptRecordFieldName.map(StructField(_, StringType)).toArray + StructType(inputSchema ++ corruptRecordField) + + } + + +} diff --git a/src/main/scala/za/co/absa/standardization/stages/SchemaChecker.scala b/src/main/scala/za/co/absa/standardization/stages/SchemaChecker.scala new file mode 100644 index 0000000..ff6ced0 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/stages/SchemaChecker.scala @@ -0,0 +1,70 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.stages + +import org.apache.log4j.{LogManager, Logger} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType +import za.co.absa.standardization.SchemaValidator.{validateErrorColumn, validateSchema} +import za.co.absa.standardization.{ValidationError, ValidationIssue, ValidationWarning} + +object SchemaChecker { + + val log: Logger = LogManager.getLogger(this.getClass) + + /** + * Validate a schema, log all errors and warnings, throws if there are fatal errors + * + * @param schema A Spark schema + */ + def validateSchemaAndLog(schema: StructType) + (implicit spark: SparkSession): (Seq[String], Seq[String]) = { + val failures = validateSchema(schema) ::: validateErrorColumn(schema) + + type ColName = String + type Pattern = String + + val flattenedIssues: Seq[(ColName, Pattern, ValidationIssue)] = + for { + failure <- failures + issue <- failure.issues + } yield (failure.fieldName, failure.pattern, issue) + + // This code crafts 2 lists of messages. The first one will contain all errors and the second will contain all warnings. + // + // This code was a result of a long discussion. It's still not perfect, but it balances + // trade-offs between readability, conciseness and performance. + // + val errorMessages: (Seq[String], Seq[String]) = flattenedIssues.foldLeft(Nil: Seq[String], Nil: Seq[String])((accumulator, current) => { + val (errors, warnings) = accumulator + val (column, pattern, issue) = current + issue match { + case ValidationError(text) => + val msg = s"Validation error for column '$column', pattern '$pattern': $text" + log.error(msg) + (errors :+ msg, warnings) + case ValidationWarning(text) => + val msg = s"Validation warning for column '$column', pattern '$pattern': $text" + log.warn(msg) + (errors, warnings :+ msg) + } + }) + + errorMessages + } + +} diff --git a/src/main/scala/za/co/absa/standardization/stages/TypeParser.scala b/src/main/scala/za/co/absa/standardization/stages/TypeParser.scala new file mode 100644 index 0000000..07818b6 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/stages/TypeParser.scala @@ -0,0 +1,658 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.stages + +import java.security.InvalidParameterException +import java.sql.Timestamp +import java.util.Date +import java.util.regex.Pattern +import org.apache.spark.sql.Column +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.slf4j.{Logger, LoggerFactory} +import za.co.absa.standardization.ErrorMessage +import za.co.absa.standardization.schema.MetadataValues +import za.co.absa.standardization.schema.SchemaUtils.FieldWithSource +import za.co.absa.standardization.types.Defaults +import za.co.absa.standardization.types.TypedStructField._ +import za.co.absa.standardization.udf.{UDFBuilder, UDFLibrary, UDFNames} +import za.co.absa.standardization.schema.SchemaUtils +import za.co.absa.standardization.time.DateTimePattern +import za.co.absa.standardization.typeClasses.{DoubleLike, LongLike} +import za.co.absa.standardization.types.{ParseOutput, TypedStructField} +import za.co.absa.spark.hofs.transform + +import scala.reflect.runtime.universe._ +import scala.util.{Random, Try} + +/** + * Base trait for standardization function + * Each final class in the hierarchy represents a `standardize` function for its specific data type field + * Class hierarchy: + * TypeParser + * ArrayParser ! + * StructParser ! + * PrimitiveParser + * ScalarParser + * NumericParser ! + * StringParser ! + * BooleanParser ! + * BinaryParser ! + * DateTimeParser + * TimestampParser ! + * DateParser ! + */ +sealed trait TypeParser[T] { + + def standardize()(implicit logger: Logger): ParseOutput = { + checkSetupForFailure().getOrElse( + standardizeAfterCheck() + ) + } + + protected val failOnInputNotPerSchema: Boolean + protected val field: TypedStructField + protected val metadata: Metadata = field.structField.metadata + protected val path: String + protected val origType: DataType + protected val fieldInputName: String = field.structField.sourceName + protected val fieldOutputName: String = field.name + protected val inputFullPathName: String = SchemaUtils.appendPath(path, fieldInputName) + protected val isArrayElement: Boolean + protected val columnIdForUdf: String = if (isArrayElement) { + s"$inputFullPathName[*]" + } else { + inputFullPathName + } + + protected val column: Column + + protected def fieldType: DataType = field.dataType + + // Error should never appear here due to validation + protected def defaultValue: Option[field.BaseType] = field.defaultValueWithGlobal.get + + protected def checkSetupForFailure()(implicit logger: Logger): Option[ParseOutput] = { + def noCastingPossible: Option[ParseOutput] = { + val message = s"Cannot standardize field '$inputFullPathName' from type ${origType.typeName} into ${fieldType.typeName}" + if (failOnInputNotPerSchema) { + throw new TypeParserException(message) + } else { + logger.info(message) + Option(ParseOutput( + lit(defaultValue.orNull).cast(fieldType) as(fieldOutputName, metadata), + typedLit(Seq(ErrorMessage.stdTypeError(inputFullPathName, origType.typeName, fieldType.typeName))) + )) + } + } + + (origType, fieldType) match { + case (ArrayType(_, _), ArrayType(_, _)) => None + case (StructType(_), StructType(_)) => None + case (ArrayType(_, _), _) => noCastingPossible + case (_, ArrayType(_, _)) => noCastingPossible + case (StructType(_), _) => noCastingPossible + case (_, StructType(_)) => noCastingPossible + case _ => None + } + } + + protected def standardizeAfterCheck()(implicit logger: Logger): ParseOutput +} + +object TypeParser { + import za.co.absa.standardization.implicits.ColumnImplicits.ColumnEnhancements + + private val decimalType = DecimalType(30,9) // scalastyle:ignore magic.number + private implicit val logger: Logger = LoggerFactory.getLogger(this.getClass) + + private val MillisecondsPerSecond = 1000 + private val MicrosecondsPerSecond = 1000000 + private val NanosecondsPerSecond = 1000000000 + private val InfinityStr = "Infinity" + private val nullColumn = lit(null) //scalastyle:ignore null + + + def standardize(field: StructField, path: String, origSchema: StructType, failOnInputNotPerSchema: Boolean = true) + (implicit udfLib: UDFLibrary, defaults: Defaults): ParseOutput = { + // udfLib implicit is present for error column UDF implementation + val sourceName = SchemaUtils.appendPath(path, field.sourceName) + val origField = SchemaUtils.getField(sourceName, origSchema) + val origFieldType = origField.map(_.dataType).getOrElse(NullType) + val column = origField.fold(nullColumn)(_ => col(sourceName)) + TypeParser(field, path, column, origFieldType, failOnInputNotPerSchema).standardize() + } + + sealed trait Parent { + val parentColumn: Column + def childColumn(fieldName: String): Column + } + + private final case class ArrayParent (parentColumn: Column) extends Parent { + override def childColumn(fieldName: String): Column = parentColumn + } + private final case class StructParent (parentColumn: Column) extends Parent { + override def childColumn(fieldName: String): Column = parentColumn(fieldName) + } + + private def apply(field: StructField, + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean = false) + (implicit defaults: Defaults): TypeParser[_] = { + val parserClass: (String, Column, DataType, Boolean, Boolean) => TypeParser[_] = field.dataType match { + case _: ArrayType => ArrayParser(TypedStructField.asArrayTypeStructField(field), _, _, _, _, _) + case _: StructType => StructParser(TypedStructField.asStructTypeStructField(field), _, _, _, _, _) + case _: ByteType => + IntegralParser(TypedStructField.asNumericTypeStructField[Byte](field), _, _, _, _, _, Set(ShortType, IntegerType, LongType)) + case _: ShortType => + IntegralParser(TypedStructField.asNumericTypeStructField[Short](field), _, _, _, _, _, Set(IntegerType, LongType)) + case _: IntegerType => IntegralParser(TypedStructField.asNumericTypeStructField[Int](field), _, _, _, _, _, Set(LongType)) + case _: LongType => IntegralParser(TypedStructField.asNumericTypeStructField[Long](field), _, _, _, _, _, Set.empty) + case _: FloatType => FractionalParser(TypedStructField.asNumericTypeStructField[Float](field), _, _, _, _, _) + case _: DoubleType => FractionalParser(TypedStructField.asNumericTypeStructField[Double](field), _, _, _, _, _) + case _: DecimalType => DecimalParser(TypedStructField.asNumericTypeStructField[BigDecimal](field), _, _, _, _, _) + case _: StringType => StringParser(TypedStructField(field), _, _, _, _, _) + case _: BinaryType => BinaryParser(TypedStructField.asBinaryTypeStructField(field), _, _, _, _, _) + case _: BooleanType => BooleanParser(TypedStructField(field), _, _, _, _, _) + case _: DateType => DateParser(TypedStructField.asDateTimeTypeStructField(field), _, _, _, _, _) + case _: TimestampType => TimestampParser(TypedStructField.asDateTimeTypeStructField(field), _, _, _, _, _) + case t => throw new IllegalStateException(s"${t.typeName} is not a supported type in this version of Enceladus") + } + parserClass(path, column, origType, failOnInputNotPerSchema, isArrayElement) + } + + private final case class ArrayParser(override val field: ArrayTypeStructField, + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean) + (implicit defaults: Defaults) extends TypeParser[Any] { + + override def fieldType: ArrayType = { + field.dataType + } + + override protected def standardizeAfterCheck()(implicit logger: Logger): ParseOutput = { + logger.info(s"Creating standardization plan for Array $inputFullPathName") + val origArrayType = origType.asInstanceOf[ArrayType] // this should never throw an exception because of `checkSetupForFailure` + val arrayField = StructField(fieldInputName, fieldType.elementType, fieldType.containsNull, field.structField.metadata) + val lambdaVariableName = s"${SchemaUtils.unpath(inputFullPathName)}_${Random.nextLong().abs}" + val lambda = (forCol: Column) => TypeParser(arrayField, path, forCol, origArrayType.elementType, failOnInputNotPerSchema, isArrayElement = true) + .standardize() + + val lambdaErrCols = lambda.andThen(_.errors) + val lambdaStdCols = lambda.andThen(_.stdCol) + val nullErrCond = column.isNull and lit(!field.nullable) + + val finalErrs = when(nullErrCond, + array(typedLit(ErrorMessage.stdNullErr(inputFullPathName)))) + .otherwise( + typedLit(flatten(transform(column, lambdaErrCols, lambdaVariableName))) + ) + val stdCols = transform(column, lambdaStdCols, lambdaVariableName) + logger.info(s"Finished standardization plan for Array $inputFullPathName") + ParseOutput(stdCols as (fieldOutputName, metadata), finalErrs) + } + } + + private final case class StructParser(override val field: StructTypeStructField, + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean) + (implicit defaults: Defaults) extends TypeParser[Any] { + override def fieldType: StructType = { + field.dataType + } + + override protected def standardizeAfterCheck()(implicit logger: Logger): ParseOutput = { + val origStructType = origType.asInstanceOf[StructType] // this should never throw an exception because of `checkSetupForFailure` + val out = fieldType.fields.map{f => + val origSubField = Try{origStructType(f.sourceName)}.toOption + val origSubFieldType = origSubField.map(_.dataType).getOrElse(NullType) + val subColumn = origSubField.map(x => column(x.name)).getOrElse(nullColumn) + TypeParser(f, inputFullPathName, subColumn, origSubFieldType, failOnInputNotPerSchema).standardize()} + val cols = out.map(_.stdCol) + val errs = out.map(_.errors) + // condition for nullable error of the struct itself + val nullErrCond = column.isNull and lit(!field.nullable) + val dropChildrenErrsCond = column.isNull + // first remove all child errors if this is null + val errs1 = concat( + flatten(array(errs.map(x => when(dropChildrenErrsCond, typedLit(Seq[ErrorMessage]())).otherwise(x)): _*)), + // then add an error if this one is null + when(nullErrCond, + array(callUDF(UDFNames.stdNullErr, lit(inputFullPathName)))) + .otherwise( + typedLit(Seq[ErrorMessage]()) + ) + ) + // rebuild the struct + val outputColumn = struct(cols: _*) as (fieldOutputName, metadata) + + ParseOutput(outputColumn, errs1) + } + } + + private abstract class PrimitiveParser[T](implicit defaults: Defaults) extends TypeParser[T] { + override protected def standardizeAfterCheck()(implicit logger: Logger): ParseOutput = { + val castedCol: Column = assemblePrimitiveCastLogic + val castHasError: Column = assemblePrimitiveCastErrorLogic(castedCol) + + val err: Column = if (field.nullable) { + when(column.isNotNull and castHasError, // cast failed + array(callUDF(UDFNames.stdCastErr, lit(columnIdForUdf), column.cast(StringType))) + ).otherwise( // everything is OK + typedLit(Seq.empty[ErrorMessage]) + ) + } else { + when(column.isNull, // NULL not allowed + array(callUDF(UDFNames.stdNullErr, lit(columnIdForUdf))) + ).otherwise( when(castHasError, // cast failed + array(callUDF(UDFNames.stdCastErr, lit(columnIdForUdf), column.cast(StringType))) + ).otherwise( // everything is OK + typedLit(Seq.empty[ErrorMessage]) + )) + } + + val std: Column = when(size(err) > lit(0), // there was an error on cast + defaultValue.orNull // converting Option to value or Null + ).otherwise( when (column.isNotNull, + castedCol + ) //.otherwise(null) - no need to explicitly mention + ) as (fieldOutputName, metadata) + + ParseOutput(std, err) + } + + protected def assemblePrimitiveCastLogic: Column //this differs based on the field data type + + protected def assemblePrimitiveCastErrorLogic(castedCol: Column): Column = { + castedCol.isNull //this one is sufficient for most primitive data types + } + } + + private abstract class ScalarParser[T](implicit defaults: Defaults) extends PrimitiveParser[T] { + override def assemblePrimitiveCastLogic: Column = column.cast(field.dataType) + } + + private abstract class NumericParser[N: TypeTag](override val field: NumericTypeStructField[N]) + (implicit defaults: Defaults) extends ScalarParser[N] { + override protected def standardizeAfterCheck()(implicit logger: Logger): ParseOutput = { + if (field.needsUdfParsing) { + standardizeUsingUdf() + } else { + super.standardizeAfterCheck() + } + } + + override def assemblePrimitiveCastLogic: Column = { + if (origType == StringType) { + // in case of string as source some adjustments might be needed + val decimalSymbols = field.pattern.toOption.flatten.map( + _.decimalSymbols + ).getOrElse(defaults.getDecimalSymbols) + val replacements: Map[Char, Char] = decimalSymbols.basicSymbolsDifference(defaults.getDecimalSymbols) + + val columnWithProperDecimalSymbols: Column = if (replacements.nonEmpty) { + val from = replacements.keys.mkString + val to = replacements.values.mkString + translate(column, from, to) + } else { + column + } + + val columnToCast = if (field.allowInfinity && (decimalSymbols.infinityValue != InfinityStr)) { + // because Spark uses Java's conversion from String, which in turn checks for hardcoded "Infinity" string not + // DecimalFormatSymbols content + val infinityEscaped = Pattern.quote(decimalSymbols.infinityValue) + regexp_replace(regexp_replace(columnWithProperDecimalSymbols, InfinityStr, s"${InfinityStr}_"), infinityEscaped, InfinityStr) + } else { + columnWithProperDecimalSymbols + } + columnToCast.cast(field.dataType) + } else { + super.assemblePrimitiveCastLogic + } + } + + private def standardizeUsingUdf(): ParseOutput = { + val udfFnc: UserDefinedFunction = UDFBuilder.stringUdfViaNumericParser(field.parser.get, field.nullable, columnIdForUdf, defaultValue) + ParseOutput(udfFnc(column)("result").cast(field.dataType).as(fieldOutputName), udfFnc(column)("error")) + } + } + + private final case class IntegralParser[N: LongLike: TypeTag](override val field: NumericTypeStructField[N], + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean, + overflowableTypes: Set[DataType]) + (implicit defaults: Defaults) extends NumericParser[N](field) { + override protected def assemblePrimitiveCastErrorLogic(castedCol: Column): Column = { + val basicLogic: Column = super.assemblePrimitiveCastErrorLogic(castedCol) + + origType match { + case dt: DecimalType => + // decimal can be too big, to catch overflow or imprecision issues compare to original + basicLogic or (column =!= castedCol.cast(dt)) + case DoubleType | FloatType => + // same as Decimal but directly comparing fractional values is not reliable, + // best check for whole number is considered modulo 1.0 + basicLogic or (column % 1.0 =!= 0.0) or column > field.typeMax or column < field.typeMin + case ot if overflowableTypes.contains(ot) => + // from these types there is the possibility of under-/overflow, extra check is needed + basicLogic or (castedCol =!= column.cast(LongType)) + case StringType => + // string of decimals are not allowed + basicLogic or column.contains(".") + case _ => + basicLogic + } + } + } + + private final case class DecimalParser(override val field: NumericTypeStructField[BigDecimal], + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean) + (implicit defaults: Defaults) + extends NumericParser[BigDecimal](field) + // NB! loss of precision is not addressed for any DecimalType + // e.g. 3.141592 will be Standardized to Decimal(10,2) as 3.14 + + private final case class FractionalParser[N: DoubleLike: TypeTag](override val field: NumericTypeStructField[N], + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean) + (implicit defaults: Defaults) + extends NumericParser[N](field) { + override protected def assemblePrimitiveCastErrorLogic(castedCol: Column): Column = { + //NB! loss of precision is not addressed for any fractional type + if (field.allowInfinity) { + castedCol.isNull or castedCol.isNaN + } else { + castedCol.isNull or castedCol.isNaN or castedCol.isInfinite + } + } + } + + private final case class StringParser(field: TypedStructField, + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean) + (implicit defaults: Defaults) extends ScalarParser[String] + + private final case class BinaryParser(field: BinaryTypeStructField, + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean) + (implicit defaults: Defaults) extends PrimitiveParser[Array[Byte]] { + override protected def assemblePrimitiveCastLogic: Column = { + origType match { + case BinaryType => column + case StringType => + // already validated in Standardization + field.normalizedEncoding match { + case Some(MetadataValues.Encoding.Base64) => callUDF(UDFNames.binaryUnbase64, column) + case Some(MetadataValues.Encoding.None) | None => + if (field.normalizedEncoding.isEmpty) { + logger.warn(s"Binary field ${field.structField.name} does not have encoding setup in metadata. Reading as-is.") + } + column.cast(field.dataType) // use as-is + case _ => throw new IllegalStateException(s"Unsupported encoding for Binary field ${field.structField.name}:" + + s" '${field.normalizedEncoding.get}'") + } + + case _ => throw new IllegalStateException(s"Unsupported conversion from BinaryType to ${field.dataType}") + } + } + } + + private final case class BooleanParser(field: TypedStructField, + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean) + (implicit defaults: Defaults) extends ScalarParser[Boolean] + + /** + * Timestamp conversion logic + * Original type | TZ in pattern/without TZ | Has default TZ + * ~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Fractional | ->Decimal->String->to_timestamp | ->Decimal->String->to_timestamp->to_utc_timestamp + * Decimal | ->String->to_timestamp | ->String->to_timestamp->to_utc_timestamp + * String | ->to_timestamp | ->to_timestamp->to_utc_timestamp + * Timestamp | O | ->to_utc_timestamp + * Date | ->to_timestamp(no pattern) | ->to_utc_timestamp + * Other | ->String->to_timestamp | ->String->to_timestamp->to_utc_timestamp + * + * + * Date conversion logic + * Original type | TZ in pattern/without TZ | Has default TZ (the last to_date is always without pattern) + * ~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Float | ->Decimal->String->to_date | ->Decimal->String->to_timestamp->to_utc_timestamp->to_date + * Decimal | ->String->to_date | ->String->->to_timestamp->->to_utc_timestamp->to_date + * String | ->to_date | ->to_timestamp->->to_utc_timestamp->to_date + * Timestamp | ->to_date(no pattern) | ->to_utc_timestamp->to_date + * Date | O | ->to_utc_timestamp->to_date + * Other | ->String->to_date | ->String->to_timestamp->to_utc_timestamp->to_date + */ + private abstract class DateTimeParser[T](implicit defaults: Defaults) extends PrimitiveParser[T] { + override val field: DateTimeTypeStructField[T] + protected val pattern: DateTimePattern = field.pattern.get.get + + override protected def assemblePrimitiveCastLogic: Column = { + if (pattern.isEpoch) { + castEpoch() + } else { + castWithPattern() + } + } + + private def patternNeeded(originType: DataType): Unit = { + if (pattern.isDefault) { + throw new InvalidParameterException( + s"Dates & times represented as ${originType.typeName} values need specified 'pattern' metadata" + ) + } + } + + private def castWithPattern(): Column = { + // sadly with parquet support, incoming might not be all `plain` +// underlyingType match { + origType match { + case _: NullType => nullColumn + case _: DateType => castDateColumn(column) + case _: TimestampType => castTimestampColumn(column) + case _: StringType => castStringColumn(column) + case ot: DoubleType => + // this case covers some IBM date format where it's represented as a double ddmmyyyy.hhmmss + patternNeeded(ot) + castFractionalColumn(column, ot) + case ot: FloatType => + // this case covers some IBM date format where it's represented as a double ddmmyyyy.hhmmss + patternNeeded(ot) + castFractionalColumn(column, ot) + case ot => + patternNeeded(ot) + castNonStringColumn(column, ot) + } + } + + private def castFractionalColumn(fractionalColumn: Column, originType: DataType): Column = { + val index = pattern.indexOf(".") //This can stop working when Spark becomes Locale aware + val (precision, scale) = if (index == -1) { + (pattern.length, 0) + } else { + (pattern.length-1, pattern.length - index - 1) + } + castNonStringColumn(fractionalColumn.cast(DecimalType.apply(precision, scale)), originType) + } + + private def castNonStringColumn(nonStringColumn: Column, originType: DataType): Column = { + logger.warn( + s"$inputFullPathName is specified as timestamp or date, but original type is ${originType.typeName}. Trying to interpret as string." + ) + castStringColumn(nonStringColumn.cast(StringType)) + } + + protected def castEpoch(): Column = { + (column.cast(decimalType) / pattern.epochFactor).cast(TimestampType) + } + + protected def castStringColumn(stringColumn: Column): Column + + protected def castDateColumn(dateColumn: Column): Column + + protected def castTimestampColumn(timestampColumn: Column): Column + + } + + private final case class DateParser(field: DateTimeTypeStructField[Date], + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean) + (implicit defaults: Defaults) extends DateTimeParser[Date] { + private val defaultTimeZone: Option[String] = field + .defaultTimeZone + .map(Option(_)) + .getOrElse(defaults.getDefaultDateTimeZone) + + private def applyPatternToStringColumn(column: Column, pattern: String): Column = { + defaultTimeZone.map(tz => + to_date(to_utc_timestamp(to_timestamp(column, pattern), tz)) + ).getOrElse( + to_date(column, pattern) + ) + } + + override def castEpoch(): Column = { + // number cannot be cast to date directly, so first casting to timestamp and then truncating + to_date(super.castEpoch()) + } + + override protected def castStringColumn(stringColumn: Column): Column = { + if (pattern.containsSecondFractions) { + // date doesn't need to care about second fractions + applyPatternToStringColumn( + stringColumn.removeSections( + Seq(pattern.millisecondsPosition, pattern.microsecondsPosition, pattern.nanosecondsPosition).flatten + ), pattern.patternWithoutSecondFractions) + } else { + applyPatternToStringColumn(stringColumn, pattern) + } + } + + override protected def castDateColumn(dateColumn: Column): Column = { + defaultTimeZone.map( + tz => to_date(to_utc_timestamp(dateColumn, tz)) + ).getOrElse( + dateColumn + ) + } + + override protected def castTimestampColumn(timestampColumn: Column): Column = { + to_date(defaultTimeZone.map( + to_utc_timestamp(timestampColumn, _) + ).getOrElse( + timestampColumn + )) + } + } + + private final case class TimestampParser(field: DateTimeTypeStructField[Timestamp], + path: String, + column: Column, + origType: DataType, + failOnInputNotPerSchema: Boolean, + isArrayElement: Boolean) + (implicit defaults: Defaults) extends DateTimeParser[Timestamp] { + + private val defaultTimeZone: Option[String] = field + .defaultTimeZone + .map(Option(_)) + .getOrElse(defaults.getDefaultTimestampTimeZone) + + private def applyPatternToStringColumn(column: Column, pattern: String): Column = { + val interim: Column = to_timestamp(column, pattern) + defaultTimeZone.map(to_utc_timestamp(interim, _)).getOrElse(interim) + } + + override protected def castStringColumn(stringColumn: Column): Column = { + if (pattern.containsSecondFractions) { + //this is a trick how to enforce fractions of seconds into the timestamp + // - turn into timestamp up to seconds precision and that into unix_timestamp, + // - the second fractions turn into numeric fractions + // - add both together and convert to timestamp + val colSeconds = unix_timestamp(applyPatternToStringColumn( + stringColumn.removeSections( + Seq(pattern.millisecondsPosition, pattern.microsecondsPosition, pattern.nanosecondsPosition).flatten + ), pattern.patternWithoutSecondFractions)) + + val colMilliseconds: Option[Column] = + pattern.millisecondsPosition.map(stringColumn.zeroBasedSubstr(_).cast(decimalType) / MillisecondsPerSecond) + val colMicroseconds: Option[Column] = + pattern.microsecondsPosition.map(stringColumn.zeroBasedSubstr(_).cast(decimalType) / MicrosecondsPerSecond) + val colNanoseconds: Option[Column] = + pattern.nanosecondsPosition.map(stringColumn.zeroBasedSubstr(_).cast(decimalType) / NanosecondsPerSecond) + val colFractions: Column = + (colMilliseconds ++ colMicroseconds ++ colNanoseconds).reduceOption(_ + _).getOrElse(lit(0)) + + (colSeconds + colFractions).cast(TimestampType) + } else { + applyPatternToStringColumn(stringColumn, pattern) + } + } + + override protected def castDateColumn(dateColumn: Column): Column = { + defaultTimeZone.map( + to_utc_timestamp(dateColumn, _) + ).getOrElse( + to_timestamp(dateColumn) + ) + } + + override protected def castTimestampColumn(timestampColumn: Column): Column = { + defaultTimeZone.map( + to_utc_timestamp(timestampColumn, _) + ).getOrElse( + timestampColumn + ) + } + } +} + +class TypeParserException(message: String) extends Exception(message: String) diff --git a/src/main/scala/za/co/absa/standardization/time/DateTimePattern.scala b/src/main/scala/za/co/absa/standardization/time/DateTimePattern.scala new file mode 100644 index 0000000..35e2e97 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/time/DateTimePattern.scala @@ -0,0 +1,184 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.time + +import za.co.absa.standardization.implicits.StringImplicits.StringEnhancements +import za.co.absa.standardization.types.{Section, TypePattern} + +/** + * Class to carry enhanced information about date/time formatting pattern in conversion from/to string + * + * @param pattern actual pattern to format the type conversion + * @param isDefault marks if the pattern is actually an assigned value or taken for global defaults + */ +abstract sealed class DateTimePattern(pattern: String, isDefault: Boolean = false) + extends TypePattern(pattern, isDefault){ + + val isEpoch: Boolean + val epochFactor: Long + + val timeZoneInPattern: Boolean + val defaultTimeZone: Option[String] + val isTimeZoned: Boolean + + val millisecondsPosition: Option[Section] + val microsecondsPosition: Option[Section] + val nanosecondsPosition: Option[Section] + + val secondFractionsSections: Seq[Section] + val patternWithoutSecondFractions: String + def containsSecondFractions: Boolean = secondFractionsSections.nonEmpty + + override def toString: String = { + val q = "\"" + s"pattern: $q$pattern$q" + defaultTimeZone.map(x => s" (default time zone: $q$x$q)").getOrElse("") + } + +} + +object DateTimePattern { + + val EpochKeyword = "epoch" + val EpochMilliKeyword = "epochmilli" + val EpochMicroKeyword = "epochmicro" + val EpochNanoKeyword = "epochnano" + + private val epochUnitFactor = 1 + private val epoch1kFactor = 1000 + private val epoch1MFactor = 1000000 + private val epoch1GFactor = 1000000000 + + private val patternTimeZoneChars = Set('X','z','Z') + + private val patternMilliSecondChar = 'S' + private val patternMicroSecondChar = 'i' + private val patternNanoSecondChat = 'n' + + // scalastyle:off magic.number + private val last3Chars = Section(-3, 3) + private val last6Chars = Section(-6, 6) + private val last9Chars = Section(-9, 9) + private val trio6Back = Section(-6, 3) + private val trio9Back = Section(-9, 3) + // scalastyle:on magic.number + + private final case class EpochDTPattern(override val pattern: String, + override val isDefault: Boolean = false) + extends DateTimePattern(pattern, isDefault) { + + override val isEpoch: Boolean = true + override val epochFactor: Long = DateTimePattern.epochFactor(pattern) + + override val timeZoneInPattern: Boolean = true + override val defaultTimeZone: Option[String] = None + override val isTimeZoned: Boolean = true + + override val millisecondsPosition: Option[Section] = pattern match { + case EpochMilliKeyword => Option(last3Chars) + case EpochMicroKeyword => Option(trio6Back) + case EpochNanoKeyword => Option(trio9Back) + case _ => None + } + override val microsecondsPosition: Option[Section] = pattern match { + case EpochMicroKeyword => Option(last3Chars) + case EpochNanoKeyword => Option(trio6Back) + case _ => None + } + override val nanosecondsPosition: Option[Section] = pattern match { + case EpochNanoKeyword => Option(last3Chars) + case _ => None + } + override val secondFractionsSections: Seq[Section] = pattern match { + case EpochMilliKeyword => Seq(last3Chars) + case EpochMicroKeyword => Seq(last6Chars) + case EpochNanoKeyword => Seq(last9Chars) + case _ => Seq.empty + } + override val patternWithoutSecondFractions: String = EpochKeyword + } + + private final case class StandardDTPattern(override val pattern: String, + assignedDefaultTimeZone: Option[String] = None, + override val isDefault: Boolean = false) + extends DateTimePattern(pattern, isDefault) { + + override val isEpoch: Boolean = false + override val epochFactor: Long = 0 + + override val timeZoneInPattern: Boolean = DateTimePattern.timeZoneInPattern(pattern) + override val defaultTimeZone: Option[String] = assignedDefaultTimeZone.filterNot(_ => timeZoneInPattern) + override val isTimeZoned: Boolean = timeZoneInPattern || defaultTimeZone.nonEmpty + + val (millisecondsPosition, microsecondsPosition, nanosecondsPosition) = analyzeSecondFractionsPositions(pattern) + override val secondFractionsSections: Seq[Section] = Section.mergeTouchingSectionsAndSort(Seq(millisecondsPosition, microsecondsPosition, nanosecondsPosition).flatten) + override val patternWithoutSecondFractions: String = Section.removeMultipleFrom(pattern, secondFractionsSections) + + private def scanForPlaceholder(withinString: String, placeHolder: Char): Option[Section] = { + val start = withinString.findFirstUnquoted(Set(placeHolder), Set(''')) + start.map(index => Section.ofSameChars(withinString, index)) + } + + private def analyzeSecondFractionsPositions(withinString: String): (Option[Section], Option[Section], Option[Section]) = { + val clearedPattern = withinString + + // TODO as part of #7 fix (originally Enceladus#677) + val milliSP = scanForPlaceholder(clearedPattern, patternMilliSecondChar) + val microSP = scanForPlaceholder(clearedPattern, patternMicroSecondChar) + val nanoSP = scanForPlaceholder(clearedPattern, patternNanoSecondChat) + (milliSP, microSP, nanoSP) + } + } + + private def create(pattern: String, assignedDefaultTimeZone: Option[String], isDefault: Boolean): DateTimePattern = { + if (isEpoch(pattern)) { + EpochDTPattern(pattern, isDefault) + } else { + StandardDTPattern(pattern, assignedDefaultTimeZone, isDefault) + } + } + + def apply(pattern: String, + assignedDefaultTimeZone: Option[String] = None): DateTimePattern = { + create(pattern, assignedDefaultTimeZone, isDefault = false) + } + + def asDefault(pattern: String, + assignedDefaultTimeZone: Option[String] = None): DateTimePattern = { + create(pattern, assignedDefaultTimeZone, isDefault = true) + } + + def isEpoch(pattern: String): Boolean = { + pattern.toLowerCase match { + case EpochKeyword | EpochMilliKeyword | EpochMicroKeyword | EpochNanoKeyword => true + case _ => false + } + } + + def epochFactor(pattern: String): Long = { + pattern.toLowerCase match { + case EpochKeyword => epochUnitFactor + case EpochMilliKeyword => epoch1kFactor + case EpochMicroKeyword => epoch1MFactor + case EpochNanoKeyword => epoch1GFactor + case _ => 0 + } + } + + def timeZoneInPattern(pattern: String): Boolean = { + isEpoch(pattern) || pattern.hasUnquoted(patternTimeZoneChars, Set(''')) + } +} diff --git a/src/main/scala/za/co/absa/standardization/time/TimeZoneNormalizer.scala b/src/main/scala/za/co/absa/standardization/time/TimeZoneNormalizer.scala new file mode 100644 index 0000000..c013fa0 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/time/TimeZoneNormalizer.scala @@ -0,0 +1,54 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.time + +import com.typesafe.config.{Config, ConfigFactory} +import org.apache.log4j.{LogManager, Logger} +import org.apache.spark.sql.SparkSession + +import java.util.TimeZone + +/** + * Sets the system time zone per application configuration, recommended value being UTC + */ +object TimeZoneNormalizer { + private val log: Logger = LogManager.getLogger(this.getClass) + private val generalConfig: Config = ConfigFactory.load() + val timeZone: String = if (generalConfig.hasPath("timezone")){ + generalConfig.getString("timezone") + } else { + val default = "UTC" + log.warn(s"No time zone (timezone) setting found. Setting to default, which is $default.") + default + } + + def normalizeJVMTimeZone(): Unit = { + TimeZone.setDefault(TimeZone.getTimeZone(timeZone)) + log.debug(s"JVM time zone set to $timeZone") + } + + def normalizeSessionTimeZone(spark: SparkSession): Unit = { + spark.conf.set("spark.sql.session.timeZone", timeZone) + log.debug(s"Spark session ${spark.sparkContext.applicationId} time zone of name ${spark.sparkContext.appName} set to $timeZone") + } + + def normalizeAll(spark: SparkSession): Unit = { + normalizeJVMTimeZone() + normalizeSessionTimeZone(spark) + } + +} diff --git a/src/main/scala/za/co/absa/standardization/typeClasses/DoubleLike.scala b/src/main/scala/za/co/absa/standardization/typeClasses/DoubleLike.scala new file mode 100644 index 0000000..43808d4 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/typeClasses/DoubleLike.scala @@ -0,0 +1,41 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.typeClasses + +import scala.annotation.implicitNotFound + +@implicitNotFound("No member of type class DoubleLike in scope for ${T}") +trait DoubleLike[T] extends Ordering[T]{ + def toDouble(x: T): Double + def toT(d: Double): T +} + +object DoubleLike { + + implicit object DoubleLikeForDouble extends DoubleLike[Double] { + override def toDouble(x: Double): Double = x + override def toT(d: Double): Double = d + override def compare(x: Double, y: Double): Int = x.compare(y) + } + + implicit object DoubleLikeForFloat extends DoubleLike[Float] { + override def toDouble(x: Float): Double = x.toDouble + override def toT(d: Double): Float = d.toFloat + override def compare(x: Float, y: Float): Int = x.compare(y) + } +} + diff --git a/src/main/scala/za/co/absa/standardization/typeClasses/LongLike.scala b/src/main/scala/za/co/absa/standardization/typeClasses/LongLike.scala new file mode 100644 index 0000000..5a0e318 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/typeClasses/LongLike.scala @@ -0,0 +1,68 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.typeClasses + +import scala.annotation.implicitNotFound + +@implicitNotFound("No member of type class LongLike in scope for ${T}") +trait LongLike[T] extends Ordering[T]{ + val MinValue: Long + val MaxValue: Long + def toLong(x: T): Long + def toT(l: Long): T + def stringToT(s: String): T +} + +object LongLike { + + implicit object LongLikeForLong extends LongLike[Long] { + override val MinValue: Long = Long.MinValue + override val MaxValue: Long = Long.MaxValue + override def toLong(x: Long): Long = x + override def toT(l: Long): Long = l + override def stringToT(s: String): Long = s.toLong + override def compare(x: Long, y: Long): Int = x.compare(y) + } + + implicit object LongLikeForInt extends LongLike[Int] { + override val MinValue: Long = Int.MinValue + override val MaxValue: Long = Int.MaxValue + override def toLong(x: Int): Long = x.toLong + override def toT(l: Long): Int = l.toInt + override def stringToT(s: String): Int = s.toInt + override def compare(x: Int, y: Int): Int = x.compare(y) + } + + implicit object LongLikeForShort extends LongLike[Short] { + override val MinValue: Long = Short.MinValue + override val MaxValue: Long = Short.MaxValue + override def toLong(x: Short): Long = x.toLong + override def toT(l: Long): Short = l.toShort + override def stringToT(s: String): Short = s.toShort + override def compare(x: Short, y: Short): Int = x.compare(y) + } + + implicit object LongLikeForByte extends LongLike[Byte] { + override val MinValue: Long = Byte.MinValue + override val MaxValue: Long = Byte.MaxValue + override def toLong(x: Byte): Long = x.toLong + override def toT(l: Long): Byte = l.toByte + override def stringToT(s: String): Byte = s.toByte + override def compare(x: Byte, y: Byte): Int = x.compare(y) + } +} + diff --git a/src/main/scala/za/co/absa/standardization/types/Defaults.scala b/src/main/scala/za/co/absa/standardization/types/Defaults.scala new file mode 100644 index 0000000..9c022f9 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/Defaults.scala @@ -0,0 +1,109 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types + +import org.apache.spark.sql.types._ +import za.co.absa.standardization.numeric.DecimalSymbols +import com.typesafe.config._ +import java.sql.{Date, Timestamp} +import java.util.{Locale, TimeZone} +import scala.util.{Success, Try} + +abstract class Defaults { + /** A function which defines default values for primitive types */ + def getDataTypeDefaultValue(dt: DataType): Any + + /** A function which defines default values for primitive types, allowing possible Null*/ + def getDataTypeDefaultValueWithNull(dt: DataType, nullable: Boolean): Try[Option[Any]] + + /** A function which defines default formats for primitive types */ + def getStringPattern(dt: DataType): String + + def getDefaultTimestampTimeZone: Option[String] + def getDefaultDateTimeZone: Option[String] + + def getDecimalSymbols: DecimalSymbols +} + +object GlobalDefaults extends Defaults { + /** A function which defines default values for primitive types */ + override def getDataTypeDefaultValue(dt: DataType): Any = + dt match { + case _: IntegerType => 0 + case _: FloatType => 0f + case _: ByteType => 0.toByte + case _: ShortType => 0.toShort + case _: DoubleType => 0.0d + case _: LongType => 0L + case _: StringType => "" + case _: BinaryType => Array.empty[Byte] + case _: DateType => new Date(0) //linux epoch + case _: TimestampType => new Timestamp(0) + case _: BooleanType => false + case t: DecimalType => + val rest = t.precision - t.scale + BigDecimal(("0" * rest) + "." + ("0" * t.scale)) + case _ => throw new IllegalStateException(s"No default value defined for data type ${dt.typeName}") + } + + /** A function which defines default values for primitive types, allowing possible Null*/ + override def getDataTypeDefaultValueWithNull(dt: DataType, nullable: Boolean): Try[Option[Any]] = { + if (nullable) { + Success(None) + } else { + Try{ + getDataTypeDefaultValue(dt) + }.map(Some(_)) + } + } + + /** A function which defines default formats for primitive types */ + override def getStringPattern(dt: DataType): String = + dt match { + case DateType => "yyyy-MM-dd" + case TimestampType => "yyyy-MM-dd HH:mm:ss" + case _: IntegerType + | FloatType + | ByteType + | ShortType + | DoubleType + | LongType => "" + case _: DecimalType => "" + case _ => throw new IllegalStateException(s"No default format defined for data type ${dt.typeName}") + } + + override def getDefaultTimestampTimeZone: Option[String] = defaultTimestampTimeZone + override def getDefaultDateTimeZone: Option[String] = defaultDateTimeZone + + override def getDecimalSymbols: DecimalSymbols = decimalSymbols + + private val defaultTimestampTimeZone: Option[String] = readTimezone("defaultTimestampTimeZone") + private val defaultDateTimeZone: Option[String] = readTimezone("defaultDateTimeZone") + private val decimalSymbols = DecimalSymbols(Locale.US) + + private def readTimezone(path: String): Option[String] = { + val generalConfig = ConfigFactory.load() + if (generalConfig.hasPath(path)){ + val result = generalConfig.getString(path) + + if (TimeZone.getAvailableIDs().contains(result)) + throw new IllegalStateException(s"The setting '$result' of '$path' is not recognized as known time zone") + + Some(result) + } else None + } +} diff --git a/src/main/scala/za/co/absa/standardization/types/DefaultsByFormat.scala b/src/main/scala/za/co/absa/standardization/types/DefaultsByFormat.scala new file mode 100644 index 0000000..1f88f81 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/DefaultsByFormat.scala @@ -0,0 +1,89 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types + +import org.apache.spark.sql.types.DataType +import za.co.absa.standardization.ConfigReader +import za.co.absa.standardization.numeric.DecimalSymbols +import za.co.absa.standardization.types.DefaultsByFormat._ + +import java.util.TimeZone +import scala.util.Try + +class DefaultsByFormat(formatName: String, + globalDefaults: Defaults = GlobalDefaults, + private val config: ConfigReader = new ConfigReader()) extends Defaults { + + /** A function which defines default values for primitive types */ + override def getDataTypeDefaultValue(dt: DataType): Any = globalDefaults.getDataTypeDefaultValue(dt) + + /** A function which defines default values for primitive types, allowing possible Null */ + override def getDataTypeDefaultValueWithNull(dt: DataType, nullable: Boolean): Try[Option[Any]] = { + globalDefaults.getDataTypeDefaultValueWithNull(dt, nullable) + } + + /** A function which defines default formats for primitive types */ + override def getStringPattern(dt: DataType): String = { + globalDefaults.getStringPattern(dt) + } + + override def getDefaultTimestampTimeZone: Option[String] = { + defaultTimestampTimeZone.orElse(globalDefaults.getDefaultTimestampTimeZone) + } + + override def getDefaultDateTimeZone: Option[String] = { + defaultDateTimeZone.orElse(globalDefaults.getDefaultDateTimeZone) + } + + override def getDecimalSymbols: DecimalSymbols = globalDefaults.getDecimalSymbols + + private def readTimezone(path: String): Option[String] = { + val result = config.getStringOption(path) + result.foreach(tz => + if (!TimeZone.getAvailableIDs().contains(tz )) { + throw new IllegalStateException(s"The setting '$tz' of '$path' is not recognized as known time zone") + } + ) + result + } + + private def formatSpecificConfigurationName(configurationName: String): String = { + configurationFullName(configurationName, formatName) + } + + private def configurationFullName(base: String, suffix: String): String = { + s"$base.$suffix" + } + + private val defaultTimestampTimeZone: Option[String] = + readTimezone(formatSpecificConfigurationName(TimestampTimeZoneKeyName)) + .orElse(readTimezone(configurationFullName(TimestampTimeZoneKeyName, DefaultKeyName))) + .orElse(readTimezone(DefaultsByFormat.ObsoleteTimestampTimeZoneName)) + + private val defaultDateTimeZone: Option[String] = + readTimezone(formatSpecificConfigurationName(DateTimeZoneKeyName)) + .orElse(readTimezone(configurationFullName(DateTimeZoneKeyName, DefaultKeyName))) + .orElse(readTimezone(DefaultsByFormat.ObsoleteDateTimeZoneName)) +} + +object DefaultsByFormat { + private final val DefaultKeyName = "default" + private final val ObsoleteTimestampTimeZoneName = "defaultTimestampTimeZone" + private final val ObsoleteDateTimeZoneName = "defaultDateTimeZone" + private final val TimestampTimeZoneKeyName = "standardization.defaultTimestampTimeZone" + private final val DateTimeZoneKeyName = "standardization.defaultDateTimeZone" +} diff --git a/src/main/scala/za/co/absa/standardization/types/ParseOutput.scala b/src/main/scala/za/co/absa/standardization/types/ParseOutput.scala new file mode 100644 index 0000000..e05e663 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/ParseOutput.scala @@ -0,0 +1,21 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types + +import org.apache.spark.sql.Column + +case class ParseOutput(stdCol: Column, errors: Column) diff --git a/src/main/scala/za/co/absa/standardization/types/Section.scala b/src/main/scala/za/co/absa/standardization/types/Section.scala new file mode 100644 index 0000000..a61e802 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/Section.scala @@ -0,0 +1,341 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types + +import java.security.InvalidParameterException +import scala.util.{Failure, Try} + +/** + * Represents a section of a string defined by its starting index and length of the section. + * It supports negative indexes for start, denoting a position counted from end + * The class is marked as `sealed abstract` to limit the constructors to only the apply method within the companion object + * _Comparison_ + * The class implements the `Ordered` trait + * Sections with negative start are BIGGER then the ones with positive start + * The section with smaller ABSOLUTE value of start is considered smaller, if starts equal the shorter section is the + * smaller + * NB! Interesting consequence of the ordering is, that if the sections are not overlapping and removal applied on + * string from greatest to smallest one by one, the result is the same as removing all sections "at once" + * + * @param start the start position of the section, if negative the position is counted from the end + * @param length length of the section, cannot be negative + */ +sealed abstract case class Section(start: Int, length: Int) extends Ordered[Section] { + + override def compare(that: Section): Int = { + if (start == that.start) { + length.compare(that.length) //shorter is smaller + } else if ((start < 0) && (that.start >= 0)) { + 1 // negative start is bigger then the one of positive+zero + } else if ((start >= 0) && (that.start < 0)) { + -1 + } else if (start.abs < that.start.abs) { + -1 + } else { + 1 + } + } + + def copy(start: Int = this.start, length: Int = this.length): Section = { + Section(start, length) + } + + /** + * Converts the Section to actual indexes representing the section on the given string, if used in the substring + * function. The result respects the boundaries of the string, the indexes in the result not to cause OutOfBound exception + * For example Section(2,3) for string "Hello!" gives (2,5), for "abc" (2,3) + * @param forString the string which the section would be applied to + * @return tuple representing the beginIndex and endIndex parameters for the substring function + */ + def toSubstringParameters(forString: String): (Int, Int) = { + val (realStart, after) = if (start >= 0) { + (start min forString.length, start + length) + } else { + val startIndex = forString.length + start + if (startIndex >= 0) { + (startIndex, startIndex + length) + } else { // the distance from end is longer than the string itself + (0, Math.max(length + startIndex, 0)) + } + } + (realStart, Math.min(after, forString.length)) + } + + /** + * The substring represented by this Section within the provided string + * Complementary to `remove` + * @param string the string to apply the section to + * @return substring defined by this section + */ + def extractFrom(string: String): String = { + val (realStart, after) = toSubstringParameters(string) + string.substring(realStart, after) + } + + /** + * Creates a string that is the remainder if the substring represented by this section is removed from the provided string + * Complementary to `extract` + * @param string the string to apply the section to + * @return concatenation of the string before and after the Section + */ + def removeFrom(string: String): String = { + val (realStart, after) = toSubstringParameters(string) + string.substring(0, realStart) + string.substring(after) + } + + /** + * Inverse function for `remove`, inserts the `what` string into the `into` string as defined by the `section` + * The `what` string needs to have the same length as the section; unless the placement of the `what` is outside + * (beyond or before) of `string` in which case it can be shorter + * @param string the string to inject into + * @param what the string to inject + * @return the newly created string + */ + def injectInto(string: String, what: String): Try[String] = { + + def fail(): Try[String] = { + Failure(new InvalidParameterException( + s"The length of the string to inject (${what.length}) doesn't match Section($start, $length) for string of length ${string.length}." + )) + } + + if (what.length > length) { + fail() + } else if ((what == "") && ((length == 0) || (start > string.length) || (start + string.length + length < 0))) { + // injecting empty string is easy if valid; which is either if the section length = 0, or the index to inject to + // is beyond the limits of the final string + Try(string) + } else if (start >= 0) { + if (start > string.length) { + // beyond the into string + fail() + } else if (start == string.length) { + // at the end of the into string + Try(string + what) + } else if (what.length == length) { + // injection in the middle (or beginning) + Try(string.substring(0, start) + what + string.substring(start)) + } else { + // wrong size of injection + fail() + } + } else { + val index = string.length + start + what.length + val whatLengthDeficit = what.length - length + if (index == string.length) { + // at the end of the into string + Try(string + what) + } else if (index == whatLengthDeficit) { + // somewhere withing the into string + Try(what + string) + } else if (whatLengthDeficit == 0 && index > 0 && index < string.length) { + // at the beginning of the into string, maybe appropriately shorter if to be place "before" 0 index + Try(string.substring(0, index) + what + string.substring(index)) + } else { + fail() + } + } + } + + /** + * Metrics defined on Section, it equals the number of positions (characters) between two sections + * @param secondSection the Section to compute the distance from/to + * @return None - if one Section has a negative start and the other positive or zero + * The end of the smaller section subtracted from the start of the greater one (see comparison), + * can be negative + */ + def distance(secondSection: Section): Option[Int] = { + def calculateDistance(first: Section, second: Section) = { + second.start - first.start - first.length + } + + (start >= 0, secondSection.start >= 0) match { + case (false, true) | (true, false) => + // two sections of differently signed starts don't have a distance defined + None + case (true, true) => + if (this <= secondSection) { + Option(calculateDistance(this, secondSection)) + } else { + Option(calculateDistance(secondSection, this)) + } + case (false, false) => + if (this <= secondSection) { + Option(calculateDistance(secondSection, this)) + } else { + Option(calculateDistance(this, secondSection)) + } + } + } + + /** + * Checks if two sections overlap + * @param that the other Section + * @return true if the greater Section starts before the smaller one ends (see comparison) + * false otherwise + */ + def overlaps(that: Section): Boolean = { + distance(that).exists(_ < 0) + } + + /** + * Checks if two sections touch or overlap + * @param that the other Section + * @return true if the greater Section starts before the smaller one ends (see comparison) or right after it + * false otherwise + */ + def touches(that: Section): Boolean = { + distance(that).exists(_ <= 0) + } +} + +object Section { + /** + * The only possible constructor for Section class. It ensures that the input values for the created object are within bounds + * @param start the start position of the section, if negative the position is counted from the end + * @param length length of the section, cannot be negative + * @return the new Section object + */ + def apply(start: Int, length: Int): Section = { + val realLength = if (length < 0) { + 0 + } else if ((start >= 0) && (start.toLong + length.toLong > Int.MaxValue)) { + Int.MaxValue - start + } else { + length + } + new Section(start, realLength) {} + } + + /** + * Alternative constructor to create a section from starting and ending indexes + * If start is bigger then end, they will be swapped for the Section creation + * @param start start of the section, inclusive + * @param end end of the section, inclusive + * @return the new Section object + */ + def fromIndexes(start: Int, end: Int): Section = { + val realStart = Math.min(start, end) + val realEnd = Math.max(start, end) + Section(realStart, realEnd - start + 1) + } + + /** + * Alternative constructor to create a Section based on the repeated character within the provided string + * The Section will start per the `start` provided, and the length will be determined by the number of same characters + * in row, as the character on the `start` index + * E.g. ofSameChars("abbccccdef", 3) -> Section(3, 4) + * @param inputString the string which to scan + * @param start start of the Section, and also the index of the character whose repetition will determine the + * length of the Section; if negative, index is counted from the end of the string + * @return the new Section object + */ + def ofSameChars(inputString: String, start: Int): Section = { + val index = if (start >= 0) { + start + } else { + inputString.length + start + } + if ((index >= inputString.length) || (index < 0)) { + Section(start, 0) + } else { + val char = inputString(index) + var res = index + while ((res < inputString.length) && (inputString(res) == char)) { + res += 1 + } + Section(start, res - index) + } + } + + /** + * Removes sections of the string in a way, that string is considered intact until all removals are executed. In + * other words the indexes are not shifted. + * @param string the string to operate upon + * @param sections sections to apply + * @return the string as a result if all the sections would be removed "at once" + */ + def removeMultipleFrom(string: String, sections: Seq[Section]): String = { + if (sections.isEmpty) { + string + } else { + val charsPresent = Array.fill(string.length)(true) + sections.foreach{section => + val (realStart, after) = section.toSubstringParameters(string) + for (i <- realStart until after) { + charsPresent(i) = false + } + } + val paring: Seq[(Char, Boolean)] = string.toSeq.zip(charsPresent) + + paring.collect{case (c, true) => c}.mkString + } + } + + /** + * Merges all touching (that includes overlaps too) sections. Sections that are not touching are left as they were. + * The resulting section is sorted + * Example: + * For a string: + * 01234567890ACDFEFGHIJKLMNOPQUSTUVWXYZ + * ^ ^^^ ^-^=^--^ ^ + * | | | | | | + * | | Section(5,1) | | Section(-1,1) + * | Section(3,2) | Section(-9,6) + * Section(1,1) Section(-11,5) + * Output of the merge: + * 01234567890ACDFEFGHIJKLMNOPQUSTUVWXYZ + * ^ ^-^ ^------^ ^ + * | | | | + * | Section(3,3) | Section(-1,1) + * Section(1,1) Section(-11,8) + * + * @param sections the sections to merge + * @return an ordered from greater to smaller sequence of distinct sections (their distance is at least 1 or undefined) + */ + def mergeTouchingSectionsAndSort(sections: Seq[Section]): Seq[Section] = { + def fuse(into: Section, what: Section): Section = { + if (into.start + into.length >= what.start + what.length) { + //as the sequence where the sections are coming from is sorter, this condition is enough to check that `what` is within `into` + into + } else { + //actual fusion + //the length expression is simplified: into.length + what.length - ((into.start + into.length) - what.start) + Section(into.start, what.length - into.start + what.start) + } + } + + def doMerge(input: Seq[Section]): Seq[Section] = { + if (input.isEmpty) { + input + } else { + val sorted = input.sorted + sorted.tail.foldLeft(List(sorted.head)) { (resultAcc, item) => + if (item touches resultAcc.head) { + val newHead = if (item.start >= 0) fuse(resultAcc.head, item) else fuse(item, resultAcc.head) + newHead :: resultAcc.tail + } else { + item :: resultAcc + } + } + } + } + + val (negativeOnes, positiveOnes) = sections.partition(_.start < 0) + doMerge(negativeOnes) ++ doMerge(positiveOnes) + } +} diff --git a/src/main/scala/za/co/absa/standardization/types/TypePattern.scala b/src/main/scala/za/co/absa/standardization/types/TypePattern.scala new file mode 100644 index 0000000..c42b79f --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/TypePattern.scala @@ -0,0 +1,30 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types + +import scala.language.implicitConversions + +/** + * Class to carry enhanced information about formatting pattern in conversion from/to string + * @param pattern actual pattern to format the type conversion + * @param isDefault marks if the pattern is actually an assigned value or taken for global defaults + */ +abstract class TypePattern(val pattern: String, val isDefault: Boolean = false) extends Serializable + +object TypePattern { + implicit def patternToString(pattern: TypePattern): String = pattern.pattern +} diff --git a/src/main/scala/za/co/absa/standardization/types/TypedStructField.scala b/src/main/scala/za/co/absa/standardization/types/TypedStructField.scala new file mode 100644 index 0000000..0859a98 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/TypedStructField.scala @@ -0,0 +1,476 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types + +import org.apache.spark.sql.types._ +import za.co.absa.standardization.ValidationIssue +import za.co.absa.standardization.implicits.StructFieldImplicits.StructFieldEnhancements +import za.co.absa.standardization.numeric.{DecimalSymbols, NumericPattern, Radix} +import za.co.absa.standardization.schema.{MetadataKeys, MetadataValues} +import za.co.absa.standardization.time.DateTimePattern +import za.co.absa.standardization.typeClasses.{DoubleLike, LongLike} +import za.co.absa.standardization.types.parsers._ +import za.co.absa.standardization.validation.field._ + +import java.sql.{Date, Timestamp} +import java.util.Base64 +import scala.util.{Failure, Success, Try} + +sealed abstract class TypedStructField(structField: StructField)(implicit defaults: Defaults) + extends StructFieldEnhancements(structField) with Serializable { + + type BaseType + + protected def convertString(string: String): Try[BaseType] + + def validate(): Seq[ValidationIssue] + + def stringToTyped(string: String): Try[Option[BaseType]] = { + def errMsg: String = { + s"'$string' cannot be cast to ${dataType.typeName}" + } + + if (string == null) { + if (structField.nullable) { + Success(None) + } else { + Failure(new IllegalArgumentException(s"null is not a valid value for field '${structField.name}'")) + } + } else { + convertString(string) match { + case Failure(e: NumberFormatException) if e.getClass == classOf[NumberFormatException] => + // replacing some not very informative exception message with better one + Failure(new NumberFormatException(errMsg)) + case Failure(e: IllegalArgumentException) if e.getClass == classOf[IllegalArgumentException]=> + // replacing some not very informative exception message with better one + Failure(new IllegalArgumentException(errMsg, e.getCause)) + case Failure(e) => + // other failures stay unchanged + Failure(e) + case Success(good) => + // good result is put withing the option as the return type requires + Success(Some(good)) + } + } + + } + + /** + * The default value defined in the metadata of the field, if present + * @return Try - because the gathering may fail in conversion between types + * outer Option - None means no default was defined within the metadata of the field + * inner Option - the actual default value or None in case the default is null + */ + def ownDefaultValue: Try[Option[Option[BaseType]]] = { + if (hasMetadataKey(MetadataKeys.DefaultValue)) { + for { + defaultValueString <- Try{structField.metadata.getString(MetadataKeys.DefaultValue)} + defaultValueTyped <- stringToTyped(defaultValueString) + } yield Some(defaultValueTyped) + } else { + Success(None) + } + } + + /** + * The default value that will be used for the field, local if defined otherwise global + * @return Try - because the gathering of local default may fail in conversion between types + * Option - the actual default value or None in case the default is null + */ + def defaultValueWithGlobal: Try[Option[BaseType]] = { + for { + localDefault <- ownDefaultValue + result <- localDefault match { + case Some(value) => Success(value) + case None => defaults.getDataTypeDefaultValueWithNull(dataType, nullable).map(_.map(_.asInstanceOf[BaseType])) + } + } yield result + } + + def pattern: Try[Option[TypePattern]] = Success(None) + def needsUdfParsing: Boolean = false + + def name: String = structField.name + def nullable: Boolean = structField.nullable + def dataType: DataType = structField.dataType + + def canEqual(any: Any): Boolean = any.isInstanceOf[TypedStructField] + + override def equals(other: Any): Boolean = other match { + case that: TypedStructField => that.canEqual(this) && structField == that.structField + case _ => false + } + + override def hashCode(): Int = { + /* one of the suggested ways to implement the hasCode logic */ + val prime = 31 + var result = 1 + result = prime * result + (if (structField == null) 0 else structField.hashCode) + result + } +} + +object TypedStructField { + /** + * This is to be the only accessible constructor for TypedStructField sub-classes + * The point is, that sub-classes have private constructors to prevent their instantiation outside this apply + * constructor. This is to ensure at compile time there is a bound between the provided StructField.dataType and the + * class created + * @param structField the structField to wrap TypedStructField around + * @return the object of non-abstract TypedStructField successor class relevant to the StructField dataType + */ + def apply(structField: StructField)(implicit defaults: Defaults): TypedStructField = { + structField.dataType match { + case _: StringType => new StringTypeStructField(structField) + case _: BinaryType => new BinaryTypeStructField(structField) + case _: BooleanType => new BooleanTypeStructField(structField) + case _: ByteType => new ByteTypeStructField(structField) + case _: ShortType => new ShortTypeStructField(structField) + case _: IntegerType => new IntTypeStructField(structField) + case _: LongType => new LongTypeStructField(structField) + case _: FloatType => new FloatTypeStructField(structField) + case _: DoubleType => new DoubleTypeStructField(structField) + case dt: DecimalType => new DecimalTypeStructField(structField, dt) + case _: TimestampType => new TimestampTypeStructField(structField) + case _: DateType => new DateTypeStructField(structField) + case at: ArrayType => new ArrayTypeStructField(structField, at) + case st: StructType => new StructTypeStructField(structField, st) + case _ => new GeneralTypeStructField(structField) + } + } + + def asNumericTypeStructField[N](structField: StructField)(implicit defaults: Defaults): NumericTypeStructField[N] = + TypedStructField(structField).asInstanceOf[NumericTypeStructField[N]] + def asDateTimeTypeStructField[T](structField: StructField)(implicit defaults: Defaults): DateTimeTypeStructField[T] = + TypedStructField(structField).asInstanceOf[DateTimeTypeStructField[T]] + def asArrayTypeStructField(structField: StructField)(implicit defaults: Defaults): ArrayTypeStructField = + TypedStructField(structField).asInstanceOf[ArrayTypeStructField] + def asStructTypeStructField(structField: StructField)(implicit defaults: Defaults): StructTypeStructField = + TypedStructField(structField).asInstanceOf[StructTypeStructField] + def asBinaryTypeStructField(structField: StructField)(implicit defaults: Defaults): BinaryTypeStructField = + TypedStructField(structField).asInstanceOf[BinaryTypeStructField] + + def unapply[T](typedStructField: TypedStructField): Option[StructField] = Some(typedStructField.structField) + + abstract class TypedStructFieldTagged[T](structField: StructField)(implicit defaults: Defaults) + extends TypedStructField(structField) { + override type BaseType = T + } + // StringTypeStructField + final class StringTypeStructField private[TypedStructField](structField: StructField) + (implicit defaults: Defaults) + extends TypedStructFieldTagged[String](structField) { + override protected def convertString(string: String): Try[String] = { + Success(string) + } + + override def validate(): Seq[ValidationIssue] = { + ScalarFieldValidator.validate(this) + } + } + + // BinaryTypeStructField + final class BinaryTypeStructField private[TypedStructField](structField: StructField) + (implicit defaults: Defaults) + extends TypedStructFieldTagged[Array[Byte]](structField) { + val normalizedEncoding: Option[String] = structField.getMetadataString(MetadataKeys.Encoding).map(_.toLowerCase) + + // used to convert the default value from metadata's [[MetadataKeys.DefaultValue]] + override protected def convertString(string: String): Try[Array[Byte]] = { + normalizedEncoding match { + case Some(MetadataValues.Encoding.Base64) => Try(Base64.getDecoder.decode(string)) + case Some(MetadataValues.Encoding.None) | None => Success(string.getBytes) // use as-is + case _ => + Failure(new IllegalStateException(s"Unsupported encoding for Binary field ${structField.name}: '${normalizedEncoding.get}'")) + } + } + + override def validate(): Seq[ValidationIssue] = { + BinaryFieldValidator.validate(this) + } + } + + // BooleanTypeStructField + final class BooleanTypeStructField private[TypedStructField](structField: StructField) + (implicit defaults: Defaults) + extends TypedStructFieldTagged[Boolean](structField) { + override protected def convertString(string: String): Try[Boolean] = { + Try{string.toBoolean} + } + + override def validate(): Seq[ValidationIssue] = { + ScalarFieldValidator.validate(this) + } + } + + // NumericTypeStructField + sealed abstract class NumericTypeStructField[N](structField: StructField, val typeMin: N, val typeMax: N) + (implicit defaults: Defaults) + extends TypedStructFieldTagged[N](structField) { + val allowInfinity: Boolean = false + val parser: Try[NumericParser[N]] + + override def pattern: Try[Option[NumericPattern]] = Success(readNumericPatternFromMetadata) + + override def needsUdfParsing: Boolean = { + pattern.toOption.flatten.exists(!_.isDefault) + } + + override protected def convertString(string: String): Try[N] = { + for { + parserToUse <- parser + parsed <- parserToUse.parse(string) + } yield parsed + } + + private def readNumericPatternFromMetadata: Option[NumericPattern] = { + val stringPatternOpt = getMetadataString(MetadataKeys.Pattern) + val decimalSymbolsOpt = readDecimalSymbolsFromMetadata() + + if (stringPatternOpt.nonEmpty) { + stringPatternOpt.map(NumericPattern(_, decimalSymbolsOpt.getOrElse(defaults.getDecimalSymbols))) + } else { + decimalSymbolsOpt.map(NumericPattern(_)) + } + } + + private def readDecimalSymbolsFromMetadata(): Option[DecimalSymbols] = { + val ds = defaults.getDecimalSymbols + val minusSign = getMetadataChar(MetadataKeys.MinusSign).getOrElse(ds.minusSign) + val decimalSeparator = getMetadataChar(MetadataKeys.DecimalSeparator).getOrElse(ds.decimalSeparator) + val groupingSeparator = getMetadataChar(MetadataKeys.GroupingSeparator).getOrElse(ds.groupingSeparator) + + if ((ds.minusSign != minusSign) || (ds.decimalSeparator != decimalSeparator) || (ds.groupingSeparator != groupingSeparator)) { + Option(ds.copy(minusSign = minusSign, decimalSeparator = decimalSeparator, groupingSeparator = groupingSeparator)) + } else { + None + } + } + } + + // IntegralTypeStructField + sealed abstract class IntegralTypeStructField[L: LongLike] private[TypedStructField](structField: StructField, + override val typeMin: L, + override val typeMax: L) + (implicit defaults: Defaults) + extends NumericTypeStructField[L](structField, typeMin, typeMax) { + + private val radix: Radix = readRadixFromMetadata + + override val parser: Try[IntegralParser[L]] = { + pattern.flatMap { patternForParser => + if (radix != Radix.DefaultRadix) { + val decimalSymbols = patternForParser.map(_.decimalSymbols).getOrElse(defaults.getDecimalSymbols) + Try(IntegralParser.ofRadix(radix, decimalSymbols, Option(typeMin), Option(typeMax))) + } else { + Success(IntegralParser(patternForParser + .getOrElse(NumericPattern(defaults.getDecimalSymbols)), Option(typeMin), Option(typeMax))) + }} + } + + override def validate(): Seq[ValidationIssue] = { + IntegralFieldValidator.validate(this) + } + + override def needsUdfParsing: Boolean = { + (radix != Radix.DefaultRadix) || super.needsUdfParsing + } + + private def readRadixFromMetadata:Radix = { + Try(getMetadataString(MetadataKeys.Radix).map(Radix(_))).toOption.flatten.getOrElse(Radix.DefaultRadix) + } + } + + final class ByteTypeStructField private[TypedStructField](structField: StructField)(implicit defaults: Defaults) + extends IntegralTypeStructField(structField, Byte.MinValue, Byte.MaxValue) + + final class ShortTypeStructField private[TypedStructField](structField: StructField)(implicit defaults: Defaults) + extends IntegralTypeStructField(structField, Short.MinValue, Short.MaxValue) + + final class IntTypeStructField private[TypedStructField](structField: StructField)(implicit defaults: Defaults) + extends IntegralTypeStructField(structField, Int.MinValue, Int.MaxValue) + + final class LongTypeStructField private[TypedStructField](structField: StructField)(implicit defaults: Defaults) + extends IntegralTypeStructField(structField, Long.MinValue, Long.MaxValue) + + // FractionalTypeStructField + sealed abstract class FractionalTypeStructField[D: DoubleLike] private[TypedStructField](structField: StructField, + override val typeMin: D, + override val typeMax: D) + (implicit defaults: Defaults) + extends NumericTypeStructField[D](structField, typeMin, typeMax) { + + override val allowInfinity: Boolean = getMetadataStringAsBoolean(MetadataKeys.AllowInfinity).getOrElse(false) + + override val parser: Try[NumericParser[D]] = { + pattern.map {patternOpt => + val patternForParser = patternOpt.getOrElse(NumericPattern(defaults.getDecimalSymbols)) + if (allowInfinity) { + FractionalParser.withInfinity(patternForParser) + } else { + FractionalParser(patternForParser, typeMin, typeMax) + } + } + } + + override def validate(): Seq[ValidationIssue] = { + FractionalFieldValidator.validate(this) + } + } + + // FloatTypeStructField + final class FloatTypeStructField private[TypedStructField](structField: StructField)(implicit defaults: Defaults) + extends FractionalTypeStructField(structField, Float.MinValue, Float.MaxValue) + + // DoubleTypeStructField + final class DoubleTypeStructField private[TypedStructField](structField: StructField)(implicit defaults: Defaults) + extends FractionalTypeStructField(structField, Double.MinValue, Double.MaxValue) + + // DecimalTypeStructField + final class DecimalTypeStructField private[TypedStructField](structField: StructField, + override val dataType: DecimalType) + (implicit defaults: Defaults) + extends NumericTypeStructField[BigDecimal]( + structField, + DecimalTypeStructField.minPossible(dataType), + DecimalTypeStructField.maxPossible(dataType) + ){ + val strictParsing: Boolean = getMetadataStringAsBoolean(MetadataKeys.StrictParsing).getOrElse(false) + + override val parser: Try[DecimalParser] = { + val maxScale = if(strictParsing) Some(scale) else None + pattern.map { patternOpt => + val pattern: NumericPattern = patternOpt.getOrElse(NumericPattern(defaults.getDecimalSymbols)) + DecimalParser(pattern, Option(typeMin), Option(typeMax), maxScale) + } + } + + override def needsUdfParsing: Boolean = strictParsing || super.needsUdfParsing + + override def validate(): Seq[ValidationIssue] = { + DecimalFieldValidator.validate(this) + } + + def precision: Int = dataType.precision + def scale: Int = dataType.scale + } + + object DecimalTypeStructField { + def maxPossible(decimalType: DecimalType): BigDecimal = { + val precision: Int = decimalType.precision + val scale: Int = decimalType.scale + val postDecimalString = "9" * scale + val preDecimalString = "9" * (precision - scale) + BigDecimal(s"$preDecimalString.$postDecimalString") + } + + def minPossible(decimalType: DecimalType): BigDecimal = { + -maxPossible(decimalType) + } + } + + // DateTimeTypeStructField + sealed abstract class DateTimeTypeStructField[T] private[TypedStructField](structField: StructField, validator: DateTimeFieldValidator) + (implicit defaults: Defaults) + extends TypedStructFieldTagged[T](structField) { + + override def pattern: Try[Option[DateTimePattern]] = { + parser.map(x => Some(x.pattern)) + } + + lazy val parser: Try[DateTimeParser] = { + val patternToUse = readDateTimePattern + Try{ + DateTimeParser(patternToUse) + } + } + + def defaultTimeZone: Option[String] = { + getMetadataString(MetadataKeys.DefaultTimeZone) + } + + override def validate(): Seq[ValidationIssue] = { + validator.validate(this) + } + + private def readDateTimePattern: DateTimePattern = { + getMetadataString(MetadataKeys.Pattern).map { pattern => + val timeZoneOpt = getMetadataString(MetadataKeys.DefaultTimeZone) + DateTimePattern(pattern, timeZoneOpt) + }.getOrElse( + DateTimePattern.asDefault(defaults.getStringPattern(structField.dataType), None) + ) + } + } + + // TimestampTypeStructField + final class TimestampTypeStructField private[TypedStructField](structField: StructField)(implicit defaults: Defaults) + extends DateTimeTypeStructField[Timestamp](structField, TimestampFieldValidator) { + + override protected def convertString(string: String): Try[Timestamp] = { + parser.map(_.parseTimestamp(string)) + } + + } + + // DateTypeStructField + final class DateTypeStructField private[TypedStructField](structField: StructField)(implicit defaults: Defaults) + extends DateTimeTypeStructField[Date](structField, DateFieldValidator) { + + override protected def convertString(string: String): Try[Date] = { + parser.map(_.parseDate(string)) + } + } + + sealed trait WeakSupport[T] { + this: TypedStructFieldTagged[T] => + + def structField: StructField + + def convertString(string: String): Try[T] = { + Failure(new IllegalStateException(s"No converter defined for data type ${structField.dataType.typeName}")) + } + + def validate(): Seq[ValidationIssue] = { + FieldValidator.validate(this) + } + } + + final class ArrayTypeStructField private[TypedStructField](structField: StructField, override val dataType: ArrayType) + (implicit defaults: Defaults) + extends TypedStructFieldTagged[Any](structField) with WeakSupport[Any] { + + override def validate(): Seq[ValidationIssue] = { + val typedSubField = Try { + val subField = StructField(name, dataType.elementType, dataType.containsNull, structField.metadata) + TypedStructField(subField) + } + + super.validate() ++ FieldValidator.tryToValidationIssues(typedSubField.map(_.validate())) + } + + override def ownDefaultValue: Try[Option[Option[Any]]] = { + Success(None) // array type doesn't have own default value, if defined it is to be applied to element type + } + } + + final class StructTypeStructField(structField: StructField, override val dataType: StructType)(implicit defaults: Defaults) + extends TypedStructFieldTagged[Any](structField) with WeakSupport[Any] + + final class GeneralTypeStructField private[TypedStructField](structField: StructField)(implicit defaults: Defaults) + extends TypedStructFieldTagged[Any](structField) with WeakSupport[Any] +} diff --git a/src/main/scala/za/co/absa/standardization/types/parsers/DateTimeParser.scala b/src/main/scala/za/co/absa/standardization/types/parsers/DateTimeParser.scala new file mode 100644 index 0000000..c35be74 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/parsers/DateTimeParser.scala @@ -0,0 +1,124 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import za.co.absa.standardization.time.DateTimePattern +import za.co.absa.standardization.types.Section +import za.co.absa.standardization.types.parsers.DateTimeParser.{MillisecondsInSecond, NanosecondsInMicrosecond, NanosecondsInMillisecond, SecondsPerDay} + +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat +import java.util.Locale + +/** + * Enables to parse string to date and timestamp based on the provided format + * Unlike SimpleDateFormat it also supports keywords to format epoch related values + * @param pattern the formatting string, in case it's an epoch format the values wil need to be convertible to Long + */ +case class DateTimeParser(pattern: DateTimePattern) { + private val formatter: Option[SimpleDateFormat] = if (pattern.isEpoch) { + None + } else { + // locale here is hardcoded to the same value as Spark uses, lenient set to false also per Spark usage + val sdf = new SimpleDateFormat(pattern.patternWithoutSecondFractions, Locale.US) + sdf.setLenient(false) + Some(sdf) + } + + def parseDate(dateValue: String): Date = { + val seconds = extractSeconds(dateValue) + new Date((seconds - (seconds % SecondsPerDay)) * MillisecondsInSecond) + } + + def parseTimestamp(timestampValue: String): Timestamp = { + val seconds = extractSeconds(timestampValue) + val nanoseconds = extractNanoseconds(timestampValue) + makePreciseTimestamp(seconds, nanoseconds) + } + + def format(time: java.util.Date): String = { + //up to milliseconds it's easy with the formatter + val preliminaryResult = formatter.map(_.format(time)).getOrElse( + (time.getTime / MillisecondsInSecond).toString + ) + if (pattern.containsSecondFractions) { + // fractions of second present + // scalastyle:off magic.number + // 9 has the relation that nano- is a 10^-9 prefix, micro- is 10^-6 and milli is 10^-3 + val nanoString = time match { + case ts: Timestamp => "%09d".format(ts.getNanos) + case _ => "000000000" + } + + val injections: Map[Section, String] = Seq( + pattern.millisecondsPosition.map(x => (x, Section(-x.length, x.length).extractFrom(nanoString.substring(0, 3)))), + pattern.microsecondsPosition.map(x => (x, Section(-x.length, x.length).extractFrom(nanoString.substring(0, 6)))), + pattern.nanosecondsPosition.map(x => (x, Section(-x.length, x.length).extractFrom(nanoString))) + ).flatten.toMap + + val sections: Seq[Section] = Seq( + pattern.millisecondsPosition, + pattern.microsecondsPosition, + pattern.nanosecondsPosition + ).flatten.sorted + + sections.foldLeft(preliminaryResult) ((result, section) => + section.injectInto(result, injections(section)).getOrElse(result) + ) + // scalastyle:on magic.number + } else { + // no fractions of second + preliminaryResult + } + } + + private def makePreciseTimestamp(seconds: Long, nanoseconds: Int): Timestamp = { + val result = new Timestamp(seconds * MillisecondsInSecond) + if (nanoseconds > 0) { + result.setNanos(nanoseconds) + } + result + } + + private def extractSeconds(value: String): Long = { + val valueToParse = if (pattern.containsSecondFractions) { + Section.removeMultipleFrom(value, pattern.secondFractionsSections) + } else { + value + } + formatter.map(_.parse(valueToParse).getTime / MillisecondsInSecond).getOrElse( + valueToParse.toLong + ) + } + + private def extractNanoseconds(value: String): Int = { + var result = 0 + pattern.millisecondsPosition.foreach(result += _.extractFrom(value).toInt * NanosecondsInMillisecond) + pattern.microsecondsPosition.foreach(result += _.extractFrom(value).toInt * NanosecondsInMicrosecond) + pattern.nanosecondsPosition.foreach(result += _.extractFrom(value).toInt) + result + } +} + +object DateTimeParser { + private val SecondsPerDay = 24*60*60 + private val MillisecondsInSecond = 1000 + private val NanosecondsInMillisecond = 1000000 + private val NanosecondsInMicrosecond = 1000 + + def apply(pattern: String): DateTimeParser = new DateTimeParser(DateTimePattern(pattern)) +} diff --git a/src/main/scala/za/co/absa/standardization/types/parsers/DecimalParser.scala b/src/main/scala/za/co/absa/standardization/types/parsers/DecimalParser.scala new file mode 100644 index 0000000..9bda2b6 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/parsers/DecimalParser.scala @@ -0,0 +1,57 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import za.co.absa.standardization.numeric.NumericPattern + +import java.text.DecimalFormat +import scala.util.{Failure, Success, Try} + +class DecimalParser(override val pattern: NumericPattern, + override val min: Option[BigDecimal], + override val max: Option[BigDecimal], + val maxScale: Option[Int] = None) + extends NumericParser(pattern, min, max) with ParseViaDecimalFormat[BigDecimal] { + + override protected val stringConversion: String => BigDecimal = BigDecimal(_) + override protected val numberConversion: Number => BigDecimal = {n => BigDecimal(n.asInstanceOf[java.math.BigDecimal])} + + protected val decimalFormat: Option[DecimalFormat] = pattern.specifiedPattern.map (s => { + val format = new DecimalFormat(s, pattern.decimalSymbols.toDecimalFormatSymbols) + format.setParseBigDecimal(true) + format + }) + + override def parse(string: String): Try[BigDecimal] = { + super.parse(string).flatMap(number => { + maxScale match { + case Some(maxSc) if maxSc < number.scale => + Failure(new IllegalArgumentException(s"$string exceeds the defined scale limit in the schema")) + case _ => Success(number) + } + }) + } +} + +object DecimalParser { + def apply(pattern: NumericPattern, + min: Option[BigDecimal] = None, + max: Option[BigDecimal] = None, + maxScale: Option[Int] = None): DecimalParser = { + new DecimalParser(pattern, min, max, maxScale) + } +} diff --git a/src/main/scala/za/co/absa/standardization/types/parsers/FractionalParser.scala b/src/main/scala/za/co/absa/standardization/types/parsers/FractionalParser.scala new file mode 100644 index 0000000..949c4c8 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/parsers/FractionalParser.scala @@ -0,0 +1,68 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import za.co.absa.standardization.numeric.NumericPattern +import za.co.absa.standardization.typeClasses.DoubleLike + +import java.text.DecimalFormat + +class FractionalParser[D: DoubleLike] private(override val pattern: NumericPattern, + override val min: Option[D], + override val max: Option[D]) + extends NumericParser(pattern, min, max) with ParseViaDecimalFormat[D] { + + private val ev = implicitly[DoubleLike[D]] + + override protected val stringConversion: String => D = stringToDWithoutPattern + override protected val numberConversion: Number => D = {number => ev.toT(number.doubleValue())} + + protected val decimalFormat: Option[DecimalFormat] = pattern + .specifiedPattern + .map(new DecimalFormat(_, pattern.decimalSymbols.toDecimalFormatSymbols)) + + private def stringToDWithoutPattern(string: String): D = { + val resultAsDouble = string match { + case pattern.decimalSymbols.naNValue => Double.NaN + case pattern.decimalSymbols.infinityValue => Double.PositiveInfinity + case pattern.decimalSymbols.negativeInfinityValue => Double.NegativeInfinity + case _ => string.toDouble + } + ev.toT(resultAsDouble) + } +} + +object FractionalParser { + def apply(pattern: NumericPattern, + min: Double = Double.MinValue, + max: Double = Double.MaxValue): FractionalParser[Double] = { + new FractionalParser(pattern, Option(min), Option(max)) + } + + def apply[D: DoubleLike](pattern: NumericPattern, + min: D, + max: D): FractionalParser[D] = { + new FractionalParser[D](pattern, Option(min), Option(max)) + } + + def withInfinity[D: DoubleLike](pattern: NumericPattern): FractionalParser[D] = { + val ev = implicitly[DoubleLike[D]] + val negativeInfinity = ev.toT(Double.NegativeInfinity) + val positiveInfinity = ev.toT(Double.PositiveInfinity) + new FractionalParser(pattern, Option(negativeInfinity), Option(positiveInfinity)) + } +} diff --git a/src/main/scala/za/co/absa/standardization/types/parsers/IntegralParser.scala b/src/main/scala/za/co/absa/standardization/types/parsers/IntegralParser.scala new file mode 100644 index 0000000..a5e9428 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/parsers/IntegralParser.scala @@ -0,0 +1,171 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import za.co.absa.standardization.numeric.{DecimalSymbols, NumericPattern, Radix} +import za.co.absa.standardization.typeClasses.LongLike + +import java.math.BigInteger +import java.text.DecimalFormat +import scala.util.{Failure, Success, Try} + +abstract class IntegralParser[N: LongLike] private(override val pattern: NumericPattern, + override val min: Option[N], + override val max: Option[N]) + extends NumericParser[N](pattern, min, max) { + protected val ev: LongLike[N] = implicitly[LongLike[N]] + + val radix: Radix + + override protected val stringConversion: String => N = {string => ev.stringToT(string)} +} + +object IntegralParser { + def apply(pattern: NumericPattern, + min: Long = Long.MinValue, + max: Long = Long.MaxValue): IntegralParser[Long] = { + new PatternIntegralParser(pattern, Option(min), Option(max)) + } + + def apply[N: LongLike](pattern: NumericPattern, + min: Option[N], + max: Option[N]): IntegralParser[N] = { + new PatternIntegralParser(pattern, min, max) + } + + def ofRadix(radix: Radix, + decimalSymbols: DecimalSymbols = NumericParser.defaultDecimalSymbols, + min: Long = Long.MinValue, + max: Long = Long.MaxValue): IntegralParser[Long] = { + new RadixIntegralParser(radix, decimalSymbols, Option(min), Option(max)) + } + + def ofRadix[N: LongLike](radix: Radix, + decimalSymbols: DecimalSymbols, + min: Option[N], + max: Option[N]): IntegralParser[N] = { + new RadixIntegralParser[N](radix, decimalSymbols, min, max) + } + + def ofStringRadix(stringRadix: String, + decimalSymbols: DecimalSymbols = NumericParser.defaultDecimalSymbols, + min: Long = Long.MinValue, + max: Long = Long.MaxValue): IntegralParser[Long] = { + ofRadix(Radix(stringRadix), decimalSymbols, min, max) + } + + def ofStringRadix[N: LongLike](stringRadix: String, + decimalSymbols: DecimalSymbols, + min: Option[N], + max: Option[N]): IntegralParser[N] = { + ofRadix(Radix(stringRadix), decimalSymbols, min, max) + } + + def tryStringToBase(string: String): Try[Radix] = { + Try(Radix(string)) + } + + final class RadixIntegralParser[N: LongLike] (override val radix: Radix, + decimalSymbols: DecimalSymbols, + override val min: Option[N], + override val max: Option[N]) + extends IntegralParser(NumericPattern(decimalSymbols), min, max) { + + private val minBI = BigInteger.valueOf(min.map(ev.toLong).getOrElse(Long.MinValue)) + private val maxBI = BigInteger.valueOf(max.map(ev.toLong).getOrElse(Long.MaxValue)) + + override def parse(string: String): Try[N] = { + val preprocessed = normalizeBasicSymbols(string) + radix match { + case Radix.DefaultRadix => Try{ev.stringToT(preprocessed)}.flatMap(valueWithinBounds(_, string)) + case Radix(16) => toNWithBoundCheck(clearHexString(preprocessed), string)// scalastyle:ignore magic.number obvious meaning + case _ => toNWithBoundCheck(preprocessed, string) + } + } + + override def format(value: N): String = { + // scalastyle:off magic.number obvious meaning + val longValue = ev.toLong(value) + val result = radix match { + case Radix(10) => longValue.toString + case Radix(16) => longValue.toHexString + case Radix(2) => longValue.toBinaryString + case Radix(8) => longValue.toOctalString + case Radix(b) => + val bigValue = BigInteger.valueOf(longValue) + bigValue.toString(b) + } + // scalastyle:on magic.number + denormalizeBasicSymbols(result) + } + + private def toNWithBoundCheck(string: String, originalInput: String): Try[N] = { + val bigIntegerTry = Try{new BigInteger(string, radix.value)} + bigIntegerTry.flatMap(bigValue => { + if (bigValue.compareTo(maxBI) > 0) { // too big + Failure(outOfBoundsException(originalInput)) + } else if (bigValue.compareTo(minBI) < 0) { // too small + Failure(outOfBoundsException(originalInput)) + } else { + Success(ev.toT(bigValue.longValue())) + } + }) + } + + private def clearHexString(string: String): String = { + // supporting 0xFF style format of hexadecimals + val longEnoughString = string + " " + val prefix2 = longEnoughString.substring(0, 2).toLowerCase + if (prefix2 == "0x") { + string.substring(2) + } else { + val prefix3 = longEnoughString.substring(0, 3).toLowerCase + if (prefix3 == "-0x") { + "-" + string.substring(3) + } else if (prefix3 == "+0x") { + string.substring(3) + } else { + string + } + } + } + + protected def parseUsingPattern(stringToParse: String):Try[N] = parse(stringToParse) + protected def formatUsingPattern(value: N): String = format(value) + + } + + final class PatternIntegralParser[N: LongLike](override val pattern: NumericPattern, + override val min: Option[N], + override val max: Option[N]) + extends IntegralParser(pattern, min, max) with ParseViaDecimalFormat[N] { + override val radix: Radix = Radix.DefaultRadix + + override protected val numberConversion: Number => N = {number => + val longValue = number.longValue() + if ((longValue > ev.MaxValue) || (longValue < ev.MinValue)) { + throw outOfBoundsException(number.toString) + } + ev.toT(longValue) + } + override protected val decimalFormat: Option[DecimalFormat] = pattern.specifiedPattern.map(s => { + val df = new DecimalFormat(s, pattern.decimalSymbols.toDecimalFormatSymbols) + df.setParseIntegerOnly(true) + df + }) + } +} diff --git a/src/main/scala/za/co/absa/standardization/types/parsers/NumericParser.scala b/src/main/scala/za/co/absa/standardization/types/parsers/NumericParser.scala new file mode 100644 index 0000000..938261f --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/parsers/NumericParser.scala @@ -0,0 +1,106 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import za.co.absa.standardization.implicits.StringImplicits.StringEnhancements +import za.co.absa.standardization.numeric.{DecimalSymbols, NumericPattern} +import za.co.absa.standardization.types.parsers.NumericParser.NumericParserException + +import java.util.Locale +import scala.util.{Failure, Success, Try} + +abstract class NumericParser[N: Ordering](val pattern: NumericPattern, + val min: Option[N], + val max: Option[N]) + extends Serializable { + + protected val stringConversion: String => N + + protected def parseUsingPattern(stringToParse: String):Try[N] + protected def formatUsingPattern(value: N): String + + def parse(string: String): Try[N] = { + for { + parsed <- if (pattern.specifiedPattern.isDefined) { + parseUsingPattern(string) + } else { + parseWithoutPattern(string) + } + result <- valueWithinBounds(parsed) + } yield result + } + + def format(value: N): String = { + if (pattern.isDefault) { + formatWithoutPattern(value) + } else { + formatUsingPattern(value) + } + } + + protected def outOfBoundsException(originalInput: String): NumericParserException = { + new NumericParserException(s"The number '$originalInput' is out of range <$min, $max>") + } + + protected def valueWithinBounds(value: N): Try[N] = { + valueWithinBounds(value, value.toString) + } + + protected def valueWithinBounds(value: N, originalInput: String): Try[N] = { + // to be abe to use comparison on N type + val ordering = implicitly[Ordering[N]] + import ordering._ + + (this.min, this.max) match { // using this because of ambiguity with imports from ordering + case (Some(minDefined), _) if value < minDefined => Failure(outOfBoundsException(originalInput)) + case (_, Some(maxDefined)) if value > maxDefined => Failure(outOfBoundsException(originalInput)) + case _ => Success(value) + } + } + + protected def parseWithoutPattern(stringToParse: String):Try[N] = { + Try(stringConversion(normalizeBasicSymbols(stringToParse))) + } + + protected def formatWithoutPattern(value: N): String = { + denormalizeBasicSymbols(value.toString) + } + + protected def normalizeBasicSymbols(string: String): String = { + val replacements = NumericParser.defaultDecimalSymbols.basicSymbolsDifference(pattern.decimalSymbols) + if (replacements.nonEmpty) { + string.replaceChars(replacements) + } else { + string + } + } + + protected def denormalizeBasicSymbols(string: String): String = { + val replacements = pattern.decimalSymbols.basicSymbolsDifference(NumericParser.defaultDecimalSymbols) + if (replacements.nonEmpty) { + string.replaceChars(replacements) + } else { + string + } + } +} + +object NumericParser { + val defaultDecimalSymbols: DecimalSymbols = DecimalSymbols(Locale.US) + + class NumericParserException(s: String = "") extends NumberFormatException(s) +} diff --git a/src/main/scala/za/co/absa/standardization/types/parsers/ParseViaDecimalFormat.scala b/src/main/scala/za/co/absa/standardization/types/parsers/ParseViaDecimalFormat.scala new file mode 100644 index 0000000..3a41528 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/types/parsers/ParseViaDecimalFormat.scala @@ -0,0 +1,56 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import za.co.absa.standardization.implicits.OptionImplicits.OptionEnhancements +import za.co.absa.standardization.types.parsers.NumericParser.NumericParserException + +import java.text.{DecimalFormat, ParsePosition} +import scala.util.{Failure, Success, Try} + +/** + * Trait to implement the common logic of parsing and formatting using DecimalFormat object + * + * @tparam N a numeric types + */ +trait ParseViaDecimalFormat[N] { + protected val decimalFormat: Option[DecimalFormat] + protected val numberConversion: Number => N + + protected def parseUsingPattern(stringToParse: String):Try[N] = { + + def checkPosAtEnd(pos: ParsePosition): Try[Unit] = { + if (pos.getIndex < stringToParse.length) { + Failure(new NumericParserException(s"Parsing of '$stringToParse' failed.")) + } else { + Success(Unit) + } + } + + for { + formatter <- decimalFormat.toTry(new NumericParserException("No pattern provided")) + pos = new ParsePosition(0) + parsed <- Try(formatter.parse(stringToParse, pos)) + _ <- checkPosAtEnd(pos) + result <- Try(numberConversion(parsed)) + } yield result + } + + protected def formatUsingPattern(value: N): String = { + decimalFormat.map(_.format(value).trim).getOrElse(value.toString) + } +} diff --git a/src/main/scala/za/co/absa/standardization/udf/UDFBuilder.scala b/src/main/scala/za/co/absa/standardization/udf/UDFBuilder.scala new file mode 100644 index 0000000..5545107 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/udf/UDFBuilder.scala @@ -0,0 +1,56 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.udf + +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions.udf +import za.co.absa.standardization.types.parsers.NumericParser +import za.co.absa.standardization.types.parsers.NumericParser.NumericParserException + +import scala.reflect.runtime.universe._ +import scala.util.{Failure, Success} + +object UDFBuilder { + def stringUdfViaNumericParser[T: TypeTag](parser: NumericParser[T], + columnNullable: Boolean, + columnNameForError: String, + defaultValue: Option[T] + ): UserDefinedFunction = { + // ensuring all values sent to the UDFBuilder are instantiated + val vParser = parser + val vColumnNameForError = columnNameForError + val vDefaultValue = defaultValue + val vColumnNullable = columnNullable + + udf[UDFResult[T], String](numericParserToTyped(_, vParser, vColumnNullable, vColumnNameForError, vDefaultValue)) + } + + private def numericParserToTyped[T](input: String, + parser: NumericParser[T], + columnNullable: Boolean, + columnNameForError: String, + defaultValue: Option[T]): UDFResult[T] = { + val result = Option(input) match { + case Some(string) => parser.parse(string).map(Some(_)) + case None if columnNullable => Success(None) + case None => Failure(nullException) + } + UDFResult.fromTry(result, columnNameForError, input, defaultValue) + } + + private val nullException = new NumericParserException("Null value on input for non-nullable field") +} diff --git a/src/main/scala/za/co/absa/standardization/udf/UDFLibrary.scala b/src/main/scala/za/co/absa/standardization/udf/UDFLibrary.scala new file mode 100644 index 0000000..691ea51 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/udf/UDFLibrary.scala @@ -0,0 +1,99 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.udf + +import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Row, SparkSession} +import za.co.absa.standardization.{ErrorMessage, Mapping} +import za.co.absa.standardization.udf.UDFNames._ + +import java.util.Base64 +import scala.collection.mutable +import scala.util.{Failure, Success, Try} + +class UDFLibrary()(implicit val spark: SparkSession) { + + spark.udf.register(stdCastErr, { (errCol: String, rawValue: String) => + ErrorMessage.stdCastErr(errCol, rawValue) + }) + + spark.udf.register(stdNullErr, { errCol: String => ErrorMessage.stdNullErr(errCol) }) + + spark.udf.register(stdSchemaErr, { errRow: String => ErrorMessage.stdSchemaError(errRow) }) + + spark.udf.register(confMappingErr, { (errCol: String, rawValues: Seq[String], mappings: Seq[Mapping]) => + ErrorMessage.confMappingErr(errCol, rawValues, mappings) + }) + + spark.udf.register(confCastErr, { (errCol: String, rawValue: String) => + ErrorMessage.confCastErr(errCol, rawValue) + }) + + spark.udf.register(confNegErr, { (errCol: String, rawValue: String) => + ErrorMessage.confNegErr(errCol, rawValue) + }) + + spark.udf.register(confLitErr, { (errCol: String, rawValue: String) => + ErrorMessage.confLitErr(errCol, rawValue) + }) + + spark.udf.register(arrayDistinctErrors, // this UDF is registered for _spark-hats_ library sake + (arr: mutable.WrappedArray[ErrorMessage]) => + if (arr != null) { + arr.distinct.filter((a: AnyRef) => a != null) + } else { + Seq[ErrorMessage]() + } + ) + + spark.udf.register(cleanErrCol, + UDFLibrary.cleanErrCol, + ArrayType.apply(ErrorMessage.errorColSchema, containsNull = false)) + + spark.udf.register(errorColumnAppend, + UDFLibrary.errorColumnAppend, + ArrayType.apply(ErrorMessage.errorColSchema, containsNull = false)) + + + spark.udf.register(binaryUnbase64, + {stringVal: String => Try { + Base64.getDecoder.decode(stringVal) + } match { + case Success(decoded) => decoded + case Failure(_) => null //scalastyle:ignore null + }}) +} + +object UDFLibrary { + private val cleanErrCol = new UDF1[Seq[Row], Seq[Row]] { + override def call(t1: Seq[Row]): Seq[Row] = { + t1.filter({ row => + row != null && { + val typ = row.getString(0) + typ != null + } + }) + } + } + + private val errorColumnAppend = new UDF2[Seq[Row], Row, Seq[Row]] { + override def call(t1: Seq[Row], t2: Row): Seq[Row] = { + t1 :+ t2 + } + } +} diff --git a/src/main/scala/za/co/absa/standardization/udf/UDFNames.scala b/src/main/scala/za/co/absa/standardization/udf/UDFNames.scala new file mode 100644 index 0000000..796b922 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/udf/UDFNames.scala @@ -0,0 +1,34 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.udf + +object UDFNames { + final val stdCastErr = "stdCastErr" + final val stdNullErr = "stdNullErr" + final val stdSchemaErr = "stdSchemaErr" + + final val confMappingErr = "confMappingErr" + final val confCastErr = "confCastErr" + final val confNegErr = "confNegErr" + final val confLitErr = "confLitErr" + + final val arrayDistinctErrors = "arrayDistinctErrors" + final val cleanErrCol = "cleanErrCol" + final val errorColumnAppend = "errorColumnAppend" + + final val binaryUnbase64 = "binaryUnbase64" +} diff --git a/src/main/scala/za/co/absa/standardization/udf/UDFResult.scala b/src/main/scala/za/co/absa/standardization/udf/UDFResult.scala new file mode 100644 index 0000000..661c53e --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/udf/UDFResult.scala @@ -0,0 +1,40 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.udf + +import za.co.absa.standardization.ErrorMessage + +import scala.util.{Failure, Success, Try} + +case class UDFResult[T] ( + result: Option[T], + error: Seq[ErrorMessage] + ) + +object UDFResult { + def success[T](result: Option[T]): UDFResult[T] = { + UDFResult(result, Seq.empty) + } + + def fromTry[T](result: Try[Option[T]], columnName: String, rawValue: String, defaultValue: Option[T] = None): UDFResult[T] = { + result match { + case Success(success) => UDFResult.success(success) + case Failure(_) if Option(rawValue).isEmpty => UDFResult(defaultValue, Seq(ErrorMessage.stdNullErr(columnName))) + case Failure(_) => UDFResult(defaultValue, Seq(ErrorMessage.stdCastErr(columnName, rawValue))) + } + } +} diff --git a/src/main/scala/za/co/absa/standardization/validation/field/BinaryFieldValidator.scala b/src/main/scala/za/co/absa/standardization/validation/field/BinaryFieldValidator.scala new file mode 100644 index 0000000..d9c096e --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/BinaryFieldValidator.scala @@ -0,0 +1,68 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization.{ValidationError, ValidationIssue, ValidationWarning} +import za.co.absa.standardization.implicits.StructFieldImplicits._ +import za.co.absa.standardization.schema.{MetadataKeys, MetadataValues} +import za.co.absa.standardization.types.TypedStructField +import za.co.absa.standardization.types.TypedStructField.BinaryTypeStructField + +import java.util.Base64 +import scala.util.{Failure, Success, Try} + +object BinaryFieldValidator extends FieldValidator { + + private def validateDefaultValueWithGlobal(field: BinaryTypeStructField): Seq[ValidationIssue] = { + tryToValidationIssues(field.defaultValueWithGlobal) + } + + private def validateExplicitBase64DefaultValue(field: BinaryTypeStructField): Seq[ValidationIssue] = { + val defaultValue: Option[String] = field.structField.getMetadataString(MetadataKeys.DefaultValue) + + (field.normalizedEncoding, defaultValue) match { + case (None, Some(encodedDefault)) => + Seq(ValidationWarning(s"Default value of '$encodedDefault' found, but no encoding is specified. Assuming 'none'.")) + case (Some(MetadataValues.Encoding.Base64), Some(encodedValue)) => + Try { + Base64.getDecoder.decode(encodedValue) + } match { + case Success(_) => Seq.empty + case Failure(_) => Seq(ValidationError(s"Invalid default value $encodedValue for Base64 encoding (cannot be decoded)!")) + } + case _ => Seq.empty + } + } + + private def validateEncoding(field: BinaryTypeStructField): Seq[ValidationIssue] = { + field.normalizedEncoding match { + case Some(MetadataValues.Encoding.Base64) | Some(MetadataValues.Encoding.None) => + Seq.empty + case None => Seq.empty + case _ => Seq(ValidationError(s"Unsupported encoding for Binary field ${field.structField.name}: '${field.normalizedEncoding.get}'")) + } + } + + override def validate(field: TypedStructField): Seq[ValidationIssue] = { + super.validate(field) ++ ( + field match { + case bField: BinaryTypeStructField => + validateDefaultValueWithGlobal(bField) ++ validateExplicitBase64DefaultValue(bField) ++ validateEncoding(bField) + case _ => Seq(ValidationError("BinaryFieldValidator can validate only fields of type BinaryTypeStructField")) + }) + } +} diff --git a/src/main/scala/za/co/absa/standardization/validation/field/DateFieldValidator.scala b/src/main/scala/za/co/absa/standardization/validation/field/DateFieldValidator.scala new file mode 100644 index 0000000..f1d244a --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/DateFieldValidator.scala @@ -0,0 +1,79 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization.implicits.StringImplicits.StringEnhancements +import za.co.absa.standardization.time.DateTimePattern +import za.co.absa.standardization.types.parsers.DateTimeParser +import za.co.absa.standardization.{ValidationIssue, ValidationWarning} + +import java.util.Date + +object DateFieldValidator extends DateTimeFieldValidator { + + override protected def patternAnalysisIssues(pattern: DateTimePattern, + defaultValue: Option[String], + defaultTimeZone: Option[String]): Seq[ValidationIssue] = { + val doubleTimeZoneIssue: Seq[ValidationIssue] = if (pattern.timeZoneInPattern && defaultTimeZone.nonEmpty) { + Seq(ValidationWarning( + "Pattern includes time zone placeholder and default time zone is also defined (will never be used)" + )) + } else { + Nil + } + + val timeZoneIssue: Option[ValidationIssue] = if (!pattern.isEpoch && pattern.isTimeZoned) { + Option(ValidationWarning( + "Time zone is defined in pattern for date. While it's valid, it can lead to unexpected outcomes." + )) + } else { + None + } + + val patternIssues: Seq[ValidationIssue] = if (!pattern.isEpoch) { + val placeholders = Set('y', 'M', 'd', 'H', 'm', 's', 'D', 'S', 'i', 'n', 'a', 'k', 'K', 'h') + val patternChars: Map[Char, Int] = pattern.pattern.countUnquoted(placeholders, Set(''')) + patternChars.foldLeft(List.empty[ValidationIssue]) {(acc, item) => (item._1, item._2 > 0) match { + case ('y', false) => ValidationWarning("No year placeholder 'yyyy' found.")::acc + case ('M', false) => ValidationWarning("No month placeholder 'MM' found.")::acc + case ('d', false) => ValidationWarning("No day placeholder 'dd' found.")::acc + case ('H', true) => ValidationWarning("Redundant hour placeholder 'H' found.")::acc + case ('m', true) => ValidationWarning("Redundant minute placeholder 'm' found.")::acc + case ('s', true) => ValidationWarning("Redundant second placeholder 's' found.")::acc + case ('S', true) => ValidationWarning("Redundant millisecond placeholder 'S' found.")::acc + case ('i', true) => ValidationWarning("Redundant microsecond placeholder 'i' found.")::acc + case ('n', true) => ValidationWarning("Redundant nanosecond placeholder 'n' found.")::acc + case ('a', true) => ValidationWarning("Redundant am/pm placeholder 'a' found.")::acc + case ('k', true) => ValidationWarning("Redundant hour placeholder 'k' found.")::acc + case ('h', true) => ValidationWarning("Redundant hour placeholder 'h' found.")::acc + case ('K', true) => ValidationWarning("Redundant hour placeholder 'H' found.")::acc + case ('D', true) if patternChars('d') == 0 => + ValidationWarning("Rarely used DayOfYear placeholder 'D' found. Possibly DayOfMonth 'd' intended.")::acc + case _ => acc + }} + } else { + Nil + } + + patternIssues ++ doubleTimeZoneIssue ++ timeZoneIssue.toSet + } + + override def verifyStringDateTime(dateTime: String)(implicit parser: DateTimeParser): Date = { + parser.parseDate(dateTime) + } + +} diff --git a/src/main/scala/za/co/absa/standardization/validation/field/DateTimeFieldValidator.scala b/src/main/scala/za/co/absa/standardization/validation/field/DateTimeFieldValidator.scala new file mode 100644 index 0000000..bfc5b5c --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/DateTimeFieldValidator.scala @@ -0,0 +1,91 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization.{ValidationError, ValidationIssue} +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.time.DateTimePattern +import za.co.absa.standardization.types.TypedStructField +import za.co.absa.standardization.types.TypedStructField.DateTimeTypeStructField +import za.co.absa.standardization.types.parsers.DateTimeParser + +import java.sql.Timestamp +import java.util.{Date, TimeZone} +import scala.util.control.NonFatal + +abstract class DateTimeFieldValidator extends FieldValidator { + override def validate(field: TypedStructField): Seq[ValidationIssue] = { + super.validate(field) ++ ( + field match { + case dateTimeField: DateTimeTypeStructField[_] => validateDateTimeTypeStructField(dateTimeField) + case _ => Seq(ValidationError("DateTimeFieldValidator can validate only fields of type Date or Timestamp")) + }) + } + + private def validateDateTimeTypeStructField(field: DateTimeTypeStructField[_]): Seq[ValidationIssue] = { + val result = for { + parser <- field.parser + defaultValue: Option[String] = field.getMetadataString(MetadataKeys.DefaultValue) + defaultTimeZone: Option[String] = field.getMetadataString(MetadataKeys.DefaultTimeZone) + } yield patternConversionIssues(field, parser).toSeq ++ + defaultTimeZoneIssues(defaultTimeZone) ++ + patternAnalysisIssues(parser.pattern, defaultValue, defaultTimeZone) + + tryToValidationIssues(result) + } + + private def patternConversionIssues(field: DateTimeTypeStructField[_], parser: DateTimeParser): Option[ValidationIssue] = { + + try { + implicit val implicitParser: DateTimeParser = parser + + field.ownDefaultValue.get //if Failure will throw the exception + + val exampleDateStr = parser.format(DateTimeFieldValidator.exampleDate) + verifyStringDateTime(exampleDateStr) + val epochStartStr = parser.format(DateTimeFieldValidator.epochStart) + verifyStringDateTime(epochStartStr) + val epochStartDayEndStr = parser.format(DateTimeFieldValidator.epochStartDayEnd) + verifyStringDateTime(epochStartDayEndStr) + + None + } + catch { + case NonFatal(e) => Option(ValidationError(e.getMessage)) + } + } + + private def defaultTimeZoneIssues(defaultTimeZone: Option[String]): Seq[ValidationIssue] = { + defaultTimeZone.filterNot(TimeZone.getAvailableIDs().contains(_)).map(tz => + ValidationError(""""%s" is not a valid time zone designation""".format(tz)) + ).toSeq + } + + protected def patternAnalysisIssues(pattern: DateTimePattern, + defaultValue: Option[String], + defaultTimeZone: Option[String]): Seq[ValidationIssue] + + protected def verifyStringDateTime(dateTime: String)(implicit parser: DateTimeParser): Date +} + +object DateTimeFieldValidator { + private val dayMilliSeconds = 24 * 60 * 60 * 1000 + private val exampleDate = new Timestamp(System.currentTimeMillis) + private val epochStart = new Timestamp(0) + private val epochStartDayEnd = new Timestamp(dayMilliSeconds - 1) +} + diff --git a/src/main/scala/za/co/absa/standardization/validation/field/DecimalFieldValidator.scala b/src/main/scala/za/co/absa/standardization/validation/field/DecimalFieldValidator.scala new file mode 100644 index 0000000..59fc8f3 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/DecimalFieldValidator.scala @@ -0,0 +1,28 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization.ValidationIssue +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.TypedStructField + +object DecimalFieldValidator extends NumericFieldValidator { + override def validate(field: TypedStructField): Seq[ValidationIssue] = { + super.validate(field) ++ + this.checkMetadataKey[Boolean](field, MetadataKeys.StrictParsing) + } +} diff --git a/src/main/scala/za/co/absa/standardization/validation/field/FieldValidationIssue.scala b/src/main/scala/za/co/absa/standardization/validation/field/FieldValidationIssue.scala new file mode 100644 index 0000000..8a0f77f --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/FieldValidationIssue.scala @@ -0,0 +1,24 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization.ValidationIssue + +/** + * This class contains the list of issues found during schema validation of a particular column + */ +case class FieldValidationIssue(fieldName: String, pattern: String, issues: Seq[ValidationIssue]) diff --git a/src/main/scala/za/co/absa/standardization/validation/field/FieldValidator.scala b/src/main/scala/za/co/absa/standardization/validation/field/FieldValidator.scala new file mode 100644 index 0000000..cce4be1 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/FieldValidator.scala @@ -0,0 +1,83 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization.{ValidationError, ValidationIssue} +import za.co.absa.standardization.types.TypedStructField + +import scala.reflect.runtime.universe._ +import scala.util.{Failure, Success, Try} +import FieldValidator._ + +class FieldValidator { + + def validate(field: TypedStructField): Seq[ValidationIssue] = { + Nil + } + + /** + * Function to convert a Try type to sequence of ValidationIssue. Naming by the patter StringToInt; Try is a noun here + * @param tryValue Try value to convert to ValidationIssue - Failure is converted to ValidationError, any ValidationIssue + * included within Success will be returned in the Sequence, all other will result in empty sequence + * @return sequence of ValidationIssue, that were either part the input or if the input was a failure, then + * it converted into ValidationError + */ + def tryToValidationIssues(tryValue: Try[Any]): Seq[ValidationIssue] = { + tryValue match { + case Failure(e) => Seq(ValidationError(e.getMessage)) + case Success(seq: Seq[_]) => seq.collect{case x:ValidationIssue => x} //have to use collect because of type erasure + case Success(opt: Option[_]) => opt.collect{case x:ValidationIssue => x}.toSeq + case Success(issue: ValidationIssue) => Seq(issue) + case _ => Nil + } + } + + + + protected def checkMetadataKey[T: TypeTag](field: TypedStructField, + metadataKey: String, + issueConstructor: String => ValidationIssue = ValidationError.apply): Seq[ValidationIssue] = { + + def optionToValidationIssueSeq(option: Option[_], typeName: String): Seq[ValidationIssue] = { + option.map(_ => Nil).getOrElse( + Seq(issueConstructor(s"$metadataKey metadata value of field '${field.name}' is not ${simpleTypeName(typeName)} in String format")) + ) + } + + if (field.hasMetadataKey(metadataKey)) { + typeOf[T] match { + case t if t =:= typeOf[String] => optionToValidationIssueSeq(field.getMetadataString(metadataKey), t.toString) + case t if t =:= typeOf[Boolean] => optionToValidationIssueSeq(field.getMetadataStringAsBoolean(metadataKey), t.toString) + case t if t =:= typeOf[Char] => optionToValidationIssueSeq(field.getMetadataChar(metadataKey), t.toString) + case _ => Seq(ValidationError(s"Unsupported metadata validation type for key '$metadataKey' of field '${field.name}'")) + } + } else { + Nil + } + } +} + +object FieldValidator extends FieldValidator { + /** + * Keeps part of the string after last dot. E.g. `scala.Boolean` -> `Boolean`. Does nothing if there is no dot. + * @param typeName possibly dot-separated type name + * @return simple type name + */ + private[field] def simpleTypeName(typeName: String) = { + typeName.split("""\.""").last + } +} diff --git a/src/main/scala/za/co/absa/standardization/validation/field/FractionalFieldValidator.scala b/src/main/scala/za/co/absa/standardization/validation/field/FractionalFieldValidator.scala new file mode 100644 index 0000000..97faa12 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/FractionalFieldValidator.scala @@ -0,0 +1,29 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization.ValidationIssue +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.TypedStructField + +object FractionalFieldValidator extends NumericFieldValidator { + override def validate(field: TypedStructField): Seq[ValidationIssue] = { + super.validate(field) ++ + this.checkMetadataKey[Boolean](field, MetadataKeys.AllowInfinity) + } +} + diff --git a/src/main/scala/za/co/absa/standardization/validation/field/IntegralFieldValidator.scala b/src/main/scala/za/co/absa/standardization/validation/field/IntegralFieldValidator.scala new file mode 100644 index 0000000..d5f68af --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/IntegralFieldValidator.scala @@ -0,0 +1,48 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization.numeric.Radix +import za.co.absa.standardization.{ValidationIssue, ValidationWarning} +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.TypedStructField + +import scala.util.Try + +object IntegralFieldValidator extends NumericFieldValidator { + + private def radixIssues(field: TypedStructField): Seq[ValidationIssue] = { + field.getMetadataString(MetadataKeys.Radix).map { radixString => + val result = for { + radix <- Try(Radix(radixString)) + pattern <- field.pattern + conflict = if ((radix != Radix.DefaultRadix) && (pattern.exists(!_.isDefault))) { + ValidationWarning( + s"Both Radix and Pattern defined for field ${field.name}, for Radix different from ${Radix.DefaultRadix} Pattern is ignored" + ) + } + } yield conflict + tryToValidationIssues(result) + }.getOrElse(Nil) + } + + override def validate(field: TypedStructField): Seq[ValidationIssue] = { + super.validate(field) ++ + checkMetadataKey[String](field, MetadataKeys.Radix) ++ + radixIssues(field) + } +} diff --git a/src/main/scala/za/co/absa/standardization/validation/field/NumericFieldValidator.scala b/src/main/scala/za/co/absa/standardization/validation/field/NumericFieldValidator.scala new file mode 100644 index 0000000..a63ff23 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/NumericFieldValidator.scala @@ -0,0 +1,42 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization.{ValidationError, ValidationIssue} +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.TypedStructField +import za.co.absa.standardization.types.TypedStructField.NumericTypeStructField + +object NumericFieldValidator extends NumericFieldValidator + +class NumericFieldValidator extends ScalarFieldValidator { + private def validateNumericTypeStructField(field: NumericTypeStructField[_]): Seq[ValidationIssue] = { + tryToValidationIssues(field.parser) + } + + override def validate(field: TypedStructField): Seq[ValidationIssue] = { + super.validate(field) ++ + checkMetadataKey[String](field, MetadataKeys.Pattern) ++ + checkMetadataKey[Char](field, MetadataKeys.DecimalSeparator) ++ + checkMetadataKey[Char](field, MetadataKeys.MinusSign) ++ + checkMetadataKey[Char](field, MetadataKeys.GroupingSeparator) ++ ( + field match { + case numericField: NumericTypeStructField[_] => validateNumericTypeStructField(numericField) + case _ => Seq(ValidationError("NumericFieldValidator can validate only fields of numeric types")) + }) + } +} diff --git a/src/main/scala/za/co/absa/standardization/validation/field/ScalarFieldValidator.scala b/src/main/scala/za/co/absa/standardization/validation/field/ScalarFieldValidator.scala new file mode 100644 index 0000000..06f33e3 --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/ScalarFieldValidator.scala @@ -0,0 +1,39 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization.ValidationIssue +import za.co.absa.standardization.types.TypedStructField + +import scala.util.Try + + +/** + * Scalar types schema validation against default value + */ +object ScalarFieldValidator extends ScalarFieldValidator + +class ScalarFieldValidator extends FieldValidator { + + private def validateDefaultValue(field: TypedStructField): Try[Any] = { + field.defaultValueWithGlobal + } + + override def validate(field: TypedStructField): Seq[ValidationIssue] = { + super.validate(field) ++ tryToValidationIssues(validateDefaultValue(field)) + } +} diff --git a/src/main/scala/za/co/absa/standardization/validation/field/TimestampFieldValidator.scala b/src/main/scala/za/co/absa/standardization/validation/field/TimestampFieldValidator.scala new file mode 100644 index 0000000..5b5f00f --- /dev/null +++ b/src/main/scala/za/co/absa/standardization/validation/field/TimestampFieldValidator.scala @@ -0,0 +1,81 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import za.co.absa.standardization._ +import za.co.absa.standardization.time.DateTimePattern +import za.co.absa.standardization.types.parsers.DateTimeParser + +import java.util.Date + +object TimestampFieldValidator extends DateTimeFieldValidator { + import za.co.absa.standardization.implicits.StringImplicits.StringEnhancements + + override protected def patternAnalysisIssues(pattern: DateTimePattern, + defaultValue: Option[String], + defaultTimeZone: Option[String]): Seq[ValidationIssue] = { + + val doubleTimeZoneIssue: Seq[ValidationIssue] = if (pattern.timeZoneInPattern && defaultTimeZone.nonEmpty) { + Seq(ValidationWarning( + "Pattern includes time zone placeholder and default time zone is also defined (will never be used)" + )) + } else { + Nil + } + + val patternIssues: Seq[ValidationIssue] = + if (pattern.pattern == DateTimePattern.EpochNanoKeyword) { + List(ValidationWarning( + "Pattern 'epochnano'. While supported it comes with possible loss of precision beyond microseconds." + )) + } else if (!pattern.isEpoch) { + val placeholders = Set('y', 'M', 'd', 'H', 'm', 's', 'D', 'K', 'h', 'a', 'k', 'n') + val patternChars = pattern.pattern.countUnquoted(placeholders, Set(''')) + patternChars.foldLeft(List.empty[ValidationIssue]) {(acc, item) => item match { + case ('y', 0) => ValidationWarning("No year placeholder 'yyyy' found.") :: acc + case ('M', 0) => ValidationWarning("No month placeholder 'MM' found.") :: acc + case ('d', 0) => ValidationWarning("No day placeholder 'dd' found.") :: acc + case ('H', 0) if patternChars('k') + patternChars('K') + patternChars('h') == 0 => + ValidationWarning("No hour placeholder 'HH' found.") :: acc + case ('m', 0) => ValidationWarning("No minute placeholder 'mm' found.") :: acc + case ('s', 0) => ValidationWarning("No second placeholder 'ss' found.") :: acc + case ('D', x) if x > 0 && patternChars('d') == 0 => + ValidationWarning("Rarely used DayOfYear placeholder 'D' found. Possibly DayOfMonth 'd' intended.") :: acc + case ('h', x) if x > 0 && patternChars('a') == 0 => + ValidationWarning( + "Placeholder for hour 1-12 'h' found, but no am/pm 'a' placeholder. Possibly 0-23 'H' intended." + ) :: acc + case ('K', x) if x > 0 && patternChars('a') == 0 => + ValidationWarning( + "Placeholder for hour 0-11 'K' found, but no am/pm 'a' placeholder. Possibly 1-24 'k' intended." + ) :: acc + case ('n', x) if x > 0 => ValidationWarning( + "Placeholder 'n' for nanoseconds recognized. While supported, it brings possible loss of precision." + ) :: acc + case _ => acc + }} + } else { + Nil + } + doubleTimeZoneIssue ++ patternIssues + } + + override def verifyStringDateTime(dateTime: String)(implicit parser: DateTimeParser): Date = { + parser.parseTimestamp(dateTime) + } + +} diff --git a/src/test/resources/application.conf b/src/test/resources/application.conf new file mode 100644 index 0000000..732c650 --- /dev/null +++ b/src/test/resources/application.conf @@ -0,0 +1,18 @@ +# Copyright 2021 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +standardization.recordId.generation.strategy="none" + +standardization.defaultTimestampTimeZone.default="CET" +standardization.defaultTimestampTimeZone.xml="Africa/Johannesburg" diff --git a/src/test/resources/data/bug.json b/src/test/resources/data/bug.json new file mode 100644 index 0000000..469c733 --- /dev/null +++ b/src/test/resources/data/bug.json @@ -0,0 +1,26 @@ +{ + "type": "struct", + "fields": [ + { + "name": "Conformed_TXN_TIMESTAMP", + "type": "timestamp", + "nullable": true, + "metadata": { + "pattern": "yyyy-MM-ddTHH:mm:ss.SSSX" + } + }, + { + "name": "Conformed_TXN_TIMESTAMP1", + "type": "timestamp", + "nullable": true, + "metadata": { + "pattern": "yyyy-MM-dd HH:mm:ss" + } + }, + { + "name": "Conformed_TXN_TIMESTAMP2", + "type": "timestamp", + "nullable": true + } + ] +} diff --git a/src/test/resources/data/data1Schema.json b/src/test/resources/data/data1Schema.json new file mode 100644 index 0000000..3cab930 --- /dev/null +++ b/src/test/resources/data/data1Schema.json @@ -0,0 +1,83 @@ +{ + "type": "struct", + "fields": [{ + "name": "name", + "type": "string", + "nullable": false, + "metadata": { + + } + }, + { + "name": "surname", + "type": "string", + "nullable": false, + "metadata": { + "default": "Unknown Surname" + } + } + , + { + "name": "hoursWorked", + "type": { + "type": "array", + "elementType": "integer", + "containsNull": false + }, + "nullable": false, + "metadata": { + + } + }, +{ + "name": "employeeNumbers", + "type": { + "type": "array", + "elementType": { + "type": "struct", + "fields": [{ + "name": "numberType", + "type": "string", + "nullable": true, + "metadata": { + + } + }, + { + "name": "numbers", + "type": { + "type": "array", + "elementType": "integer", + "containsNull": true + }, + "nullable": true, + "metadata": { + + } + }] + }, + "containsNull": true + }, + "nullable": true, + "metadata": { + + } + }, + { + "name": "startDate", + "type": "date", + "nullable": false, + "metadata": { + "pattern": "yyyy-MM-dd" + } + }, + { + "name": "updated", + "type": "timestamp", + "nullable": true, + "metadata": { + "pattern": "yyyyMMdd.HHmmss" + } + } + ] +} \ No newline at end of file diff --git a/src/test/resources/data/dateTimestampSchemaOk.json b/src/test/resources/data/dateTimestampSchemaOk.json new file mode 100644 index 0000000..0f1b626 --- /dev/null +++ b/src/test/resources/data/dateTimestampSchemaOk.json @@ -0,0 +1,93 @@ +{ + "type" : "struct", + "fields" : [ { + "name" : "id", + "type" : "long", + "nullable" : false, + "metadata" : { } + }, { + "name" : "dateSampleOk1", + "type" : "date", + "nullable" : true, + "metadata" : { + "pattern" : "dd-MM-yyyy" + } + }, { + "name" : "dateSampleOk2", + "type" : "date", + "nullable" : true, + "metadata" : { + "pattern" : "yyyy-MM-dd" + } + }, { + "name" : "dateSampleOk3", + "type" : "date", + "nullable" : true, + "metadata" : { + "pattern" : "MM/dd/yyyy" + } + }, { + "name" : "dateSampleWrong1", + "type" : "date", + "nullable" : false, + "metadata" : { + "pattern" : "dd-MM-yyyy" + } + }, { + "name" : "dateSampleWrong2", + "type" : "date", + "nullable" : false, + "metadata" : { + "pattern" : "dd-MM-yyyy" + } + }, { + "name" : "dateSampleWrong3", + "type" : "date", + "nullable" : true, + "metadata" : { + "pattern" : "dd-MM-yyyy" + } + }, { + "name" : "timestampSampleOk1", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "pattern" : "yyyy-MM-dd'T'HH:mm:ss" + } + }, { + "name" : "timestampSampleOk2", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "pattern" : "yyyy-MM-dd_HH:mm:ss" + } + }, { + "name" : "timestampSampleOk3", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "pattern" : "yyyyMMddHHmmss" + } + }, { + "name" : "timestampSampleWrong1", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "pattern" : "yyyy-MM-dd'T'HH:mm:ss" + } + }, { + "name" : "timestampSampleWrong2", + "type" : "timestamp", + "nullable" : false, + "metadata" : { + "pattern" : "yyyy-MM-dd'T'HH:mm:ss" + } + }, { + "name" : "timestampSampleWrong3", + "type" : "timestamp", + "nullable" : false, + "metadata" : { + "pattern" : "yyyy-MM-dd'T'HH:mm:ss" + } + } ] +} diff --git a/src/test/resources/data/dateTimestampSchemaWrong.json b/src/test/resources/data/dateTimestampSchemaWrong.json new file mode 100644 index 0000000..4bd3f19 --- /dev/null +++ b/src/test/resources/data/dateTimestampSchemaWrong.json @@ -0,0 +1,107 @@ +{ + "type" : "struct", + "fields" : [ { + "name" : "id", + "type" : "long", + "nullable" : false, + "metadata" : { } + }, { + "name" : "dateSampleOk1", + "type" : "date", + "nullable" : true, + "metadata" : { + "pattern" : "dd-MM-yyyy" + } + }, { + "name" : "dateSampleOk2", + "type" : "date", + "nullable" : true, + "metadata" : { + "pattern" : "yyyy-MM-dd" + } + }, { + "name" : "dateSampleOk3", + "type" : "date", + "nullable" : true, + "metadata" : { + "pattern" : "MM/dd/yyyy" + } + }, { + "name" : "dateSampleWrong1", + "type" : "date", + "nullable" : true, + "metadata" : { + "pattern" : "DD-MM-yyyy" + } + }, { + "name" : "dateSampleWrong2", + "type" : "date", + "nullable" : true, + "metadata" : { + "pattern" : "Dy" + } + }, { + "name" : "dateSampleWrong3", + "type" : "date", + "nullable" : true, + "metadata" : { + "pattern" : "rrr" + } + }, { + "name" : "timestampSampleOk1", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "pattern" : "yyyy-MM-dd'T'HH:mm:ss" + } + }, { + "name" : "timestampSampleOk2", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "pattern" : "yyyy-MM-dd_HH:mm:ss" + } + }, { + "name" : "timestampSampleOk3", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "pattern" : "yyyyMMddHHmmss" + } + }, { + "name" : "timestampSampleWrong1", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "pattern" : "yyyyMMddTHHmmss" + } + }, { + "name" : "timestampSampleWrong2", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "pattern" : "yyyy-MM-dd't'HH:mm:ss" + } + }, { + "name" : "timestampSampleWrong3", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "pattern" : "yyyy-MM-dd" + } + }, { + "name" : "timestampNullDefaultWrong", + "type" : "timestamp", + "nullable" : false, + "metadata" : { + "default" : null + } + }, { + "name" : "timestampNullDefaultOK", + "type" : "timestamp", + "nullable" : true, + "metadata" : { + "default" : null + } + } ] +} diff --git a/src/test/resources/data/integral_overflow_test.csv b/src/test/resources/data/integral_overflow_test.csv new file mode 100644 index 0000000..0c696ab --- /dev/null +++ b/src/test/resources/data/integral_overflow_test.csv @@ -0,0 +1,11 @@ +description,bytesize,shortsize,integersize,longsize +"One","1","1","1","1" +"Full positive","127","32767","2147483647","9223372036854775807" +"Full negative","-128","-32768","-2147483648","-9223372036854775808" +"With plus sign","+127","+32767","+2147483647","+9223372036854775807" +"Overflow","128","32768","2147483648","9223372036854775808" +"Underflow","-129","-32769","-2147483649","-9223372036854775809" +"With zeros","+0","007","-0001","-00000000" +"Nulls",,,, +"Decimal entry","1.0","2.0","3.0","4.0" +"With fractions","3.14","2.71","1.41","1.5" \ No newline at end of file diff --git a/src/test/resources/data/integral_overflow_test_numbers.json b/src/test/resources/data/integral_overflow_test_numbers.json new file mode 100644 index 0000000..cf90e4f --- /dev/null +++ b/src/test/resources/data/integral_overflow_test_numbers.json @@ -0,0 +1 @@ +[{"description":"One","bytesize":1,"shortsize":1,"integersize":1,"longsize":1},{"description":"Full positive","bytesize":127,"shortsize":32767,"integersize":2147483647,"longsize":9223372036854776000},{"description":"Full negative","bytesize":-128,"shortsize":-32768,"integersize":-2147483648,"longsize":-9223372036854776000},{"description":"Overflow","bytesize":128,"shortsize":32768,"integersize":2147483648,"longsize":9223372036854776000},{"description":"Underflow","bytesize":-129,"shortsize":-32769,"integersize":-2147483649,"longsize":-9223372036854776000},{"description":"Nulls","bytesize":null,"shortsize":null,"integersize":null,"longsize":null},{"description":"Decimal entry","bytesize":1.1,"shortsize":2,"integersize":3,"longsize":4}] \ No newline at end of file diff --git a/src/test/resources/data/integral_overflow_test_text.json b/src/test/resources/data/integral_overflow_test_text.json new file mode 100644 index 0000000..8b17a7c --- /dev/null +++ b/src/test/resources/data/integral_overflow_test_text.json @@ -0,0 +1 @@ +[{"description":"One","bytesize":"1","shortsize":"1","integersize":"1","longsize":"1"},{"description":"Full positive","bytesize":"127","shortsize":"32767","integersize":"2147483647","longsize":"9223372036854775807"},{"description":"Full negative","bytesize":"-128","shortsize":"-32768","integersize":"-2147483648","longsize":"-9223372036854775808"},{"description":"With plus sign","bytesize":"+127","shortsize":"+32767","integersize":"+2147483647","longsize":"+9223372036854775807"},{"description":"Overflow","bytesize":"128","shortsize":"32768","integersize":"2147483648","longsize":"9223372036854775808"},{"description":"Underflow","bytesize":"-129","shortsize":"-32769","integersize":"-2147483649","longsize":"-9223372036854775809"},{"description":"With zeros","bytesize":"+0","shortsize":"007","integersize":"-0001","longsize":"-00000000"},{"description":"Nulls","bytesize":null,"shortsize":null,"integersize":null,"longsize":null},{"description":"Decimal entry","bytesize":"1.0","shortsize":"2.0","integersize":"3.0","longsize":"4.0"},{"description":"With fractions","bytesize":"3.14","shortsize":"2.71","integersize":"1.41","longsize":"1.5"}] \ No newline at end of file diff --git a/src/test/resources/data/patients.json b/src/test/resources/data/patients.json new file mode 100644 index 0000000..459dec2 --- /dev/null +++ b/src/test/resources/data/patients.json @@ -0,0 +1,59 @@ +[ + { + "first name": "Jane", + "last name": "Goodall", + "body stats": { + "height": 164, + "weight": 61, + "miscellaneous": { + "eye color": "green", + "glasses": true + }, + "temperature measurements": [ + 36.6, + 36.7, + 37, + 36.6 + ] + } + }, + { + "first name": "Scott", + "last name": "Lang", + "body stats": { + "height": "various", + "weight": 83, + "miscellaneous": { + "eye color": "blue", + "glasses": false + }, + "temperature measurements": [ + 36.6, + 36.7, + 37, + 36.6 + ] + } + }, + { + "first name": "Aldrich", + "last name": "Killian", + "body stats": { + "height": 181, + "weight": 90, + "miscellaneous": { + "eye color": "brown or orange", + "glasses": "not any more" + }, + "temperature measurements": [ + 36.7, + 36.5, + 38, + 48, + 152, + 831, + "exploded" + ] + } + } +] diff --git a/src/test/resources/data/standardizeJsonSrc.json b/src/test/resources/data/standardizeJsonSrc.json new file mode 100644 index 0000000..8e344d7 --- /dev/null +++ b/src/test/resources/data/standardizeJsonSrc.json @@ -0,0 +1 @@ +{"rootField":"rootfieldval","rootStruct":{"subFieldA":123,"subFieldB":"subfieldval"},"rootStruct2":{"subStruct2":{"subSub2FieldA":456,"subSub2FieldB":"subsubfieldval"}},"rootArray":[{"arrayFieldA":789,"arrayFieldB":"arrayfieldval","arrayStruct":{"subFieldA":321,"subFieldB":"xyz"}}]} diff --git a/src/test/scala/za/co/absa/standardization/ArrayTransformationsSuite.scala b/src/test/scala/za/co/absa/standardization/ArrayTransformationsSuite.scala new file mode 100644 index 0000000..6b8a65e --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/ArrayTransformationsSuite.scala @@ -0,0 +1,90 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import org.scalatest.funsuite.AnyFunSuite + +import scala.util.Random + +case class InnerStruct(a: Int, b: String = null) +case class OuterStruct(id: Int, vals: Seq[InnerStruct]) +case class ZippedOuterStruct(id: Int, vals: Seq[(Int, InnerStruct)]) +case class Outer2(z: OuterStruct) +case class ZippedOuter2(z: ZippedOuterStruct) + +case class MyA(b: MyB) +case class MyA2(b: MyB2) +case class MyB(c: MyC) +case class MyB2(c: MyC2) +case class MyC(something: Int) +case class MyC2(something: Int, somethingByTwo: Int) + +case class Nested2Levels(a: List[List[Option[Int]]]) +case class Nested1Level(a: List[Option[Int]]) + +class ArrayTransformationsSuite extends AnyFunSuite with SparkTestBase { + + private val inputData = (0 to 10).toList.map(x => (x, Random.shuffle((0 until x).toList))) + private val inputDataOrig = OuterStruct(-1, null) :: inputData.map({ case (x, vals) => OuterStruct(x, vals.map(InnerStruct(_))) }) + + private val extraNested = inputDataOrig.map(Outer2) + + import spark.implicits._ + + test("Testing nestedWithColumn") { + val df = spark.createDataFrame(extraNested) + + val res = ArrayTransformations.nestedWithColumn(df)("z.id", $"z.id" * 2) + + val actual = res.as[Outer2].collect().sortBy(x => x.z.id) + val expected = extraNested.toArray.map(x => x.copy(x.z.copy(x.z.id * 2))).sortBy(x => x.z.id) + + assertResult(expected)(actual) + } + + test("Testing nestedWithColumn 3 levels deep") { + val df = spark.createDataFrame(Seq( + MyA(MyB(MyC(0))), MyA(MyB(MyC(1))), MyA(MyB(MyC(2))), MyA(MyB(MyC(3))), MyA(MyB(MyC(4))))) + + val expected = Seq( + MyA2(MyB2(MyC2(0, 0))), MyA2(MyB2(MyC2(1, 2))), MyA2(MyB2(MyC2(2, 4))), MyA2(MyB2(MyC2(3, 6))), MyA2(MyB2(MyC2(4, 8)))).sortBy(_.b.c.something).toList + + val res = ArrayTransformations.nestedWithColumn(df)("b.c.somethingByTwo", $"b.c.something" * 2).as[MyA2].collect.sortBy(_.b.c.something).toList + + assertResult(expected)(res) + } + + test("Testing flattenArrays") { + val df = spark.createDataFrame(Seq( + Nested2Levels(List( + List(Some(1)), null, List(None), List(Some(2)), + List(Some(3), Some(4)), List(Some(5), Some(6)))), + Nested2Levels(List()), + Nested2Levels(null))) + + val res = ArrayTransformations.flattenArrays(df, "a") + + val exp = List( + Nested1Level(List(Some(1), None, Some(2), Some(3), Some(4), Some(5), Some(6))), + Nested1Level(List()), + Nested1Level(null)).toSeq + + val resLocal = res.as[Nested1Level].collect().toSeq + + assertResult(exp)(resLocal) + } +} diff --git a/src/test/scala/za/co/absa/standardization/ErrorMessageFactory.scala b/src/test/scala/za/co/absa/standardization/ErrorMessageFactory.scala new file mode 100644 index 0000000..05a567b --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/ErrorMessageFactory.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +object ErrorMessageFactory { + val errColSchema: String = "\n |-- errCol: array (nullable = true)\n"+ + " | |-- element: struct (containsNull = false)\n"+ + " | | |-- errType: string (nullable = true)\n"+ + " | | |-- errCode: string (nullable = true)\n"+ + " | | |-- errMsg: string (nullable = true)\n"+ + " | | |-- errCol: string (nullable = true)\n"+ + " | | |-- rawValues: array (nullable = true)\n"+ + " | | | |-- element: string (containsNull = true)\n"+ + " | | |-- mappings: array (nullable = true)\n"+ + " | | | |-- element: struct (containsNull = true)\n"+ + " | | | | |-- mappingTableColumn: string (nullable = true)\n"+ + " | | | | |-- mappedDatasetColumn: string (nullable = true)\n" + + def attachErrColToSchemaPrint(schemaPrint: String): String = { + schemaPrint + errColSchema + } +} diff --git a/src/test/scala/za/co/absa/standardization/JsonUtilsSuite.scala b/src/test/scala/za/co/absa/standardization/JsonUtilsSuite.scala new file mode 100644 index 0000000..fcc8740 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/JsonUtilsSuite.scala @@ -0,0 +1,63 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import org.scalatest.funsuite.AnyFunSuite + +class JsonUtilsSuite extends AnyFunSuite with SparkTestBase { + test("Test JSON pretty formatting from a JSON string") { + val inputJson = """[{"id":1,"items":[{"itemid":100,"subitems":[{"elems":[{"numbers":["1","2","3b","4","5c","6"]}],"code":100}]}]}]""" + val expected = """[ { + | "id" : 1, + | "items" : [ { + | "itemid" : 100, + | "subitems" : [ { + | "elems" : [ { + | "numbers" : [ "1", "2", "3b", "4", "5c", "6" ] + | } ], + | "code" : 100 + | } ] + | } ] + |} ]""".stripMargin.replace("\r\n", "\n") + + val actual = JsonUtils.prettyJSON(inputJson) + + assert(expected == actual) + } + + test("Test JSON pretty formatting from a Spark JSON string") { + val inputJsons = Seq("""{"value": 1}""", """{"value": 2}""") + val expected = "[ {\n \"value\" : 1\n}, {\n \"value\" : 2\n} ]" + + val actual = JsonUtils.prettySparkJSON(inputJsons) + + assert(expected == actual) + } + + test("Test a dataframe created from a JSON") { + val inputJson = Seq("""{"value":1}""", """{"value":2}""") + + val df = JsonUtils.getDataFrameFromJson(spark, inputJson) + + val expectedSchema = """root + | |-- value: long (nullable = true) + |""".stripMargin.replace("\r\n", "\n") + val actualSchema = df.schema.treeString + + assert(expectedSchema == actualSchema) + } +} diff --git a/src/test/scala/za/co/absa/standardization/LoggerTestBase.scala b/src/test/scala/za/co/absa/standardization/LoggerTestBase.scala new file mode 100644 index 0000000..2a0f590 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/LoggerTestBase.scala @@ -0,0 +1,47 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import org.apache.spark.sql.DataFrame +import org.slf4j.event.Level +import org.slf4j.event.Level._ +import org.slf4j.{Logger, LoggerFactory} + +trait LoggerTestBase { + + val logger: Logger = LoggerFactory.getLogger(this.getClass) + + def logLevelToLogFunction(logLevel: Level): String => Unit = { + logLevel match { + case TRACE => logger.trace + case DEBUG => logger.debug + case INFO => logger.info + case WARN => logger.warn + case ERROR => logger.error + } + } + + protected def logDataFrameContent(df: DataFrame, logLevel: Level = DEBUG): Unit = { + import za.co.absa.standardization.implicits.DataFrameImplicits.DataFrameEnhancements + + val logFnc = logLevelToLogFunction(logLevel) + logFnc(df.schema.treeString) + + val dfData = df.dataAsString(false) + logFnc(dfData) + } +} diff --git a/src/test/scala/za/co/absa/standardization/RecordIdGenerationSuite.scala b/src/test/scala/za/co/absa/standardization/RecordIdGenerationSuite.scala new file mode 100644 index 0000000..8d3435b --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/RecordIdGenerationSuite.scala @@ -0,0 +1,91 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import za.co.absa.standardization.RecordIdGenerationSuite.{SomeData, SomeDataWithId} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import za.co.absa.standardization.RecordIdGeneration.IdType.{NoId, StableHashId, TrueUuids} +import za.co.absa.standardization.RecordIdGeneration._ + +import java.util.UUID + +class RecordIdGenerationSuite extends AnyFlatSpec with Matchers with SparkTestBase { + import spark.implicits._ + + val data1 = Seq( + SomeData("abc", 12), + SomeData("def", 34), + SomeData("xyz", 56) + ) + + "RecordIdColumnByStrategy" should s"do noop with $NoId" in { + val df1 = spark.createDataFrame(data1) + val updatedDf1 = addRecordIdColumnByStrategy(df1, "idColumnWontBeUsed", NoId) + + df1.collectAsList() shouldBe updatedDf1.collectAsList() + } + + it should s"always yield the same IDs with ${StableHashId}" in { + + val df1 = spark.createDataFrame(data1) + val updatedDf1 = addRecordIdColumnByStrategy(df1, "stableId", StableHashId) + val updatedDf2 = addRecordIdColumnByStrategy(df1, "stableId", StableHashId) + + updatedDf1.as[SomeDataWithId].collect() should contain theSameElementsInOrderAs updatedDf2.as[SomeDataWithId].collect() + + Seq(updatedDf1, updatedDf2).foreach { updatedDf => + val updatedData = updatedDf.as[SomeDataWithId].collect() + updatedData.length shouldBe 3 + } + } + + it should s"yield the different IDs with $TrueUuids" in { + + val df1 = spark.createDataFrame(data1) + val updatedDf1 = addRecordIdColumnByStrategy(df1, "trueId", TrueUuids) + val updatedDf2 = addRecordIdColumnByStrategy(df1, "trueId", TrueUuids) + + updatedDf1.as[SomeDataWithId].collect() shouldNot contain theSameElementsAs updatedDf2.as[SomeDataWithId].collect() + + Seq(updatedDf1, updatedDf2).foreach { updatedDf => + val updatedData = updatedDf.as[SomeDataWithId].collect() + updatedData.length shouldBe 3 + updatedData.foreach(entry => UUID.fromString(entry.enceladus_record_id)) + } + } + + "RecordIdGenerationStrategyFromConfig" should "correctly load uuidType from config (case insensitive)" in { + getRecordIdGenerationType("UUiD") shouldBe TrueUuids + getRecordIdGenerationType("StaBleHASHiD") shouldBe StableHashId + getRecordIdGenerationType("nOnE") shouldBe NoId + + val caughtException = the[IllegalArgumentException] thrownBy { + getRecordIdGenerationType("InVaLiD") + } + caughtException.getMessage should include("Invalid value 'InVaLiD' was encountered for id generation strategy, use one of: uuid, stableHashId, none.") + } + +} + +object RecordIdGenerationSuite { + + case class SomeData(value1: String, value2: Int) + + case class SomeDataWithId(value1: String, value2: Int, enceladus_record_id: String) + +} diff --git a/src/test/scala/za/co/absa/standardization/SchemaValidationSuite.scala b/src/test/scala/za/co/absa/standardization/SchemaValidationSuite.scala new file mode 100644 index 0000000..ea7fe5b --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/SchemaValidationSuite.scala @@ -0,0 +1,227 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite + +/** + * A test suite for validation of scalar data types + */ +//noinspection ZeroIndexToHead +class SchemaValidationSuite extends AnyFunSuite with LoggerTestBase { + + test("Scalar types should be validated") { + val schema = StructType( + Array( + StructField("id", LongType), + StructField("string_good", StringType, nullable = true, Metadata.fromJson(""" { "default": "Unknown" } """)), + StructField("string_bad", StringType, nullable = true, Metadata.fromJson(""" { "default": -1 } """)), + StructField("bool_good", BooleanType, nullable = true, Metadata.fromJson(""" { "default": "true" } """)), + StructField("bool_bad", BooleanType, nullable = true, Metadata.fromJson(""" { "default": "1" } """)), + StructField("tiny_good", ByteType, nullable = true, Metadata.fromJson(""" { "default": "127" } """)), + StructField("tiny_bad", ByteType, nullable = true, Metadata.fromJson(""" { "default": "-127-" } """)), + StructField("short_good", ShortType, nullable = false, Metadata.fromJson(""" { "default": "20000" } """)), + StructField("short_bad", ShortType, nullable = false, Metadata.fromJson(""" { "default": "20000.0" } """)), + StructField("integer_good", IntegerType, nullable = false, Metadata.fromJson(""" { "default": "1000000000" } """)), + StructField("integer_bad", IntegerType, nullable = false, Metadata.fromJson(""" { "default": "number" } """)), + StructField("long_good", LongType, nullable = false, Metadata.fromJson(""" { "default": "1000000000000000000" } """)), + StructField("long_bad", LongType, nullable = false, Metadata.fromJson(""" { "default": "wrong" } """)), + StructField("float_good", FloatType, nullable = false, Metadata.fromJson(""" { "default": "1000.222" } """)), + StructField("float_bad", FloatType, nullable = false, Metadata.fromJson(""" { "default": "1000.2.22" } """)), + StructField("double_good", DoubleType, nullable = false, Metadata.fromJson(""" { "default": "1000000.5544" } """)), + StructField("double_bad", DoubleType, nullable = false, Metadata.fromJson(""" { "default": "++1000000.5544" } """)), + StructField("decimal_good", DecimalType(20,10), nullable = false, Metadata.fromJson(""" { "default": "314159265.314159265"}""")), + StructField("decimal_bad", DecimalType(20,10), nullable = false, Metadata.fromJson(""" { "default": "314159265358882224.3141.59265"}""")) + ) + ) + + val failures = SchemaValidator.validateSchema(schema) + if (failures.lengthCompare(9) != 0) { + logger.error("Validation errors:") + logger.error(failures.mkString("\n")) + } + + assert(failures.lengthCompare(9) == 0) + assert(failures(0).fieldName == "string_bad") + assert(failures(1).fieldName == "bool_bad") + assert(failures(2).fieldName == "tiny_bad") + assert(failures(3).fieldName == "short_bad") + assert(failures(4).fieldName == "integer_bad") + assert(failures(5).fieldName == "long_bad") + assert(failures(6).fieldName == "float_bad") + assert(failures(7).fieldName == "double_bad") + assert(failures(8).fieldName == "decimal_bad") + } + + test("Overflows should generate validation errors") { + val schema = StructType( + Array( + StructField("id", LongType), + StructField("tiny_good", ByteType, nullable = true, Metadata.fromJson(""" { "default": "127" } """)), + StructField("tiny_bad", ByteType, nullable = true, Metadata.fromJson(""" { "default": "128" } """)), + StructField("short_good", ShortType, nullable = false, Metadata.fromJson(""" { "default": "20000" } """)), + StructField("short_bad", ShortType, nullable = false, Metadata.fromJson(""" { "default": "32768" } """)), + StructField("integer_good", IntegerType, nullable = false, Metadata.fromJson(""" { "default": "1000000000" } """)), + StructField("integer_bad", IntegerType, nullable = false, Metadata.fromJson(""" { "default": "6000000000" } """)), + StructField("long_good", LongType, nullable = false, Metadata.fromJson(""" { "default": "1000000000000000000" } """)), + StructField("long_bad", LongType, nullable = false, Metadata.fromJson(""" { "default": "10000000000000000000" } """)), + StructField("float_good", FloatType, nullable = false, Metadata.fromJson(""" { "default": "1000.222" } """)), + StructField("float_bad", FloatType, nullable = false, Metadata.fromJson(""" { "default": "1e40" } """)), + StructField("double_good", DoubleType, nullable = false, Metadata.fromJson(""" { "default": "1000000.5544" } """)), + StructField("double_bad", DoubleType, nullable = false, Metadata.fromJson(""" { "default": "1e310" } """)), + StructField("decimal_good", DecimalType(20,10), nullable = false, Metadata.fromJson(""" { "default": "-9999999999.9999999999"}""")), + StructField("decimal_bad", DecimalType(20,10), nullable = false, Metadata.fromJson(""" { "default": "123456789012345678901.12345678901"}""")) + ) + ) + + val failures = SchemaValidator.validateSchema(schema) + if (failures.lengthCompare(7) != 0) { + logger.error("Validation errors:") + logger.error(failures.mkString("\n")) + } + + assert(failures.lengthCompare(7) == 0) + assert(failures(0).fieldName == "tiny_bad") + assert(failures(1).fieldName == "short_bad") + assert(failures(2).fieldName == "integer_bad") + assert(failures(3).fieldName == "long_bad") + assert(failures(4).fieldName == "float_bad") + assert(failures(5).fieldName == "double_bad") + assert(failures(6).fieldName == "decimal_bad") + } + + test("Date/Time patterns should be validated") { + val schema = StructType( + Array( + StructField("id", LongType), + StructField("name", StringType), + StructField("orderdate", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "kk-MM-yyyy" } """)), + StructField("deliverydate", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "wrong" } """)), + StructField("paymentmade", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "bad" } """)), + StructField("paymentreceived", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "dd-MM-yyyyTHH:mm:ss" } """)) + ) + ) + + val failures = SchemaValidator.validateSchema(schema) + if (failures.lengthCompare(4) != 0) { + logger.error("Validation errors:") + logger.error(failures.mkString("\n")) + } + + assert(failures.lengthCompare(4) == 0) + assert(failures(0).fieldName == "orderdate") + assert(failures(1).fieldName == "deliverydate") + assert(failures(2).fieldName == "paymentmade") + assert(failures(3).fieldName == "paymentreceived") + } + + test("Date/Time default values should be validated") { + val schema = StructType( + Array( + StructField("id", LongType), + StructField("name", StringType), + StructField("orderdate", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "dd-MM-yyyy", "default": "2015-01-01" } """)), + StructField("deliverydate", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "dd-MM-yyyy", "default": "KKK" } """)), + StructField("paymentmade", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "dd-MM-yyyy'T'HH:mm:ss", "default": "2005-01-01T18:00:12" } """)), + StructField("paymentreceived", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "dd-MM-yyyy'T'HH:mm:ss", "default": "ZZZ" } """)) + ) + ) + + val failures = SchemaValidator.validateSchema(schema) + if (failures.lengthCompare(4) != 0) { + logger.error("Validation errors:") + logger.error(failures.mkString("\n")) + } + + assert(failures.lengthCompare(4) == 0) + assert(failures(0).fieldName == "orderdate") + assert(failures(1).fieldName == "deliverydate") + assert(failures(2).fieldName == "paymentmade") + assert(failures(3).fieldName == "paymentreceived") + } + + test("Nested struct and array fields should be validated") { + val schema = StructType( + Array( + StructField("id", StringType), + StructField("name", StringType), + StructField("orders", ArrayType(StructType(Array( + StructField("orderdate", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "DD-MM-yyyy" } """)), + StructField("deliverdate", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "dd-MM-yyyy'T'HH:mm:ss" } """)), + StructField("ordermade", TimestampType, nullable = false, Metadata.fromJson(""" { "pattern": "dd-MM-yyyy'T'HH:mm:ss" } """)), + StructField("payment", StructType(Array( + StructField("due", DateType, nullable = false, Metadata.fromJson(""" { "pattern": "dd-MM-yyyy'T'HH:mm:ss" } """)), + StructField("made", TimestampType, nullable = false, Metadata.fromJson(""" { "pattern": "dd-MM-yyyy" } """)), + StructField("expected", TimestampType, nullable = false, Metadata.fromJson(""" { "pattern": "foo" } """)), + StructField("received", TimestampType, nullable = false, Metadata.fromJson(""" { "pattern": "dd-MM-yyyy'T'HH:mm:ss" } """)))) + )))) + ), + StructField("matrix", ArrayType(ArrayType(StructType(Array( + StructField("bar", TimestampType, nullable = true, Metadata.fromJson(""" { "pattern": "DD-MM-yyyy'T'HH:mm:ss" } """))))) + )) + )) + val failures = SchemaValidator.validateSchema(schema) + if (failures.lengthCompare(6) != 0) { + logger.error("Schema:\n") + logger.error(schema.prettyJson) + logger.error("Validation errors:") + logger.error(failures.mkString("\n")) + } + + assert(failures.lengthCompare(6) == 0) + assert(failures(0).fieldName == "orders[].orderdate") + assert(failures(1).fieldName == "orders[].deliverdate") + assert(failures(2).fieldName == "orders[].payment.due") + assert(failures(3).fieldName == "orders[].payment.made") + assert(failures(4).fieldName == "orders[].payment.expected") + assert(failures(5).fieldName == "matrix[][].bar") + } + + test("Column names should not contain dots") { + val schema = StructType( + Array( + StructField("my.id", StringType), + StructField("name", StringType), + StructField("orders", ArrayType(StructType(Array( + StructField("order.date", StringType), + StructField("deliverdate", StringType), + StructField("payment", StructType(Array( + StructField("due.time", StringType), + StructField("made", StringType))) + )))) + ), + StructField("matrix", ArrayType(ArrayType(StructType(Array( + StructField("foo.bar", StringType)))) + )) + )) + val failures = SchemaValidator.validateSchema(schema) + if (failures.lengthCompare(4) != 0) { + logger.error("Schema:\n") + logger.error(schema.prettyJson) + logger.error("Validation errors:") + logger.error(failures.mkString("\n")) + } + + assert(failures.lengthCompare(4) == 0) + assert(failures(0).fieldName == "my.id") + assert(failures(1).fieldName == "orders[].order.date") + assert(failures(2).fieldName == "orders[].payment.due.time") + assert(failures(3).fieldName == "matrix[][].foo.bar") + } + +} diff --git a/src/test/scala/za/co/absa/standardization/SparkTestBase.scala b/src/test/scala/za/co/absa/standardization/SparkTestBase.scala new file mode 100644 index 0000000..782049d --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/SparkTestBase.scala @@ -0,0 +1,137 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import com.typesafe.config.{Config, ConfigFactory} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.log4j.{Level, Logger} +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import za.co.absa.standardization.time.TimeZoneNormalizer +import scala.collection.JavaConversions._ + +import java.io.File + +trait SparkTestBase { self => + TimeZoneNormalizer.normalizeJVMTimeZone() + + val config: Config = ConfigFactory.load() + val sparkMaster: String = config.getString("standardization.testUtils.sparkTestBaseMaster") + + val sparkBuilder: SparkSession.Builder = SparkSession.builder() + .master(sparkMaster) + .appName(s"Enceladus test - ${self.getClass.getName}") + .config("spark.ui.enabled", "false") + .config("spark.debug.maxToStringFields", 100) // scalastyle:ignore magic.number + // ^- default value is insufficient for some tests, 100 is a compromise between resource consumption and expected need + + implicit val spark: SparkSession = if (sparkMaster == "yarn") { + val confDir = config.getString("enceladus.utils.testUtils.hadoop.conf.dir") + val distJarsDir = config.getString("enceladus.utils.testUtils.spark.distJars.dir") + val sparkHomeDir = config.getString("enceladus.utils.testUtils.spark.home.dir") + + val hadoopConfigs = SparkTestBase.getHadoopConfigurationForSpark(confDir) + val sparkConfigs = SparkTestBase.loadSparkDefaults(sparkHomeDir) + val allConfigs = hadoopConfigs ++ sparkConfigs + + //get a list of all dist jars + val distJars = FileSystem.get(SparkTestBase.getHadoopConfiguration(confDir)).listStatus(new Path(distJarsDir)).map(_.getPath) + val localJars = SparkTestBase.getDepsFromClassPath("absa") + val currentJars = SparkTestBase.getCurrentProjectJars + val deps = (distJars ++ localJars ++currentJars).mkString(",") + + sparkBuilder.config(new SparkConf().setAll(allConfigs)) + .config("spark.yarn.jars", deps) + .config("spark.deploy.mode", "client") + .getOrCreate() + + } else { + sparkBuilder + .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.driver.host", "127.0.0.1") + .getOrCreate() + } + TimeZoneNormalizer.normalizeSessionTimeZone(spark) + + // Do not display INFO entries for tests + Logger.getLogger("org").setLevel(Level.WARN) + Logger.getLogger("akka").setLevel(Level.WARN) +} + +object SparkTestBase { + /** + * Gets a Hadoop configuration object from the specified hadoopConfDir parameter + * + * @param hadoopConfDir string representation of HADOOP_CONF_DIR + */ + def getHadoopConfiguration(hadoopConfDir: String): Configuration = { + val hadoopConf = new Configuration() + hadoopConf.addResource(new Path(s"$hadoopConfDir/hdfs-site.xml")) + hadoopConf.addResource(new Path(s"$hadoopConfDir/yarn-site.xml")) + hadoopConf.addResource(new Path(s"$hadoopConfDir/core-site.xml")) + + hadoopConf + } + + /** + * Converts all entries from a Hadoop configuration to Map, which can be consumed by SparkConf + * + * @param hadoopConf Hadoop Configuration object to be converted into Spark configs + */ + def hadoopConfToSparkMap(hadoopConf: Configuration): Map[String, String] = { + hadoopConf.iterator().map(entry => (s"spark.hadoop.${entry.getKey}", entry.getValue)).toMap + } + + /** + * Get Hadoop configuration consumable by SparkConf + */ + def getHadoopConfigurationForSpark(hadoopConfDir: String): Map[String, String] = { + hadoopConfToSparkMap(getHadoopConfiguration(hadoopConfDir)) + } + + /** + * Loads spark defaults from the specified SPARK_HOME directory + */ + def loadSparkDefaults(sparkHome: String): Map[String, String] = { + val sparkConfigIn = ConfigFactory.empty().atPath(s"$sparkHome/conf/spark-defaults.conf") + sparkConfigIn + .entrySet() + .filter(_.getKey != "spark.yarn.jars") + .map(entry => (entry.getKey, entry.getValue.unwrapped().toString)) + .toMap + } + + /** + * Gets the list of jars, which are currently loaded in the classpath and contain the given inclPattern in the file name + */ + def getDepsFromClassPath(inclPattern: String): Seq[String] = { + val cl = this.getClass.getClassLoader + cl.asInstanceOf[java.net.URLClassLoader].getURLs.filter(c => c.toString.contains(inclPattern)).map(_.toString()) + } + + /** + * Get the list of jar(s) of the current project + */ + def getCurrentProjectJars: Seq[String] = { + val targetDir = new File(s"${System.getProperty("user.dir")}/target") + targetDir + .listFiles() + .filter(f => f.getName.split("\\.").last.toLowerCase() == "jar" && f.getName.contains("original")) + .map(_.getAbsolutePath) + } +} diff --git a/src/test/scala/za/co/absa/standardization/StandardizationCsvSuite.scala b/src/test/scala/za/co/absa/standardization/StandardizationCsvSuite.scala new file mode 100644 index 0000000..becd1ce --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/StandardizationCsvSuite.scala @@ -0,0 +1,133 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.{AnyFunSuite, FixtureAnyFunSuite} +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary + + +class StandardizationCsvSuite extends AnyFunSuite with SparkTestBase { + import spark.implicits._ + import za.co.absa.standardization.implicits.DataFrameImplicits.DataFrameEnhancements + + private implicit val udfLib: UDFLibrary = new UDFLibrary + private implicit val defaults: Defaults = GlobalDefaults + + + private val csvContent = spark.sparkContext.parallelize( + """101,102,1,2019-05-04,2019-05-04 + |201,202,2,2019-05-05,2019-05-05 + |301,302,1,2019-05-06,2019-05-06 + |401,402,1,2019-05-07,2019-05-07 + |501,502,,2019-05-08,2019-05-08""" + .stripMargin.lines.toList ).toDS() + + test("Test standardizing a CSV without special columns") { + val schema: StructType = StructType(Seq( + StructField("A1", IntegerType, nullable = true), + StructField("A2", IntegerType, nullable = true), + StructField("A3", IntegerType, nullable = true), + StructField("A4", StringType, nullable = true, + Metadata.fromJson("""{"pattern": "yyyy-MM-dd"}""")), + StructField("A5", StringType, nullable = true) + )) + + val schemaWithStringType: StructType = StructType(Seq( + StructField("A1", StringType, nullable = true), + StructField("A2", StringType, nullable = true), + StructField("A3", StringType, nullable = true), + StructField("A4", StringType, nullable = true), + StructField("A5", StringType, nullable = true) + )) + + val expectedOutput = + """+---+---+----+----------+----------+------+ + ||A1 |A2 |A3 |A4 |A5 |errCol| + |+---+---+----+----------+----------+------+ + ||101|102|1 |2019-05-04|2019-05-04|[] | + ||201|202|2 |2019-05-05|2019-05-05|[] | + ||301|302|1 |2019-05-06|2019-05-06|[] | + ||401|402|1 |2019-05-07|2019-05-07|[] | + ||501|502|null|2019-05-08|2019-05-08|[] | + |+---+---+----+----------+----------+------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val rawDataFrame = spark.read.option("header", false).schema(schemaWithStringType).csv(csvContent) + val stdDf = Standardization.standardize(rawDataFrame, schema).cache() + val actualOutput = stdDf.dataAsString(truncate = false) + + assert(actualOutput == expectedOutput) + } + + test("Test standardizing a CSV with special columns when error column has wrong type") { + val schema: StructType = StructType(Seq( + StructField("A1", IntegerType, nullable = true), + StructField(ErrorMessage.errorColumnName, IntegerType, nullable = true), + StructField("enceladus_info_version", IntegerType, nullable = true), + StructField("enceladus_info_date", DateType, nullable = true, + Metadata.fromJson("""{"pattern": "yyyy-MM-dd"}""")), + StructField("enceladus_info_date_string", StringType, nullable = true) + )) + + val schemaStr: StructType = StructType(Seq( + StructField("A1", StringType, nullable = true), + StructField(ErrorMessage.errorColumnName, StringType, nullable = true), + StructField("enceladus_info_version", StringType, nullable = true), + StructField("enceladus_info_date", StringType, nullable = true), + StructField("enceladus_info_date_string", StringType, nullable = true) + )) + + val rawDataFrame = spark.read.option("header", false).schema(schemaStr).csv(csvContent) + + assertThrows[ValidationException] { + Standardization.standardize(rawDataFrame, schema).cache() + } + } + + test("Test standardizing a CSV with special columns when error column has correct type") { + val schema: StructType = StructType(Seq( + StructField("A1", IntegerType, nullable = true), + StructField("A2", IntegerType, nullable = true), + StructField("enceladus_info_version", IntegerType, nullable = false), + StructField("enceladus_info_date", DateType, nullable = true, + Metadata.fromJson("""{"pattern": "yyyy-MM-dd"}""")), + StructField("enceladus_info_date_string", StringType, nullable = true) + )) + + val schemaStr: StructType = StructType(Seq( + StructField("A1", StringType, nullable = true), + StructField("A2", StringType, nullable = true), + StructField("enceladus_info_version", StringType, nullable = true), + StructField("enceladus_info_date", StringType, nullable = true), + StructField("enceladus_info_date_string", StringType, nullable = true) + )) + + val rawDataFrame = spark.read.option("header", false).schema(schemaStr).csv(csvContent) + .withColumn(ErrorMessage.errorColumnName, typedLit(List[ErrorMessage]())) + + val stdDf = Standardization.standardize(rawDataFrame, schema).cache() + val failedRecords = stdDf.filter(size(col(ErrorMessage.errorColumnName)) > 0).count + + assert(stdDf.schema.exists(field => field.name == ErrorMessage.errorColumnName)) + assert(failedRecords == 1) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/StandardizationParquetSuite.scala b/src/test/scala/za/co/absa/standardization/StandardizationParquetSuite.scala new file mode 100644 index 0000000..bda61aa --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/StandardizationParquetSuite.scala @@ -0,0 +1,361 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import com.typesafe.config.{ConfigFactory, ConfigValueFactory} + +import java.util.UUID +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import org.apache.spark.sql.functions.{col, to_timestamp} +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.stages.TypeParserException +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary + +// For creation of Structs in DF +private case class FooClass(bar: Boolean) + +class StandardizationParquetSuite extends AnyFunSuite with SparkTestBase { + import spark.implicits._ + import za.co.absa.standardization.implicits.DataFrameImplicits.DataFrameEnhancements + + private implicit val udfLibrary:UDFLibrary = new UDFLibrary() + private implicit val defaults: Defaults = GlobalDefaults + + private val tsPattern = "yyyy-MM-dd HH:mm:ss zz" + + private val configPlain = ConfigFactory.load() + private val configWithSchemaValidation = configPlain + .withValue("standardization.failOnInputNotPerSchema", ConfigValueFactory.fromAnyRef(true)) + private val uuidConfig = configPlain + .withValue("standardization.recordId.generation.strategy", ConfigValueFactory.fromAnyRef("uuid")) + private val stableIdConfig = configPlain + .withValue("standardization.recordId.generation.strategy", ConfigValueFactory.fromAnyRef("stablehashid")) + + private val data = Seq ( + (1, Array("A", "B"), FooClass(false), "1970-01-01 00:00:00 UTC"), + (2, Array("C"), FooClass(true), "1970-01-01 00:00:00 CET") + ) + private val sourceDataDF = data.toDF("id", "letters", "struct", "str_ts") + .withColumn("ts", to_timestamp(col("str_ts"), tsPattern)) + + test("All columns standardized") { + val expected = + """+---+-------+-------+------+ + ||id |letters|struct |errCol| + |+---+-------+-------+------+ + ||1 |[A, B] |[false]|[] | + ||2 |[C] |[true] |[] | + |+---+-------+-------+------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val seq = Seq( + StructField("id", LongType, nullable = false), + StructField("letters", ArrayType(StringType), nullable = false), + StructField("struct", StructType(Seq(StructField("bar", BooleanType))), nullable = false) + ) + val schema = StructType(seq) + val destDF = Standardization.standardize(sourceDataDF, schema) + + val actual = destDF.dataAsString(truncate = false) + assert(actual == expected) + } + + + test("Missing non-nullable fields are filled with default values and error appears in error column") { + val expected = + """+---+------------+-------------------+----------+------------+-------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||id |string_field|timestamp_field |long_field|double_field|decimal_field|errCol | + |+---+------------+-------------------+----------+------------+-------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||1 | |1970-01-01 00:00:00|0 |0 |3.14 |[[stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, string_field, [null], []], [stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, timestamp_field, [null], []], [stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, long_field, [null], []], [stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, double_field, [null], []], [stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, decimal_field, [null], []]]| + ||2 | |1970-01-01 00:00:00|0 |0 |3.14 |[[stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, string_field, [null], []], [stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, timestamp_field, [null], []], [stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, long_field, [null], []], [stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, double_field, [null], []], [stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, decimal_field, [null], []]]| + |+---+------------+-------------------+----------+------------+-------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val seq = Seq( + StructField("id", IntegerType, nullable = false), + StructField("string_field", StringType, nullable = false), + StructField("timestamp_field", TimestampType, nullable = false), + StructField("long_field", LongType, nullable = false), + StructField("double_field", IntegerType, nullable = false), + StructField("decimal_field", + DecimalType(20,2), + nullable = false, + new MetadataBuilder().putString(MetadataKeys.DefaultValue, "3.14").build()) + ) + val schema = StructType(seq) + val destDF = Standardization.standardize(sourceDataDF, schema) + + val actual = destDF.dataAsString(truncate = false) + assert(actual == expected) + } + + test("Cannot convert int to array, and array to long") { + val expected = + """+----+-------+--------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||id |letters|lettersB|errCol | + |+----+-------+--------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||null|null |0 |[[stdTypeError, E00006, Standardization Error - Type 'integer' cannot be cast to 'array', id, [], []], [stdTypeError, E00006, Standardization Error - Type 'array' cannot be cast to 'long', letters, [], []], [stdTypeError, E00006, Standardization Error - Type 'array' cannot be cast to 'long', letters, [], []]]| + ||null|null |0 |[[stdTypeError, E00006, Standardization Error - Type 'integer' cannot be cast to 'array', id, [], []], [stdTypeError, E00006, Standardization Error - Type 'array' cannot be cast to 'long', letters, [], []], [stdTypeError, E00006, Standardization Error - Type 'array' cannot be cast to 'long', letters, [], []]]| + |+----+-------+--------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val seq = Seq( + StructField("id", ArrayType(StringType), nullable = true), + StructField("letters", LongType, nullable = true), + StructField("lettersB", LongType, nullable = false, + new MetadataBuilder().putString(MetadataKeys.SourceColumn, "letters").build()) + ) + val schema = StructType(seq) + val destDF = Standardization.standardize(sourceDataDF, schema) + + val actual = destDF.dataAsString(truncate = false) + assert(actual == expected) + } + + test("Missing nullable fields are considered null") { + val expected = + """+---+------------+---------------+----------+------------+-------------+------+ + ||id |string_field|timestamp_field|long_field|double_field|decimal_field|errCol| + |+---+------------+---------------+----------+------------+-------------+------+ + ||1 |null |null |null |null |null |[] | + ||2 |null |null |null |null |null |[] | + |+---+------------+---------------+----------+------------+-------------+------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val seq = Seq( + StructField("id", IntegerType, nullable = true), + StructField("string_field", StringType, nullable = true), + StructField("timestamp_field", TimestampType, nullable = true), + StructField("long_field", LongType, nullable = true), + StructField("double_field", IntegerType, nullable = true), + StructField("decimal_field", DecimalType(20,4), nullable = true) + ) + val schema = StructType(seq) + val destDF = Standardization.standardize(sourceDataDF, schema) + + val actual = destDF.dataAsString(truncate = false) + assert(actual == expected) + } + + test("Cannot convert int to struct, and struct to long") { + val expected = + """|+----+------+-------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||id |struct|structB|errCol | + |+----+------+-------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||null|null |-1 |[[stdTypeError, E00006, Standardization Error - Type 'integer' cannot be cast to 'struct', id, [], []], [stdTypeError, E00006, Standardization Error - Type 'struct' cannot be cast to 'long', struct, [], []], [stdTypeError, E00006, Standardization Error - Type 'struct' cannot be cast to 'long', struct, [], []]]| + ||null|null |-1 |[[stdTypeError, E00006, Standardization Error - Type 'integer' cannot be cast to 'struct', id, [], []], [stdTypeError, E00006, Standardization Error - Type 'struct' cannot be cast to 'long', struct, [], []], [stdTypeError, E00006, Standardization Error - Type 'struct' cannot be cast to 'long', struct, [], []]]| + |+----+------+-------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val seq = Seq( + StructField("id", StructType(Seq(StructField("bar", BooleanType))), nullable = true), + StructField("struct", LongType, nullable = true), + StructField("structB", LongType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, "struct") + .putString(MetadataKeys.DefaultValue, "-1") + .build()) + ) + val schema = StructType(seq) + val destDF = Standardization.standardize(sourceDataDF, schema) + + val actual = destDF.dataAsString(truncate = false) + assert(actual == expected) + } + + test("Cannot convert array to struct, and struct to array") { + val expected = + """|+---+-------+------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||id |letters|struct|errCol | + |+---+-------+------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||1 |null |null |[[stdTypeError, E00006, Standardization Error - Type 'array' cannot be cast to 'struct', letters, [], []], [stdTypeError, E00006, Standardization Error - Type 'struct' cannot be cast to 'array', struct, [], []]]| + ||2 |null |null |[[stdTypeError, E00006, Standardization Error - Type 'array' cannot be cast to 'struct', letters, [], []], [stdTypeError, E00006, Standardization Error - Type 'struct' cannot be cast to 'array', struct, [], []]]| + |+---+-------+------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val seq = Seq( + StructField("id", LongType, nullable = true), + StructField("letters", StructType(Seq(StructField("bar", BooleanType))), nullable = true), + StructField("struct", ArrayType(StringType), nullable = true) + ) + val schema = StructType(seq) + val destDF = Standardization.standardize(sourceDataDF, schema) + + val actual = destDF.dataAsString(truncate = false) + assert(actual == expected) + } + + test("Cannot convert int to array, and array to long, fail fast") { + val seq = Seq( + StructField("id", ArrayType(StringType), nullable = true), + StructField("letters", LongType, nullable = true), + StructField("lettersB", LongType, nullable = false, + new MetadataBuilder().putString(MetadataKeys.SourceColumn, "letters").build()) + ) + val schema = StructType(seq) + + val exception = intercept[TypeParserException] { + Standardization.standardize(sourceDataDF, schema, configWithSchemaValidation) + } + assert(exception.getMessage == "Cannot standardize field 'id' from type integer into array") + } + + test("Cannot convert int to struct, and struct to long, fail fast") { + val seq = Seq( + StructField("id", StructType(Seq(StructField("bar", BooleanType))), nullable = true), + StructField("struct", LongType, nullable = true), + StructField("structB", LongType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, "struct") + .putString(MetadataKeys.DefaultValue, "-1") + .build()) + ) + val schema = StructType(seq) + + val exception = intercept[TypeParserException] { + Standardization.standardize(sourceDataDF, schema, configWithSchemaValidation) + } + assert(exception.getMessage == "Cannot standardize field 'id' from type integer into struct") + } + + test("Cannot convert array to struct, and struct to array, fail fast") { + val seq = Seq( + StructField("id", LongType, nullable = true), + StructField("letters", StructType(Seq(StructField("bar", BooleanType))), nullable = true), + StructField("struct", ArrayType(StringType), nullable = true) + ) + val schema = StructType(seq) + + val exception = intercept[TypeParserException] { + Standardization.standardize(sourceDataDF, schema, configWithSchemaValidation) + } + assert(exception.getMessage == "Cannot standardize field 'letters' from type array into struct") + } + + test("PseudoUuids are used") { + val expected = + """+---+-------+-------+------+-------------------+ + ||id |letters|struct |errCol|enceladus_record_id| + |+---+-------+-------+------+-------------------+ + ||1 |[A, B] |[false]|[] |1950798873 | + ||2 |[C] |[true] |[] |-988631025 | + |+---+-------+-------+------+-------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val seq = Seq( + StructField("id", LongType, nullable = false), + StructField("letters", ArrayType(StringType), nullable = false), + StructField("struct", StructType(Seq(StructField("bar", BooleanType))), nullable = false) + ) + val schema = StructType(seq) + // stableHashId will always yield the same ids + val destDF = Standardization.standardize(sourceDataDF, schema, stableIdConfig) + + val actual = destDF.dataAsString(truncate = false) + assert(actual == expected) + } + + test("True uuids are used") { + val expected = + """+---+-------+-------+------+ + ||id |letters|struct |errCol| + |+---+-------+-------+------+ + ||1 |[A, B] |[false]|[] | + ||2 |[C] |[true] |[] | + |+---+-------+-------+------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val seq = Seq( + StructField("id", LongType, nullable = false), + StructField("letters", ArrayType(StringType), nullable = false), + StructField("struct", StructType(Seq(StructField("bar", BooleanType))), nullable = false) + ) + val schema = StructType(seq) + val destDF = Standardization.standardize(sourceDataDF, schema, uuidConfig) + + // same except for the record id + val actual = destDF.drop("enceladus_record_id").dataAsString(truncate = false) + assert(actual == expected) + + val destIds = destDF.select('enceladus_record_id ).collect().map(_.getAs[String](0)).toSet + assert(destIds.size == 2) + destIds.foreach(UUID.fromString) // check uuid validity + + } + + test("Existing enceladus_record_id is kept") { + val expected = + """+---+-------+-------+-------------------+------+ + ||id |letters|struct |enceladus_record_id|errCol| + |+---+-------+-------+-------------------+------+ + ||1 |[A, B] |[false]|id1 |[] | + ||2 |[C] |[true] |id2 |[] | + |+---+-------+-------+-------------------+------+ + | + |""".stripMargin.replace("\r\n", "\n") + + import org.apache.spark.sql.functions.{concat, lit} + val sourceDfWithExistingIds = sourceDataDF.withColumn("enceladus_record_id", concat(lit("id"), 'id)) + + val seq = Seq( + StructField("id", LongType, nullable = false), + StructField("letters", ArrayType(StringType), nullable = false), + StructField("struct", StructType(Seq(StructField("bar", BooleanType))), nullable = false), + StructField("enceladus_record_id", StringType, nullable = false) + ) + val schema = StructType(seq) + val destDF = Standardization.standardize(sourceDfWithExistingIds, schema, uuidConfig) + + // The TrueUuids strategy does not override the existing values + val actual = destDF.dataAsString(truncate = false) + assert(actual == expected) + } + + test("Timestamp with timezone in metadata are shifted") { + /* This might seem confusing for a quick observer. The reason why this is the correct result: + the source data has two timestamps 12:00:00AM and 23:00:00PM *without* time zone. + The metadata then signal the timestamp are to be considered in CET time zone. The data are ingested with that + time zone and adjusted to system time zone - UTC. Therefore they are seemingly shifted by one hour. */ + val expected = + """+---+-------------------+------+ + ||id |ts |errCol| + |+---+-------------------+------+ + ||1 |1969-12-31 23:00:00|[] | + ||2 |1969-12-31 22:00:00|[] | + |+---+-------------------+------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val seq = Seq( + StructField("id", LongType, nullable = false), + StructField("ts", TimestampType, nullable = false, new MetadataBuilder().putString(MetadataKeys.DefaultTimeZone, "CET").build()) + ) + val schema = StructType(seq) + val destDF = Standardization.standardize(sourceDataDF, schema) + + val actual = destDF.dataAsString(truncate = false) + assert(actual == expected) + } +} diff --git a/src/test/scala/za/co/absa/standardization/TestSamples.scala b/src/test/scala/za/co/absa/standardization/TestSamples.scala new file mode 100644 index 0000000..0fa8068 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/TestSamples.scala @@ -0,0 +1,88 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization + +import java.sql.Timestamp +import java.sql.Date +import java.math.BigDecimal +import java.time.LocalDate +import java.util.TimeZone +import java.util.Calendar + +case class EmployeeNumber(numberType : String, numbers: Seq[String]) +case class EmployeeNumberStd(numberType : String, numbers: Seq[Int]) +case class Employee(name: String, surname: Option[String], hoursWorked: List[String], employeeNumbers: List[EmployeeNumber] = List(), startDate: String, updated: Option[Double] = None) +case class StdEmployee(name: String, surname: String, hoursWorked: Option[List[Int]], employeeNumbers: List[EmployeeNumberStd] = List(), startDate: java.sql.Date, updated: Option[Timestamp] = None, errCol: List[ErrorMessage]) + +case class Leg( + accruedInterestTxnCcy: BigDecimal, + carry: Option[BigDecimal] = None, cashRepCcy: Option[BigDecimal] = None, cashTxnCcy: Option[BigDecimal] = None, currencyName: String = null, currentRate: Option[BigDecimal] = None, + currentSpread: Option[BigDecimal] = None, dayCountMethod: String = null, endDate: Date = null, faceValueRepCcy: Option[BigDecimal] = None, faceValueTxnCcy: Option[BigDecimal] = None, + fixedRate: Option[BigDecimal] = None, floatRateSpread: Option[BigDecimal] = None, isPayLeg: String = null, floatRateReferenceName: String = null, + legFloatRateFactor: Option[Long] = None, legNumber: String = null, legStartDate: Date = null, legType: String = null, nominalRepCcy: Option[BigDecimal] = None, + nominalTxnCcy: Option[BigDecimal] = None, price: String = null, pvRepCcy: Option[BigDecimal] = None, pvTxnCcy: Option[BigDecimal] = None, repoRate: String = null, rollingPeriod: String = null) + +case class Leg1(leg: Seq[Leg]) + +case class StdTradeInstrument( + batchId: Option[Long] = None, requestId: String = null, + contractSize: Option[BigDecimal] = None, currencyName: String = null, + digital: Option[Boolean] = None, endDateTime: Timestamp = null, + expiryDate: Date = null, expiryTime: Timestamp = null, fxOptionType: String = null, + instrumentAddress: String = null, instrumentName: String = null, instrumentType: String = null, + isExpired: Option[Boolean] = null, + legs: Leg1 = null, optionExoticType: String = null, otc: Option[Boolean] = None, + payDayOffset: Option[Long] = None, payOffsetMethod: String = null, payType: String = null, quoteType: String = null, + realDividendValue: String = null, refValue: Option[Long] = None, referencePrice: Option[BigDecimal] = None, reportDate: Date = null, + settlementType: String = null, spotBankingDayOffset: Option[Long] = None, + startDate: Date = null, strikeCurrencyName: String = null, strikeCurrencyNumber: Option[Long] = None, tradeId: String = null, tradeNumber: String = null, + txnMaturityPeriod: String = null, underlyingInstruments: String = null, valuationGroupName: String = null, + versionId: String = null, errCol: Seq[ErrorMessage] = Seq()) + +case class DateTimestampData( + id: Long, + dateSampleOk1: String, dateSampleOk2: String, dateSampleOk3: String, + dateSampleWrong1: String, dateSampleWrong2: String, dateSampleWrong3: String, + timestampSampleOk1: String, timestampSampleOk2: String, timestampSampleOk3: String, + timestampSampleWrong1: String, timestampSampleWrong2: String, timestampSampleWrong3: String) + +object TestSamples { + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + + val john0 = Employee(name = "John0",surname = None, hoursWorked = List("8", "7", "8", "9", "12", null), employeeNumbers = List(EmployeeNumber("SAP", List("456", "123")), EmployeeNumber("WD", List("00005"))), startDate = "2015-08-01") + val john1 = Employee(name = "John1", surname = Some("Doe1"), hoursWorked = List("99", "99", "76", "12", "12", "24"), startDate = "Two Thousand Something") + val john2 = Employee(name = "John2", surname = None, hoursWorked = null, startDate = "2015-08-01", updated = Some(20150716.133224)) + val john3 = Employee(name = "John3", surname = None, hoursWorked = List(), startDate = "2015-08-01", updated = Some(20150716.103224)) + + val data1 = List(john0, john1, john2, john3) + + val startDate = 1438387200000l //01/08/2015 + + val resData = List( + StdEmployee(name = "John0", surname = "Unknown Surname", hoursWorked = Some(List(8, 7, 8, 9, 12, 0)), + employeeNumbers = List(EmployeeNumberStd("SAP", List(456, 123)), EmployeeNumberStd("WD", List(5))), startDate = new java.sql.Date(startDate), errCol = List(ErrorMessage.stdNullErr("surname"), ErrorMessage.stdNullErr("hoursWorked[*]"))), + StdEmployee(name = "John1", surname = "Doe1", hoursWorked = Some(List(99, 99, 76, 12, 12, 24)), startDate = new java.sql.Date(0), errCol = List(ErrorMessage.stdCastErr("startDate", "Two Thousand Something"))), + StdEmployee(name = "John2", surname = "Unknown Surname", hoursWorked = None, startDate = new java.sql.Date(startDate), updated = Some(Timestamp.valueOf("2015-07-16 13:32:24")), errCol = List(ErrorMessage.stdNullErr("surname"), ErrorMessage.stdNullErr("hoursWorked"))), + StdEmployee(name = "John3", surname = "Unknown Surname", hoursWorked = Some(List()), startDate = new java.sql.Date(startDate), updated = Some(Timestamp.valueOf("2015-07-16 10:32:24")), errCol = List(ErrorMessage.stdNullErr("surname")))) + + val dateSamples = List(DateTimestampData( + 1, + "20-10-2017", "2017-10-20", "12/29/2017", + "10-20-2017", "201711", "", + "2017-10-20T08:11:31", "2017-10-20_08:11:31", "20171020081131", + "20171020T081131", "2017-10-20t081131", "2017-10-20")) +} diff --git a/src/test/scala/za/co/absa/standardization/implicits/DataFrameImplicitsSuite.scala b/src/test/scala/za/co/absa/standardization/implicits/DataFrameImplicitsSuite.scala new file mode 100644 index 0000000..735f388 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/implicits/DataFrameImplicitsSuite.scala @@ -0,0 +1,157 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.implicits + +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.SparkTestBase +import za.co.absa.standardization.implicits.DataFrameImplicits.DataFrameEnhancements + +class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { + import spark.implicits._ + + private val columnName = "data" + private val inputDataSeq = Seq( + "0123456789012345678901234", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z" + ) + private val inputData = inputDataSeq.toDF(columnName) + + private def cellText(text: String, width: Int, leftAlign: Boolean): String = { + val pad = " " * (width - text.length) + if (leftAlign) { + text + pad + } else { + pad + text + } + } + + private def line(width: Int): String = { + "+" + "-" * width + "+" + } + + private def header(width: Int, leftAlign: Boolean): String = { + val lineStr = line(width) + val title = cellText(columnName, width, leftAlign) + s"$lineStr\n|$title|\n$lineStr" + } + + private def cell(text: String, width: Int, leftAlign: Boolean): String = { + val inner = if (text.length > width) { + text.substring(0, width - 3) + "..." + } else { + cellText(text, width, leftAlign) + } + s"|$inner|" + } + + private def inputDataToString(width: Int, leftAlign: Boolean, limit: Option[Int] = Option(20)): String = { + val (extraLine, seq) = limit match { + case Some(n) => + val line = if (inputDataSeq.length > n) { + s"only showing top $n rows\n" + } else { + "" + } + (line, inputDataSeq.take(n)) + case None => + ("", inputDataSeq) + } + seq.foldLeft(header(width, leftAlign)) { (acc, item) => + acc + "\n" + cell(item, width, leftAlign) + } + "\n" + line(width) + s"\n$extraLine\n" + } + + test("Like show()") { + val result = inputData.dataAsString() + val leftAlign = false + val cellWidth = 20 + val expected = inputDataToString(cellWidth, leftAlign) + + assert(result == expected) + } + + test("Like show(false)") { + val result = inputData.dataAsString(false) + val leftAlign = true + val cellWidth = 25 + val expected = inputDataToString(cellWidth, leftAlign) + + assert(result == expected) + } + + test("Like show(3, true)") { + val result = inputData.dataAsString(3, true) + val leftAlign = false + val cellWidth = 20 + val expected = inputDataToString(cellWidth, leftAlign, Option(3)) + + assert(result == expected) + } + + test("Like show(30, false)") { + val result = inputData.dataAsString(30, false) + val leftAlign = true + val cellWidth = 25 + val expected = inputDataToString(cellWidth, leftAlign, Option(30)) + + assert(result == expected) + } + + + test("Like show(10, 10)") { + val result = inputData.dataAsString(10, 10) + val leftAlign = false + val cellWidth = 10 + val expected = inputDataToString(cellWidth, leftAlign, Option(10)) + + assert(result == expected) + } + + test("Like show(50, 50, false)") { + val result = inputData.dataAsString(50, 50, false) + val leftAlign = false + val cellWidth = 25 + val expected = inputDataToString(cellWidth, leftAlign, Option(50)) + + assert(result == expected) + } +} diff --git a/src/test/scala/za/co/absa/standardization/implicits/StringImplicitsSuite.scala b/src/test/scala/za/co/absa/standardization/implicits/StringImplicitsSuite.scala new file mode 100644 index 0000000..efe7e44 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/implicits/StringImplicitsSuite.scala @@ -0,0 +1,310 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.implicits + +import org.scalatest.funsuite.AnyFunSuite + +import org.scalatest.matchers.should.Matchers +import za.co.absa.standardization.implicits.StringImplicits.StringEnhancements +import java.security.InvalidParameterException + +class StringImplicitsSuite extends AnyFunSuite with Matchers { + test("StringEnhancements.replaceChars - empty replacements") { + val s = "supercalifragilisticexpialidocious" + assert(s.replaceChars(Map.empty) == s) + } + + test("StringEnhancements.replaceChars - no hits") { + val s = "supercalifragilisticexpialidocious" + val map = Map('1' -> '5', '2' -> '6', '3' -> '7') + assert(s.replaceChars(map) == s) + } + + test("StringEnhancements.replaceChars - replace all to same char") { + val s: String = "abcba" + val map = Map('a' -> 'x', 'b' -> 'x', 'c' -> 'x', 'd' -> 'x') + assert(s.replaceChars(map) == "xxxxx") + } + + test("StringEnhancements.replaceChars - swap characters") { + val s: String = "abcba" + val map = Map('a' -> 'b', 'b' -> 'a') + assert(s.replaceChars(map) == "bacab") + } + + test("StringEnhancements.findFirstUnquoted - empty string") { + var result = "".findFirstUnquoted(Set.empty, Set.empty) + assert(result.isEmpty) + result = "".findFirstUnquoted(Set('a'), Set.empty) + assert(result.isEmpty) + result = "".findFirstUnquoted(Set('a', 'b', 'c'), Set(''')) + assert(result.isEmpty) + result = "".findFirstUnquoted(Set('a', 'b', 'c'), Set(''', '"')) + assert(result.isEmpty) + } + + test("StringEnhancements.findFirstUnquoted - no quotes") { + var result = "Hello world".findFirstUnquoted(Set('x', 'y', 'z'), Set.empty) + assert(result.isEmpty) + result = "Hello world".findFirstUnquoted(Set('w'), Set.empty) + assert(result.contains(6)) + result = "Hello world".findFirstUnquoted(Set('w', 'e', 'l'), Set.empty) + assert(result.contains(1)) + } + + test("StringEnhancements.findFirstUnquoted - simple quotes") { + val quotes = Set(''') + var result = "Hello 'world'".findFirstUnquoted(Set('w'), quotes) + assert(result.isEmpty) + result = "Hello 'world'".findFirstUnquoted(Set('w', 'e', 'l'), quotes) + assert(result.contains(1)) + result = "'Hello' world".findFirstUnquoted(Set('w', 'e', 'l'), quotes) + assert(result.contains(8)) + } + + test("StringEnhancements.findFirstUnquoted - multiple quotes") { + val charsToFind = Set('w', 'e', 'l') + val quotes = Set(''', '`') + var result = "`Hello` 'world'".findFirstUnquoted(charsToFind, quotes) + assert(result.isEmpty) + result = "`Hello` 'wor'ld".findFirstUnquoted(charsToFind, quotes) + assert(result.contains(13)) + result = "`Hel'lo` 'wor'ld".findFirstUnquoted(charsToFind, quotes) + assert(result.contains(14)) + } + + test("StringEnhancements.findFirstUnquoted - using escape character") { + val charsToFind = Set('w', 'e', 'l') + val quotes = Set(''', '`') + var result = "`Hello` \\'world".findFirstUnquoted(charsToFind, quotes) //hasn't started + assert(result.contains(10)) + result = "`H\\`ello` 'wor'ld".findFirstUnquoted(charsToFind, quotes) //hasn't ended + assert(result.contains(15)) + result = "`Hello\\`` 'wor'ld".findFirstUnquoted(charsToFind, quotes) //escaped followed by unescaped + assert(result.contains(15)) + result = "\\ `Hello` 'world'".findFirstUnquoted(charsToFind, quotes) //escape elsewhere + assert(result.isEmpty) + result = "H\\e\\l\\lo \\wor\\ld'".findFirstUnquoted(charsToFind, quotes) //hits escaped + assert(result.isEmpty) + } + + test("StringEnhancements.findFirstUnquoted - quote between search characters") { + val charsToFind = Set('w', 'e', 'l', ''') + val quotes = Set(''', '`') + var result = "Hello \\'world".findFirstUnquoted(charsToFind, quotes) //simple + assert(result.contains(1)) + result = "`Hello` \\'world".findFirstUnquoted(charsToFind, quotes) //quote hit + assert(result.contains(9)) + result = "`Hello` 'world'".findFirstUnquoted(charsToFind, quotes) //just quotes + assert(result.isEmpty) + result = "`Hello\\'` 'world'".findFirstUnquoted(charsToFind, quotes) //within other quotes + assert(result.isEmpty) + result = "`Hello` '\\'world'".findFirstUnquoted(charsToFind, quotes) //within same quotes + assert(result.isEmpty) + } + + test("StringEnhancements.findFirstUnquoted - custom escape character") { + val charsToFind = Set('w', 'e', 'l') + val quotes = Set(''', '`') + val escapeChar = '~' + var result = "`Hello` ~'world".findFirstUnquoted(charsToFind, quotes, escapeChar) //hasn't started + assert(result.contains(10)) + result = "`H~`ello` 'wor'ld".findFirstUnquoted(charsToFind, quotes, escapeChar) //hasn't ended + assert(result.contains(15)) + result = "`Hello~`` 'wor'ld".findFirstUnquoted(charsToFind, quotes, escapeChar) //escaped followed by unescaped + assert(result.contains(15)) + result = "~ `Hello` 'world'".findFirstUnquoted(charsToFind, quotes, escapeChar) //escape elsewhere + assert(result.isEmpty) + result = "`Hello~`` 'wor'\\ld".findFirstUnquoted(charsToFind, quotes, escapeChar) //mix-in standard escape + assert(result.contains(16)) + } + + test("StringEnhancements.findFirstUnquoted - many escapes") { //better to do with other then \ + val charsToFind = Set('w') + val quotes = Set(''') + val escapeChar = '~' + var result = "Hello ~~world'".findFirstUnquoted(charsToFind, quotes, escapeChar) //escaped escape -> hit valid + assert(result.contains(8)) + result = "Hello ~~'world'".findFirstUnquoted(charsToFind, quotes, escapeChar) //escaped escape -> quotes valid + assert(result.isEmpty) + result = "Hello ~~~world'".findFirstUnquoted(charsToFind, quotes, escapeChar) //3x -> hit escaped + assert(result.isEmpty) + result = "Hello ~~~'world'".findFirstUnquoted(charsToFind, quotes, escapeChar) //3x -> quote escaped + assert(result.contains(10)) + result = "'Hello ~~~~~'world'".findFirstUnquoted(charsToFind, quotes, escapeChar) //5x -> quote escaped, whole string quoted + assert(result.isEmpty) + } + + test("StringEnhancements.findFirstUnquoted - escape in search chars") { //better to do with other then \ + val escapeChar = '~' + val quotes = Set(''') + val charsToFind = Set('w', escapeChar) + var result = "Hello ~~world'".findFirstUnquoted(charsToFind, quotes, escapeChar) //escaped escape -> hit valid + assert(result.contains(7)) + result = "Hello '~~world'".findFirstUnquoted(charsToFind, quotes, escapeChar) //escaped escape in quotes + assert(result.isEmpty) + result = "Hello ~'~~world".findFirstUnquoted(charsToFind, quotes, escapeChar) //escaped quote + assert(result.contains(9)) + result = "Hello ~world~~".findFirstUnquoted(charsToFind, quotes, escapeChar) //escaped normal hit, escaped escape follows + assert(result.contains(13)) + } + + test("StringEnhancements.findFirstUnquoted - escape in quote chars") { //better to do with other then \ + val escapeChar = '~' + val quotes = Set(''', escapeChar) + val charsToFind = Set('w', 'e', 'l') + var result = "~'Hello world'".findFirstUnquoted(charsToFind, quotes, escapeChar) //simple escape + assert(result.contains(3)) + result = "~~Hello ~~pole".findFirstUnquoted(charsToFind, quotes, escapeChar) //escape as quotes + assert(result.contains(12)) + result = "~~Hello ~~'pole'".findFirstUnquoted(charsToFind, quotes, escapeChar) //escape as quotes followed by standard quotes + assert(result.isEmpty) + result = "~~Hello ~~world".findFirstUnquoted(charsToFind, quotes, escapeChar) //escape as quotes directly followed by hit + assert(result.contains(10)) + result = "~~Hello ~~~world".findFirstUnquoted(charsToFind, quotes, escapeChar) //escape as quotes and right after escaped hit + assert(result.contains(14)) + } + + test("StringEnhancements.findFirstUnquoted - escape in search and quote chars") { //better to do with other then \ + val escapeChar = '!' + val quotes = Set(''', escapeChar) + val charsToFind = Set('w', 'e', 'l', escapeChar) + val expectedMessage = s"Escape character '$escapeChar 'is both between charsToFind and quoteChars. That's not allowed." + val caught = intercept[InvalidParameterException] { + "All the jewels of the world!".findFirstUnquoted(charsToFind, quotes, escapeChar) + } + assert(caught.getMessage == expectedMessage) + } + + test("StringEnhancements.hasUnquoted") { + assert(!"".hasUnquoted(Set.empty, Set.empty)) + assert(!"Hello world".hasUnquoted(Set('x'), Set.empty)) + assert("Hello world".hasUnquoted(Set('w', 'e', 'l'), Set('`'))) + assert(!"`Hello world`".hasUnquoted(Set('w', 'e', 'l'), Set('`'))) + } + + test("StringEnhancements.countUnquoted: empty variants") { + val expected = Map( + 'x'->0, + 'y'->0, + 'z'->0 + ) + val charsToFind = Set('x', 'y', 'z') + val empty = Set.empty[Char] + assert("".countUnquoted(charsToFind, Set('"')) == expected) + assert("Lorem ipsum".countUnquoted(charsToFind, empty) == expected) + assert("Hello world".countUnquoted(empty, Set('|')) == Map.empty) + } + + test("StringEnhancements.countUnquoted: simple test") { + val charsToFind = Set('x', 'y', 'z') + val expected1 = Map( + 'x'->0, + 'y'->0, + 'z'->0 + ) + assert("Lorem ipsum".countUnquoted(charsToFind, Set(''')) == expected1) + assert("Hello 'xyz' world".countUnquoted(charsToFind, Set(''')) == expected1) + val expected2 = Map( + 'x'->3, + 'y'->2, + 'z'->1 + ) + assert("xxxyzy".countUnquoted(charsToFind, Set('-')) == expected2) + val expected3 = Map( + 'x'->1, + 'y'->2, + 'z'->5 + ) + assert("x-xxy-yyzzz|zyyy|zz".countUnquoted(charsToFind, Set('-', '|')) == expected3) + } + + test("StringEnhancements.countUnquoted: escape involved") { + val charsToFind = Set('x', 'y', 'z') + val expected = Map( + 'x'->3, + 'y'->2, + 'z'->3 + ) + assert("x~yz~'xxyyzz 'xxxx~'zzzz'~''yyyy".countUnquoted(charsToFind, Set('''),'~') == expected) + } + + test("StringEnhancements.countUnquoted: search and quote chars overlap") { + val charsToFind = Set('a', '#', '$', '%') + val quoteChars = Set('$', '%', '^') + val expected = Map( + 'a'->0, + '#'->1, + '$'->1, + '%'->0 + ) + assert("#^##^|$%%%%".countUnquoted(charsToFind, quoteChars, '|') == expected) + } + + test("StringEnhancements.countUnquoted: escape in search for chars") { + val charsToFind = Set('a', 'b', 'c', 'd', '|') + val quoteChars = Set('%', '^') + val expected = Map( + 'a'->2, + 'b'->0, + 'c'->2, + 'd'->0, + '|'->1 + ) + assert("aa||%bb%|^cc^a|cd||d|^b^".countUnquoted(charsToFind, quoteChars, '|') == expected) + } + + test("StringEnhancements.countUnquoted: escape in quote chars") { + val charsToFind = Set('a', 'b', 'c') + val quoteChars = Set('$', '%', '^') + val expected = Map( + 'a'->1, + 'b'->0, + 'c'->0 + ) + assert("a$$bc$$".countUnquoted(charsToFind, quoteChars, '$') == expected) + } + + test("string joining general") { + "abc#".joinWithSingleSeparator("#def", "#") shouldBe "abc#def" + "abc###".joinWithSingleSeparator("def", "#") shouldBe "abc###def" + "abcSEP".joinWithSingleSeparator("def", "SEP") shouldBe "abcSEPdef" + "abcSEPSEP".joinWithSingleSeparator("SEPSEPdef", "SEP") shouldBe "abcSEPSEPSEPdef" + } + + test("string joining with /") { + "abc" / "123" shouldBe "abc/123" + "aaa/" / "123" shouldBe "aaa/123" + "bbb" / "/123" shouldBe "bbb/123" + "ccc/" / "/123" shouldBe "ccc/123" + "file:///" / "path" shouldBe "file:///path" + } + + test("getOrElse") { + "a".nonEmpyOrElse("b") shouldBe "a" + "".nonEmpyOrElse("b") shouldBe "b" + "a".nonEmpyOrElse("") shouldBe "a" + } + + test("coalesce") { + "".coalesce() shouldBe "" + "".coalesce("A", "") shouldBe "A" + "".coalesce("", "", "B", "", "C") shouldBe "B" + "X".coalesce("Y", "Z") shouldBe "X" + "X".coalesce("") shouldBe "X" + } +} diff --git a/src/test/scala/za/co/absa/standardization/implicits/StructFieldImplicitsSuite.scala b/src/test/scala/za/co/absa/standardization/implicits/StructFieldImplicitsSuite.scala new file mode 100644 index 0000000..e9de355 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/implicits/StructFieldImplicitsSuite.scala @@ -0,0 +1,67 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.implicits + +import org.apache.spark.sql.types.{DataTypes, Metadata, StructField} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + + +class StructFieldImplicitsSuite extends AnyFlatSpec with Matchers { + + val metadata1 = Metadata.fromJson( + """ + |{ + | "pattern" : "yyyy-MM-dd", + | "default" : "2020-02-02", + | "singleCharThingy" : "s", + | "isAwesome" : "true", + | "hasChildren" : false + |} + |""".stripMargin + ) + val structField1 = StructField("myField1", DataTypes.DateType, nullable = true, metadata1) + + import StructFieldImplicits._ + + "StructFieldEnhancements" should "getMetadataString for (non)existing key" in { + structField1.getMetadataString("pattern") shouldBe Some("yyyy-MM-dd") + structField1.getMetadataString("PaTTerN") shouldBe None // case sensitive + structField1.getMetadataString("somethingElse") shouldBe None + } + + it should "getMetadataChar correctly" in { + structField1.getMetadataChar("singleCharThingy") shouldBe Some('s') + structField1.getMetadataChar("default") shouldBe None + structField1.getMetadataChar("somethingElse") shouldBe None + } + + it should "getMetadataStringAsBoolean correctly" in { + structField1.getMetadataStringAsBoolean("pattern") shouldBe None + structField1.getMetadataStringAsBoolean("isAwesome") shouldBe Some(true) + structField1.getMetadataStringAsBoolean("hasChildren") shouldBe None // interesting: metadata is always string-first + structField1.getMetadataStringAsBoolean("somethingElse") shouldBe None + } + + it should "hasMetadataKey" in { + structField1.hasMetadataKey("default") shouldBe true + structField1.hasMetadataKey("DeFAuLT") shouldBe false // case sensitive + structField1.hasMetadataKey("somethingElse") shouldBe false + + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/CounterPartySuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/CounterPartySuite.scala new file mode 100644 index 0000000..4a1029c --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/CounterPartySuite.scala @@ -0,0 +1,61 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.{ErrorMessage, LoggerTestBase, SparkTestBase, Standardization} +import za.co.absa.standardization.udf.UDFLibrary + +case class Root(ConformedParty: Party, errCol: Seq[ErrorMessage] = Seq.empty) +case class Party(key: Integer, clientKeys1: Seq[String], clientKeys2: Seq[String]) + +class CounterPartySuite extends AnyFunSuite with SparkTestBase with LoggerTestBase { + + private implicit val defaults: Defaults = GlobalDefaults + + test("Mimic running standardization twice on counter party") { + import spark.implicits._ + + val desiredSchema = StructType(Seq(StructField("ConformedParty", StructType( + Seq( + StructField("key", IntegerType, nullable = true), + StructField("clientKeys1", ArrayType(StringType, containsNull = true), nullable = true) + , + StructField("clientKeys2", ArrayType(StringType, containsNull = true), nullable = true) + + )), nullable = true))) + + implicit val udfLib: UDFLibrary = new UDFLibrary + + val input = spark.createDataFrame(Seq( + Root(Party(key = 0, clientKeys1 = Seq("a", "b", "c"), clientKeys2 = Seq("d", "e", "f"))), + Root(Party(1, Seq("d"), Seq("e"))), + Root(Party(2, Seq("f"), Seq())), + Root(Party(3, Seq(), Seq())), + Root(Party(4, null, Seq())), + Root(Party(5, Seq(), null)), + Root(Party(6, null, null)))) + + val std = Standardization.standardize(input, desiredSchema).cache() + + logDataFrameContent(std) + + assertResult(input.as[Root].collect.toList)(std.as[Root].collect().sortBy(_.ConformedParty.key).toList) + } +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/DateTimeSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/DateTimeSuite.scala new file mode 100644 index 0000000..b6d2490 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/DateTimeSuite.scala @@ -0,0 +1,107 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import java.sql.{Date, Timestamp} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.stages.SchemaChecker +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary +import za.co.absa.standardization.validation.field.FieldValidationIssue +import za.co.absa.standardization.{ErrorMessage, FileReader, LoggerTestBase, SchemaValidator, SparkTestBase, Standardization, TestSamples, ValidationError, ValidationException, ValidationWarning} + + +class DateTimeSuite extends AnyFunSuite with SparkTestBase with LoggerTestBase { + import spark.implicits._ + + private implicit val defaults: Defaults = GlobalDefaults + + lazy val data: DataFrame = spark.createDataFrame(TestSamples.dateSamples) + lazy val schemaWrong: StructType = DataType + .fromJson(FileReader.readFileAsString("src/test/resources/data/dateTimestampSchemaWrong.json")) + .asInstanceOf[StructType] + lazy val schemaOk: StructType = DataType + .fromJson(FileReader.readFileAsString("src/test/resources/data/dateTimestampSchemaOk.json")) + .asInstanceOf[StructType] + + private implicit val udfLib: UDFLibrary = new UDFLibrary() + + test("Validation should return critical errors") { + logger.debug(data.schema.prettyJson) + val validationErrors = SchemaValidator.validateSchema(schemaWrong) + val exp = List( + FieldValidationIssue("dateSampleWrong1", "DD-MM-yyyy", List( + ValidationWarning("No day placeholder 'dd' found."), + ValidationWarning("Rarely used DayOfYear placeholder 'D' found. Possibly DayOfMonth 'd' intended."))), + FieldValidationIssue("dateSampleWrong2", "Dy", List( + ValidationWarning("No day placeholder 'dd' found."), + ValidationWarning("Rarely used DayOfYear placeholder 'D' found. Possibly DayOfMonth 'd' intended."), + ValidationWarning("No month placeholder 'MM' found."))), + FieldValidationIssue("dateSampleWrong3", "rrr", List( + ValidationError("Illegal pattern character 'r'"))), + FieldValidationIssue("timestampSampleWrong1", "yyyyMMddTHHmmss", List( + ValidationError("Illegal pattern character 'T'"))), + FieldValidationIssue("timestampSampleWrong3", "yyyy-MM-dd", List( + ValidationWarning("No hour placeholder 'HH' found."), + ValidationWarning("No minute placeholder 'mm' found."), + ValidationWarning("No second placeholder 'ss' found."))), + FieldValidationIssue("timestampNullDefaultWrong", "", List( + ValidationError("null is not a valid value for field 'timestampNullDefaultWrong'"))) + ) + assert(validationErrors == exp) + } + + test("Validation for this data should return critical errors") { + val errors = SchemaChecker.validateSchemaAndLog(schemaWrong) + assert(errors._1.nonEmpty) + } + + test("Date Time Standardization Example Test should throw an exception") { + intercept[ValidationException] { + Standardization.standardize(data, schemaWrong) + } + } + + test("Date Time Standardization Example with fixed schema should work") { + val date0 = new Date(0) + val ts = Timestamp.valueOf("2017-10-20 08:11:31") + val ts0 = new Timestamp(0) + val exp = List(( + 1L, + Date.valueOf("2017-10-20"), + Date.valueOf("2017-10-20"), + Date.valueOf("2017-12-29"), + date0, + date0, + null, + ts, ts, ts, null, ts0, ts0, + List( + ErrorMessage.stdCastErr("dateSampleWrong1","10-20-2017"), + ErrorMessage.stdCastErr("dateSampleWrong2","201711"), + ErrorMessage.stdCastErr("dateSampleWrong3",""), + ErrorMessage.stdCastErr("timestampSampleWrong1", "20171020T081131"), + ErrorMessage.stdCastErr("timestampSampleWrong2", "2017-10-20t081131"), + ErrorMessage.stdCastErr("timestampSampleWrong3", "2017-10-20") + ) + )) + val std: Dataset[Row] = Standardization.standardize(data, schemaOk) + logDataFrameContent(std) + assertResult(exp)(std.as[Tuple14[Long, Date, Date, Date, Date, Date, Date, Timestamp, Timestamp, Timestamp, Timestamp, Timestamp,Timestamp, Seq[ErrorMessage]]].collect().toList) + } +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/SampleDataSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/SampleDataSuite.scala new file mode 100644 index 0000000..334edb0 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/SampleDataSuite.scala @@ -0,0 +1,47 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import org.apache.spark.sql.types.{DataType, StructType} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary +import za.co.absa.standardization.{FileReader, LoggerTestBase, SparkTestBase, Standardization, StdEmployee, TestSamples} + +class SampleDataSuite extends AnyFunSuite with SparkTestBase with LoggerTestBase { + private implicit val defaults: Defaults = GlobalDefaults + + test("Simple Example Test") { + import spark.implicits._ + val data = spark.createDataFrame(TestSamples.data1) + + logDataFrameContent(data) + + implicit val udfLib: UDFLibrary = new UDFLibrary() + + val sourceFile = FileReader.readFileAsString("src/test/resources/data/data1Schema.json") + val schema = DataType.fromJson(sourceFile).asInstanceOf[StructType] + val std = Standardization.standardize(data, schema) + logDataFrameContent(std) + val stdList = std.as[StdEmployee].collect.sortBy(_.name).toList + val exp = TestSamples.resData.sortBy(_.name) + + assertResult(exp)(stdList) + + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreterSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreterSuite.scala new file mode 100644 index 0000000..fe33732 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreterSuite.scala @@ -0,0 +1,375 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary +import za.co.absa.standardization.{ErrorMessage, FileReader, JsonUtils, LoggerTestBase, SparkTestBase, Standardization} + +object StandardizationInterpreterSuite { + + case class ErrorPreserve(a: String, b: String, errCol: List[ErrorMessage]) + case class ErrorPreserveStd(a: String, b: Int, errCol: List[ErrorMessage]) + + case class MyWrapper(counterparty: MyHolder) + case class MyHolder(yourRef: String) + case class MyWrapperStd(counterparty: MyHolder, errCol: Seq[ErrorMessage]) + + case class Time(id: Int, date: String, timestamp: String) + case class StdTime(id: Int, date: Date, timestamp: Timestamp, errCol: List[ErrorMessage]) + + case class subCC(subFieldA: Integer, subFieldB: String) + case class sub2CC(subSub2FieldA: Integer, subSub2FieldB: String) + case class sub1CC(subStruct2: sub2CC) + case class subarrayCC(arrayFieldA: Integer, arrayFieldB: String, arrayStruct: subCC) + case class rootCC(rootField: String, rootStruct: subCC, rootStruct2: sub1CC, rootArray: Array[subarrayCC]) + + // used by the last test: + // cannot use case class as the field names contain spaces therefore cast will happen into tuple + type BodyStats = (Int, Int, (String, Option[Boolean]), Seq[Double]) + type PatientRow = (String, String, BodyStats, Seq[ErrorMessage]) + + object BodyStats { + def apply(height: Int, + weight: Int, + eyeColor: String, + glasses: Option[Boolean], + temperatureMeasurements: Seq[Double] + ): BodyStats = (height, weight, (eyeColor, glasses), temperatureMeasurements) + } + + object PatientRow { + def apply(first_name: String, + lastName: String, + bodyStats: BodyStats, + errCol: Seq[ErrorMessage] = Seq.empty + ): PatientRow = (first_name, lastName, bodyStats, errCol) + } +} + +class StandardizationInterpreterSuite extends AnyFunSuite with SparkTestBase with LoggerTestBase { + import StandardizationInterpreterSuite._ + import spark.implicits._ + private implicit val defaults: Defaults = GlobalDefaults + private implicit val udfLib: UDFLibrary = new UDFLibrary + + private val stdExpectedSchema = StructType( + Seq( + StructField("rootField", StringType, nullable = true), + StructField("rootStruct", + StructType( + Seq( + StructField("subFieldA", IntegerType, nullable = true), + StructField("subFieldB", StringType, nullable = true))), nullable = false), + StructField("rootStruct2", + StructType( + Seq( + StructField("subStruct2", + StructType( + Seq( + StructField("subSub2FieldA", IntegerType, nullable = true), + StructField("subSub2FieldB", StringType, nullable = true))), nullable = false))), nullable = false), + StructField("rootArray", + ArrayType( + StructType( + Seq( + StructField("arrayFieldA", IntegerType, nullable = true), + StructField("arrayFieldB", StringType, nullable = true), + StructField("arrayStruct", + StructType( + Seq( + StructField("subFieldA", IntegerType, nullable = true), + StructField("subFieldB", StringType, nullable = true))), nullable = false))), containsNull = false + )))) + + test("Non-null errors produced for non-nullable attribute in a struct") { + val orig = spark.createDataFrame(Seq( + MyWrapper(MyHolder(null)), + MyWrapper(MyHolder("447129")))) + + val exp = Seq( + MyWrapperStd(MyHolder(""), Seq(ErrorMessage.stdNullErr("counterparty.yourRef"))), + MyWrapperStd(MyHolder("447129"), Seq())) + + val schema = StructType(Seq( + StructField("counterparty", StructType( + Seq( + StructField("yourRef", StringType, nullable = false))), nullable = false))) + + val standardizedDF = Standardization.standardize(orig, schema) + + assertResult(exp)(standardizedDF.as[MyWrapperStd].collect().toList) + } + + test("Existing error messages should be preserved") { + val df = spark.createDataFrame(Array( + ErrorPreserve("a", "1", null), + ErrorPreserve("b", "2", List()), + ErrorPreserve("c", "3", List(new ErrorMessage("myErrorType", "E-1", "Testing This stuff", "whatEvColumn", Seq("some value")))), + ErrorPreserve("d", "abc", List(new ErrorMessage("myErrorType2", "E-2", "Testing This stuff blabla", "whatEvColumn2", Seq("some other value")))))) + + val exp = Array( + ErrorPreserveStd("a", 1, List()), + ErrorPreserveStd("b", 2, List()), + ErrorPreserveStd("c", 3, List(new ErrorMessage("myErrorType", "E-1", "Testing This stuff", "whatEvColumn", Seq("some value")))), + ErrorPreserveStd("d", 0, List(ErrorMessage.stdCastErr("b", "abc"), + new ErrorMessage("myErrorType2", "E-2", "Testing This stuff blabla", "whatEvColumn2", Seq("some other value"))))) + + val expSchema = spark.emptyDataset[ErrorPreserveStd].schema + val res = Standardization.standardize(df, expSchema) + + assertResult(exp.sortBy(_.a).toList)(res.as[ErrorPreserveStd].collect().sortBy(_.a).toList) + } + + test("Standardize Test") { + val sourceDF = spark.createDataFrame( + Array( + rootCC("rootfieldval", + subCC(123, "subfieldval"), + sub1CC(sub2CC(456, "subsubfieldval")), + Array(subarrayCC(789, "arrayfieldval", subCC(321, "xyz")))))) + + val expectedSchema = stdExpectedSchema.add( + StructField("errCol", + ArrayType( + ErrorMessage.errorColSchema, containsNull = false))) + + val standardizedDF = Standardization.standardize(sourceDF, stdExpectedSchema) + + logger.debug(standardizedDF.schema.treeString) + logger.debug(expectedSchema.treeString) + + assert(standardizedDF.schema.treeString === expectedSchema.treeString) + } + + test("Standardize Test (JSON source)") { + val sourceDF = spark.read.json("src/test/resources/data/standardizeJsonSrc.json") + + val expectedSchema = stdExpectedSchema.add( + StructField("errCol", + ArrayType( + ErrorMessage.errorColSchema, containsNull = false))) + + val standardizedDF = Standardization.standardize(sourceDF, stdExpectedSchema) + + logger.debug(standardizedDF.schema.treeString) + logger.debug(expectedSchema.treeString) + + assert(standardizedDF.schema.treeString === expectedSchema.treeString) + } + + case class OrderCC(orderName: String, deliverName: Option[String]) + case class RootRecordCC(id: Long, name: Option[String], orders: Option[Array[OrderCC]]) + + test("Test standardization of non-nullable field of a contains null array") { + val schema = StructType( + Array( + StructField("id", LongType, nullable = false), + StructField("name", StringType, nullable = true), + StructField("orders", ArrayType(StructType(Array( + StructField("orderName", StringType, nullable = false), + StructField("deliverName", StringType, nullable = true))), containsNull = true), nullable = true))) + + val sourceDF = spark.createDataFrame( + Array( + RootRecordCC(1, Some("Test Name 1"), Some(Array(OrderCC("Order Test Name 1", Some("Deliver Test Name 1"))))), + RootRecordCC(2, Some("Test Name 2"), Some(Array(OrderCC("Order Test Name 2", None)))), + RootRecordCC(3, Some("Test Name 3"), None), + RootRecordCC(4, None, None))) + + val standardizedDF = Standardization.standardize(sourceDF, schema) + // 'orders' array is nullable, so it can be omitted + // But orders[].ordername is not nullable, so it must be specified + // But absence of orders should not cause validation errors + val count = standardizedDF.where(size(col("errCol")) > 0).count() + + assert(count == 0) + } + + test ("Test standardization of Date and Timestamp fields with default value and pattern") { + val schema = StructType( + Seq( + StructField("id" ,IntegerType, nullable = false), + StructField("date", DateType, nullable = true, Metadata.fromJson("""{"default": "20250101", "pattern": "yyyyMMdd"}""")), + StructField("timestamp", TimestampType, nullable = true, Metadata.fromJson("""{"default": "20250101.142626", "pattern": "yyyyMMdd.HHmmss"}""")))) + + val sourceDF = spark.createDataFrame( + List ( + Time(1, "20171004", "20171004.111111"), + Time(2, "", "") + ) + ) + + val expected = List ( + StdTime(1, new Date(1507075200000L), new Timestamp(1507115471000L), List()), + StdTime(2, new Date(1735689600000L), new Timestamp(1735741586000L), List(ErrorMessage.stdCastErr("date", ""), ErrorMessage.stdCastErr("timestamp", ""))) + ) + + val standardizedDF = Standardization.standardize(sourceDF, schema) + val result = standardizedDF.as[StdTime].collect().toList + assertResult(expected)(result) + } + + test ("Test standardization of Date and Timestamp fields with default value, without pattern") { + val schema = StructType( + Seq( + StructField("id" ,IntegerType, nullable = false), + StructField("date", DateType, nullable = true, Metadata.fromJson("""{"default": "2025-01-01"}""")), + StructField("timestamp", TimestampType, nullable = true, Metadata.fromJson("""{"default": "2025-01-01 14:26:26"}""")))) + + val sourceDF = spark.createDataFrame( + List ( + Time(1, "2017-10-04", "2017-10-04 11:11:11"), + Time(2, "", "") + ) + ) + + val expected = List ( + StdTime(1, new Date(1507075200000L), new Timestamp(1507115471000L), List()), + StdTime(2, new Date(1735689600000L), new Timestamp(1735741586000L), List(ErrorMessage.stdCastErr("date", ""), ErrorMessage.stdCastErr("timestamp", ""))) + ) + + val standardizedDF = Standardization.standardize(sourceDF, schema) + val result = standardizedDF.as[StdTime].collect().toList + assertResult(expected)(result) + } + + test ("Test standardization of Date and Timestamp fields without default value, with pattern") { + val schema = StructType( + Seq( + StructField("id" ,IntegerType, nullable = false), + StructField("date", DateType, nullable = true, Metadata.fromJson("""{"pattern": "yyyyMMdd"}""")), + StructField("timestamp", TimestampType, nullable = false, Metadata.fromJson("""{"pattern": "yyyyMMdd.HHmmss"}""")))) + + val sourceDF = spark.createDataFrame( + List ( + Time(1, "20171004", "20171004.111111"), + Time(2, "", "") + ) + ) + + val expected = List ( + StdTime(1, new Date(1507075200000L), new Timestamp(1507115471000L), List()), + StdTime(2, null, new Timestamp(0L), List(ErrorMessage.stdCastErr("date", ""), ErrorMessage.stdCastErr("timestamp", ""))) + ) + + val standardizedDF = Standardization.standardize(sourceDF, schema) + val result = standardizedDF.as[StdTime].collect().toList + assertResult(expected)(result) + } + + test ("Test standardization of Date and Timestamp fields without default value, without pattern") { + val schema = StructType( + Seq( + StructField("id" ,IntegerType, nullable = false), + StructField("date", DateType, nullable = false), + StructField("timestamp", TimestampType, nullable = true))) + + val sourceDF = spark.createDataFrame( + List ( + Time(1, "2017-10-04", "2017-10-04 11:11:11"), + Time(2, "", "") + ) + ) + + val expected = List ( + StdTime(1, new Date(1507075200000L), new Timestamp(1507115471000L), List()), + StdTime(2, new Date(0L), null, List(ErrorMessage.stdCastErr("date", ""), ErrorMessage.stdCastErr("timestamp", ""))) + ) + + val standardizedDF = Standardization.standardize(sourceDF, schema) + val result = standardizedDF.as[StdTime].collect().toList + assertResult(expected)(result) + } + + test("Errors in fields and having source columns") { + val desiredSchema = StructType(Seq( + StructField("first_name", StringType, nullable = true, + new MetadataBuilder().putString("sourcecolumn", "first name").build), + StructField("last_name", StringType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", "last name").build), + StructField("body_stats", + StructType(Seq( + StructField("height", IntegerType, nullable = false), + StructField("weight", IntegerType, nullable = false), + StructField("miscellaneous", StructType(Seq( + StructField("eye_color", StringType, nullable = true, + new MetadataBuilder().putString("sourcecolumn", "eye color").build), + StructField("glasses", BooleanType, nullable = true) + ))), + StructField("temperature_measurements", ArrayType(DoubleType, containsNull = false), nullable = false, + new MetadataBuilder().putString("sourcecolumn", "temperature measurements").build) + )), + nullable = false, + new MetadataBuilder().putString("sourcecolumn", "body stats").build + ) + )) + + + val srcString:String = FileReader.readFileAsString("src/test/resources/data/patients.json") + val src = JsonUtils.getDataFrameFromJson(spark, Seq(srcString)) + + logDataFrameContent(src) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val actualSchema = std.schema.treeString + val expectedSchema = "root\n" + + " |-- first_name: string (nullable = true)\n" + + " |-- last_name: string (nullable = true)\n" + + " |-- body_stats: struct (nullable = false)\n" + + " | |-- height: integer (nullable = true)\n" + + " | |-- weight: integer (nullable = true)\n" + + " | |-- miscellaneous: struct (nullable = false)\n" + + " | | |-- eye_color: string (nullable = true)\n" + + " | | |-- glasses: boolean (nullable = true)\n" + + " | |-- temperature_measurements: array (nullable = true)\n" + + " | | |-- element: double (containsNull = true)\n" + + " |-- errCol: array (nullable = true)\n" + + " | |-- element: struct (containsNull = false)\n" + + " | | |-- errType: string (nullable = true)\n" + + " | | |-- errCode: string (nullable = true)\n" + + " | | |-- errMsg: string (nullable = true)\n" + + " | | |-- errCol: string (nullable = true)\n" + + " | | |-- rawValues: array (nullable = true)\n" + + " | | | |-- element: string (containsNull = true)\n" + + " | | |-- mappings: array (nullable = true)\n" + + " | | | |-- element: struct (containsNull = true)\n" + + " | | | | |-- mappingTableColumn: string (nullable = true)\n" + + " | | | | |-- mappedDatasetColumn: string (nullable = true)\n" + assert(actualSchema == expectedSchema) + + val exp = Seq( + PatientRow("Jane", "Goodall", BodyStats(164, 61, "green", Option(true), Seq(36.6, 36.7, 37.0, 36.6))), + PatientRow("Scott", "Lang", BodyStats(0, 83, "blue", Option(false),Seq(36.6, 36.7, 37.0, 36.6)), Seq( + ErrorMessage.stdCastErr("body stats.height", "various") + )), + PatientRow("Aldrich", "Killian", BodyStats(181, 90, "brown or orange", None, Seq(36.7, 36.5, 38.0, 48.0, 152.0, 831.0, 0.0)), Seq( + ErrorMessage.stdCastErr("body stats.miscellaneous.glasses", "not any more"), + ErrorMessage.stdCastErr("body stats.temperature measurements[*]", "exploded") + )) + ) + + assertResult(exp)(std.as[PatientRow].collect().toList) + } +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_ArraySuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_ArraySuite.scala new file mode 100644 index 0000000..a7b500d --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_ArraySuite.scala @@ -0,0 +1,210 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import za.co.absa.standardization.implicits.DataFrameImplicits.DataFrameEnhancements +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary +import za.co.absa.standardization.{ErrorMessageFactory, JsonUtils, LoggerTestBase, SparkTestBase, Standardization, ValidationException} + +class StandardizationInterpreter_ArraySuite extends AnyFunSuite with SparkTestBase with LoggerTestBase with Matchers { + import spark.implicits._ + + private implicit val udfLib: UDFLibrary = new UDFLibrary + private implicit val defaults: Defaults = GlobalDefaults + + private val fieldName = "arrayField" + + private def generateDesiredSchema(arrayElementType: String, metadata: String): StructType = { + val jsonField: String = s"""{"name": "$fieldName", "type": { "type": "array", "elementType": $arrayElementType, "containsNull": true}, "nullable": true, "metadata": {$metadata} }""" + val fullJson = s"""{"type": "struct", "fields": [$jsonField]}""" + DataType.fromJson(fullJson).asInstanceOf[StructType] + } + + private def generateDesiredSchema(arrayElementType: DataType, metadata: String = ""): StructType = { + generateDesiredSchema('"' + arrayElementType.typeName + '"', metadata) + } + + test("Array of timestamps with no pattern") { + val seq = Seq( + Array("00:00:00 01.12.2018", "00:10:00 02.12.2018","00:20:00 03.12.2018"), + Array("00:00:00 01.12.2019", "00:10:00 02.12.2019","00:20:00 03.12.2019"), + Array("2020-01-12 00:00:00" ,"2020-12-02 00:10:00","2020-12-03 00:20:00") + ) + val src = seq.toDF(fieldName) + val desiredSchema = generateDesiredSchema(TimestampType) + + val expectedData = + """+---------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||arrayField |errCol | + |+---------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||[,,] |[[stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [00:00:00 01.12.2018], []], [stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [00:10:00 02.12.2018], []], [stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [00:20:00 03.12.2018], []]]| + ||[,,] |[[stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [00:00:00 01.12.2019], []], [stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [00:10:00 02.12.2019], []], [stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [00:20:00 03.12.2019], []]]| + ||[2020-01-12 00:00:00, 2020-12-02 00:10:00, 2020-12-03 00:20:00]|[] | + |+---------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + val expectedSchema = ErrorMessageFactory.attachErrColToSchemaPrint( + "root\n"+ + " |-- arrayField: array (nullable = true)\n" + + " | |-- element: timestamp (containsNull = true)" + ) + + val std = Standardization.standardize(src, desiredSchema).cache() + assert(std.schema.treeString == expectedSchema) + assert(std.dataAsString(false) == expectedData) + } + + test("Array of timestamps with pattern defined") { + val seq = Seq( + Array("00:00:00 01.12.2008", "00:10:00 02.12.2008","00:20:00 03.12.2008"), + Array("00:00:00 01.12.2009", "00:10:00 02.12.2009","00:20:00 03.12.2009"), + Array("2010-01-12 00:00:00" ,"2010-12-02 00:10:00","2010-12-03 00:20:00") + ) + val src = seq.toDF(fieldName) + val desiredSchema = generateDesiredSchema(TimestampType, s""""${MetadataKeys.Pattern}": "HH:mm:ss dd.MM.yyyy"""") + + val expectedData = + """+---------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||arrayField |errCol | + |+---------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||[2008-12-01 00:00:00, 2008-12-02 00:10:00, 2008-12-03 00:20:00]|[] | + ||[2009-12-01 00:00:00, 2009-12-02 00:10:00, 2009-12-03 00:20:00]|[] | + ||[,,] |[[stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [2010-01-12 00:00:00], []], [stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [2010-12-02 00:10:00], []], [stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [2010-12-03 00:20:00], []]]| + |+---------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + val expectedSchema = ErrorMessageFactory.attachErrColToSchemaPrint( + "root\n"+ + " |-- arrayField: array (nullable = true)\n" + + " | |-- element: timestamp (containsNull = true)" + ) + val std = Standardization.standardize(src, desiredSchema).cache() + assert(std.schema.treeString == expectedSchema) + assert(std.dataAsString(false) == expectedData) + } + + test("Array of timestamps with invalid pattern") { + val seq = Seq( + Array("00:00:00 01.12.2013", "00:10:00 02.12.2013","00:20:00 03.12.2013"), + Array("00:00:00 01.12.2014", "00:10:00 02.12.2014","00:20:00 03.12.2014"), + Array("2015-01-12 00:00:00" ,"2015-12-02 00:10:00","2015-12-03 00:20:00") + ) + val src = seq.toDF(fieldName) + val desiredSchema = generateDesiredSchema(TimestampType, s""""${MetadataKeys.Pattern}": "fubar"""") + val caught = intercept[ValidationException] { + Standardization.standardize(src, desiredSchema).cache() + } + + caught.getMessage should startWith ("A fatal schema validation error occurred.") + caught.errors.head should startWith ("Validation error for column 'arrayField[].arrayField', pattern 'fubar") + } + + test("Array of integers with pattern defined") { + val seq = Seq( + Array("Size: 1", "Size: 2","Size: 3"), + Array("Size: -7", "Size: ~13.13"), + Array("A" , null, "") + ) + val src = seq.toDF(fieldName) + val desiredSchema = generateDesiredSchema(IntegerType, s""""${MetadataKeys.Pattern}": "Size: #;Size: -#"""") + + val expectedData = + """+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||arrayField|errCol | + |+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||[1, 2, 3] |[] | + ||[-7,] |[[stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [Size: ~13.13], []]] | + ||[,,] |[[stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [A], []], [stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [], []]]| + |+----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + val expectedSchema = ErrorMessageFactory.attachErrColToSchemaPrint( + "root\n"+ + " |-- arrayField: array (nullable = true)\n" + + " | |-- element: integer (containsNull = true)" + ) + + val std = Standardization.standardize(src, desiredSchema).cache() + assert(std.schema.treeString == expectedSchema) + assert(std.dataAsString(false) == expectedData) + } + + test("Array of floats with minus sign changed and default defined") { + val seq = Seq( + Array("1.1", "2.2","3.3"), + Array("~7.7", "-13.13"), + Array("A" , null, "") + ) + val src = seq.toDF(fieldName) + val desiredSchema = generateDesiredSchema(FloatType, s""""${MetadataKeys.DefaultValue}": "3.14", "${MetadataKeys.MinusSign}": "~" """) + + val expectedData = + """+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||arrayField |errCol | + |+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + ||[1.1, 2.2, 3.3]|[] | + ||[-7.7, 3.14] |[[stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [-13.13], []]] | + ||[3.14,, 3.14] |[[stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [A], []], [stdCastError, E00000, Standardization Error - Type cast, arrayField[*], [], []]]| + |+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + val expectedSchema = ErrorMessageFactory.attachErrColToSchemaPrint( + "root\n"+ + " |-- arrayField: array (nullable = true)\n" + + " | |-- element: float (containsNull = true)" + ) + + val std = Standardization.standardize(src, desiredSchema).cache() + assert(std.schema.treeString == expectedSchema) + assert(std.dataAsString(false) == expectedData) + } + + test("Array of arrays of string") { + val seq = Seq( + s"""{"$fieldName": [["a", "bb", "ccc"],["1", "12"],["Hello", null, "World"]]}""" + ) + val src = JsonUtils.getDataFrameFromJson(spark, seq) + + val subArrayJson = """{"type": "array", "elementType": "string", "containsNull": false}""" + val desiredSchema = generateDesiredSchema(subArrayJson, s""""${MetadataKeys.DefaultValue}": "Nope"""") + + val expectedData = + """+---------------------------------------------+--------------------------------------------------------------------------------------------------------------------+ + ||arrayField |errCol | + |+---------------------------------------------+--------------------------------------------------------------------------------------------------------------------+ + ||[[a, bb, ccc], [1, 12], [Hello, Nope, World]]|[[stdNullError, E00002, Standardization Error - Null detected in non-nullable attribute, arrayField[*], [null], []]]| + |+---------------------------------------------+--------------------------------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + val expectedSchema = ErrorMessageFactory.attachErrColToSchemaPrint( + "root\n"+ + " |-- arrayField: array (nullable = true)\n" + + " | |-- element: array (containsNull = true)\n" + + " | | |-- element: string (containsNull = true)" + ) + + val std = Standardization.standardize(src, desiredSchema).cache() + + assert(std.schema.treeString == expectedSchema) + assert(std.dataAsString(false) == expectedData) + } +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_BinarySuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_BinarySuite.scala new file mode 100644 index 0000000..963ea59 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_BinarySuite.scala @@ -0,0 +1,192 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import org.apache.spark.sql.types.{BinaryType, Metadata, MetadataBuilder, StructField, StructType} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary +import za.co.absa.standardization.{ErrorMessage, LoggerTestBase, SparkTestBase, Standardization, ValidationException} + +class StandardizationInterpreter_BinarySuite extends AnyFunSuite with SparkTestBase with LoggerTestBase with Matchers { + + import spark.implicits._ + + private implicit val udfLib: UDFLibrary = new UDFLibrary + private implicit val defaults: Defaults = GlobalDefaults + + private val fieldName = "binaryField" + + test("byteArray to Binary") { + val seq = Seq( + Array(1, 2, 3).map(_.toByte), + Array('a', 'b', 'c').map(_.toByte) + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, BinaryType, nullable = false) + )) + val expected = Seq( + BinaryRow(Array(1, 2, 3).map(_.toByte)), + BinaryRow(Array(97, 98, 99).map(_.toByte)) + ) + + val src = seq.toDF(fieldName) + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val result = std.as[BinaryRow].collect().toList + expected.map(_.simpleFields) should contain theSameElementsAs result.map(_.simpleFields) + } + + test("Binary from string with base64 encoding") { + val seq = Seq( + "MTIz", + "YWJjZA==", + "bogus#$%^" // invalid base64 chars + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, BinaryType, nullable = false, + new MetadataBuilder().putString("encoding", "base64").build) + )) + + val expected = Seq( + BinaryRow(Array(49, 50, 51).map(_.toByte)), // "123" + BinaryRow(Array(97, 98, 99, 100).map(_.toByte)), // "abcd" + BinaryRow(Array.emptyByteArray, // default value on error + Seq(ErrorMessage("stdCastError", "E00000", "Standardization Error - Type cast", "binaryField", + rawValues = Seq("bogus#$%^"), mappings = Seq())) + ) + ) + + val src = seq.toDF(fieldName) + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val result = std.as[BinaryRow].collect().toList + expected.map(_.simpleFields) should contain theSameElementsAs result.map(_.simpleFields) + } + + test("Binary from string with bogus encoding") { + val seq = Seq( + "does not matter" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, BinaryType, nullable = false, + new MetadataBuilder().putString("encoding", "bogus").build) + )) + + val src = seq.toDF(fieldName) + val caught = intercept[ValidationException]( + Standardization.standardize(src, desiredSchema).cache() + ) + + caught.errors.length shouldBe 1 + caught.errors.head should include("Unsupported encoding for Binary field binaryField: 'bogus'") + } + + // behavior of explicit metadata "none" and lacking metadata should behave identically + Seq(None, Some("none")).foreach { enc => + test(s"Binary from string with ${enc.getOrElse("missing")} encoding") { + val seq = Seq( + "abc", + "1234" + ) + + val metadata = enc.fold(Metadata.empty)(e => new MetadataBuilder().putString("encoding", e).build) + val desiredSchema = StructType(Seq( + StructField(fieldName, BinaryType, nullable = false, metadata) + )) + + val expected = Seq( + BinaryRow(Array(97, 98, 99).map(_.toByte)), // "123" + BinaryRow(Array(49, 50, 51, 52).map(_.toByte)) // "abcd" + ) + + val src = seq.toDF(fieldName) + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val result = std.as[BinaryRow].collect().toList + expected.map(_.simpleFields) should contain theSameElementsAs result.map(_.simpleFields) + } + } + + test("Binary with defaultValue uses base64") { + val seq = Seq[Option[String]]( + Some("MTIz"), + None + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, BinaryType, nullable = false, new MetadataBuilder() + .putString("encoding", "base64") + .putString("default", "ZW1wdHk=") // "empty" + .build) + )) + + val expected = Seq( + BinaryRow(Array(49, 50, 51).map(_.toByte)), // "123" + // ^ std error is written into the errCol and the default (fallback) value "(binary) empty" is used. + BinaryRow(Array('e', 'm', 'p', 't', 'y').map(_.toByte), + Seq(ErrorMessage("stdNullError", "E00002", "Standardization Error - Null detected in non-nullable attribute", + "binaryField", rawValues = Seq("null"), mappings = Seq())) + ) + ) + + val src = seq.toDF(fieldName) + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val result = std.as[BinaryRow].collect().toList + expected.map(_.simpleFields) should contain theSameElementsAs result.map(_.simpleFields) + } + + // behavior of explicit metadata "none" and lacking metadata should behave identically + Seq(None, Some("none")).foreach { enc => + test(s"Binary with defaultValue ${enc.getOrElse("missing")} encoding") { + val seq = Seq[Option[String]]( + Some("123"), + None + ) + + val metadata = { + val base = new MetadataBuilder().putString("default", "fallback1") + enc.fold(base)(base.putString("encoding", _)) + base.build + } + + val desiredSchema = StructType(Seq(StructField(fieldName, BinaryType, nullable = false, metadata))) + + val expected = Seq( + BinaryRow(Array(49, 50, 51).map(_.toByte)), // "123" + // ^ std error is written into the errCol and the default (fallback) value "(binary) empty" is used. + BinaryRow(Array('f', 'a', 'l', 'l', 'b', 'a', 'c', 'k', '1').map(_.toByte), + Seq(ErrorMessage("stdNullError", "E00002", "Standardization Error - Null detected in non-nullable attribute", + "binaryField", rawValues = Seq("null"), mappings = Seq())) + ) + ) + + val src = seq.toDF(fieldName) + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val result = std.as[BinaryRow].collect().toList + expected.map(_.simpleFields) should contain theSameElementsAs result.map(_.simpleFields) + } + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_DateSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_DateSuite.scala new file mode 100644 index 0000000..cd5a48b --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_DateSuite.scala @@ -0,0 +1,361 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import java.sql.Date +import org.apache.spark.sql.types.{DateType, MetadataBuilder, StructField, StructType} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary +import za.co.absa.standardization.{ErrorMessage, LoggerTestBase, SparkTestBase, Standardization} + +class StandardizationInterpreter_DateSuite extends AnyFunSuite with SparkTestBase with LoggerTestBase { + import spark.implicits._ + + private implicit val udfLib: UDFLibrary = new UDFLibrary + private implicit val defaults: Defaults = GlobalDefaults + + private val fieldName = "dateField" + + test("epoch") { + val seq = Seq( + 0, + 86399, + 86400, + 978307199, + 1563288103 + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder().putString("pattern", "epoch").build) + )) + val exp = Seq( + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-01-02")), + DateRow(Date.valueOf("2000-12-31")), + DateRow(Date.valueOf("2019-07-16")) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + + test("epochmilli") { + val seq = Seq( + 0L, + 86400000, + 978307199999L, + 1563288103123L + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder().putString("pattern", "epochmilli").build) + )) + val exp = Seq( + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-01-02")), + DateRow(Date.valueOf("2000-12-31")), + DateRow(Date.valueOf("2019-07-16")) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + + test("epochmicro") { + val seq = Seq( + 0.1, + 86400000000.02, + 978307199999999.003, + 1563288103123456.123 + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder().putString("pattern", "epochmicro").build) + )) + val exp = Seq( + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-01-02")), + DateRow(Date.valueOf("2000-12-31")), + DateRow(Date.valueOf("2019-07-16")) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + + test("epochnano") { + val seq = Seq( + 0, + 86400000000000L, + 978307199999999999L, + 1563288103123456789L + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder().putString("pattern", "epochnano").build) + )) + val exp = Seq( + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-01-02")), + DateRow(Date.valueOf("2000-12-31")), + DateRow(Date.valueOf("2019-07-16")) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + + test("simple date pattern") { + val seq = Seq( + "1970/01/01", + "1970/02/01", + "2000/31/12", + "2019/16/07", + "1970-02-02", + "crash" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder().putString("pattern", "yyyy/dd/MM").build) + )) + val exp = Seq( + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-01-02")), + DateRow(Date.valueOf("2000-12-31")), + DateRow(Date.valueOf("2019-07-16")), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "1970-02-02"))), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "crash"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + + test("date + time pattern and named time zone") { + val seq = Seq( + "01-00-00 01.01.1970 CET", + "00-00-00 03.01.1970 EET", + "21-45-39 30.12.2000 PST", + "14-25-11 16.07.2019 UTC", + "00-75-00 03.01.1970 EET", + "crash" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder().putString("pattern", "HH-mm-ss dd.MM.yyyy ZZZ").build) + )) + val exp = Seq( + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-01-02")), + DateRow(Date.valueOf("2000-12-31")), + DateRow(Date.valueOf("2019-07-16")), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "00-75-00 03.01.1970 EET"))), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "crash"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + + test("date + time + second fractions pattern and offset time zone") { + val seq = Seq( + "01:00:00(000000000) 01+01+1970 +01:00", + "00:00:00(001002003) 03+01+1970 +02:00", + "21:45:39(999999999) 30+12+2000 -08:00", + "14:25:11(123456789) 16+07+2019 +00:00", + "00:75:00(001002003) 03+01+1970 +02:00", + "crash" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder().putString("pattern", "HH:mm:ss(SSSnnnnnn) dd+MM+yyyy XXX").build) + )) + val exp = Seq( + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-01-02")), + DateRow(Date.valueOf("2000-12-31")), + DateRow(Date.valueOf("2019-07-16")), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "00:75:00(001002003) 03+01+1970 +02:00"))), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "crash"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + + test("date with default time zone - EST") { + val seq = Seq( + "1970/01/01", + "1970/02/01", + "2000/31/12", + "2019/16/07", + "1970-02-02", + "crash" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder() + .putString("pattern", "yyyy/dd/MM") + .putString("timezone", "EST") + .build) + )) + val exp = Seq( + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-01-02")), + DateRow(Date.valueOf("2000-12-31")), + DateRow(Date.valueOf("2019-07-16")), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "1970-02-02"))), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "crash"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + + + test("date with default time zone - SAST") { + val seq = Seq( + "1970/01/01", + "1970/02/01", + "2000/31/12", + "2019/16/07", + "1970-02-02", + "crash" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder() + .putString("pattern", "yyyy/dd/MM") + .putString("timezone", "Africa/Johannesburg") + .build) + )) + val exp = Seq( + DateRow(Date.valueOf("1969-12-31")), + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("2000-12-30")), + DateRow(Date.valueOf("2019-07-15")), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "1970-02-02"))), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "crash"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + + test("date with quoted") { + val seq = Seq( + "January 1 of 1970", + "February 1 of 1970", + "December 31 of 2000", + "July 16 of 2019", + "02 3 of 1970", + "February 4 1970", + "crash" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder().putString("pattern", "MMMM d 'of' yyyy").build) + )) + val exp = Seq( + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-02-01")), + DateRow(Date.valueOf("2000-12-31")), + DateRow(Date.valueOf("2019-07-16")), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "02 3 of 1970"))), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "February 4 1970"))), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "crash"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + + /* TODO this should work with #7 fixed (originally Enceladus#677) + test("date with quoted and second frations") { + val seq = Seq( + "1970/01/01 insignificant 000000", + "1970/02/01 insignificant 001002", + "2000/31/12 insignificant 999999", + "2019/16/07 insignificant 123456", + "1970/02/02 insignificant ", + "crash" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, DateType, nullable = false, + new MetadataBuilder().putString("pattern", "yyyy/MM/dd 'insignificant' iiiiii").build) + )) + val exp = Seq( + DateRow(Date.valueOf("1970-01-01")), + DateRow(Date.valueOf("1970-02-01")), + DateRow(Date.valueOf("2000-12-31")), + DateRow(Date.valueOf("2019-07-16")), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "1970/02/02 insignificant "))), + DateRow(Date.valueOf("1970-01-01"), Seq(ErrorMessage.stdCastErr(fieldName, "crash"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema) + logDataFrameContent(std) + + assertResult(exp)(std.as[DateRow].collect().toList) + } + */ + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_DecimalSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_DecimalSuite.scala new file mode 100644 index 0000000..ac7a3ba --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_DecimalSuite.scala @@ -0,0 +1,280 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import java.text.{DecimalFormat, NumberFormat} +import java.util.Locale +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.{ErrorMessage, LoggerTestBase, SparkTestBase, Standardization} +import za.co.absa.standardization.udf.UDFLibrary + +class StandardizationInterpreter_DecimalSuite extends AnyFunSuite with SparkTestBase with LoggerTestBase { + import spark.implicits._ + + private implicit val udfLib: UDFLibrary = new UDFLibrary + private implicit val defaults: Defaults = GlobalDefaults + + private val desiredSchema = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("small", DecimalType(5,2), nullable = false), + StructField("big", DecimalType(38,18), nullable = true) + )) + + private val zero = BigDecimal("0E-18") + private val bigDecimalFormat = { + val pattern = "0.000000000000000000" //18 decimal places + val nf = NumberFormat.getNumberInstance(Locale.US) + val df = nf.asInstanceOf[DecimalFormat] + df.applyPattern(pattern) + df + } + + private def bd(number: Double): BigDecimal = { + val s: String = bigDecimalFormat.format(number) + BigDecimal(s) + } + + test("From String") { + val seq = Seq( + ("01-Pi", "3.14", "3.14"), + ("02-Null", null, null), + ("03-Long", Long.MaxValue.toString, Long.MinValue.toString), + ("04-infinity", "-Infinity", "Infinity"), + ("05-Really big", "123456789123456791245678912324789123456789123456789.12", + "12345678912345679124567891232478912345678912345678912345678912345678912345678912345678912345678912345678912345" + + "678912345678912345678912345678912345678912345678912345678912345678912345678912345678912345678912346789123456" + + "789123456789123456789123456791245678912324789123456789123456789123456789123456789123456791245678912324789123" + + "456789123456789123456789123456789123456789123456789123456789.1"), + ("06-Text", "foo", "bar"), + ("07-Exponential notation", "-1.23E2", "+9.8765E-4"), + ("08-Small overflow", "1000", "1000"), + ("09-Loss of precision", "123.456", "123.456") + ) + val src = seq.toDF("description","small", "big") + logDataFrameContent(src) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val exp = Seq( + DecimalRow("01-Pi", Option(bd(3.14)), Option(bd(3.14))), + DecimalRow("02-Null", Option(zero), None, Seq( + ErrorMessage.stdNullErr("small"))), + DecimalRow("03-Long", Option(zero), Option(Long.MinValue), Seq( + ErrorMessage.stdCastErr("small", Long.MaxValue.toString))), + DecimalRow("04-infinity", Option(zero), None, Seq( + ErrorMessage.stdCastErr("small", "-Infinity"), + ErrorMessage.stdCastErr("big", "Infinity"))), + DecimalRow("05-Really big", Option(zero), None, Seq( + ErrorMessage.stdCastErr("small", "123456789123456791245678912324789123456789123456789.12"), + ErrorMessage.stdCastErr("big", "1234567891234567912456789123247891234567891234567891234567891" + + "2345678912345678912345678912345678912345678912345678912345678912345678912345678912345678912345678912345678" + + "9123456789123456789123456789123456789123467891234567891234567891234567891234567912456789123247891234567891" + + "2345678912345678912345678912345679124567891232478912345678912345678912345678912345678912345678912345678912" + + "3456789.1"))), + DecimalRow("06-Text", Option(zero), None, Seq( + ErrorMessage.stdCastErr("small", "foo"), + ErrorMessage.stdCastErr("big", "bar"))), + DecimalRow("07-Exponential notation", Option(bd(-123)), Option(bd(0.00098765))), + DecimalRow("08-Small overflow", Option(zero), Option(bd(1000)), Seq( + ErrorMessage.stdCastErr("small", "1000"))), + DecimalRow("09-Loss of precision", Option(bd(123.46)), Option(bd(123.456))) + ) + + assertResult(exp)(std.as[DecimalRow].collect().sortBy(_.description).toList) + } + + test("From double") { + val reallyBig = Double.MaxValue + val seq = Seq( + new InputRowDoublesForDecimal("01-Pi", Math.PI), + InputRowDoublesForDecimal("02-Null", None, None), + InputRowDoublesForDecimal("03-Long", Option(Long.MaxValue.toFloat), Option(Long.MinValue.toDouble)), + InputRowDoublesForDecimal("04-Infinity", Option(Float.NegativeInfinity), Option(Double.PositiveInfinity)), + new InputRowDoublesForDecimal("05-Really big", reallyBig), + InputRowDoublesForDecimal("06-NaN", Option(Float.NaN), Option(Double.NaN)) + ) + val src = spark.createDataFrame(seq) + logDataFrameContent(src) + + val exp = Seq( + DecimalRow("01-Pi", Option(bd(3.14)), Option(bd(Math.PI))), //NB! Note the loss of precision in Pi + DecimalRow("02-Null", Option(zero), None, Seq( + ErrorMessage.stdNullErr("small"))), + DecimalRow("03-Long", Option(zero), Option(-9223372036854776000.0), Seq( // rounding in doubles for large integers + ErrorMessage.stdCastErr("small", "9.223372036854776E18"))), + DecimalRow("04-Infinity", Option(zero), None, Seq( + ErrorMessage.stdCastErr("small", "-Infinity"), + ErrorMessage.stdCastErr("big", "Infinity"))), + DecimalRow("05-Really big", Option(zero), None, Seq( + ErrorMessage.stdCastErr("small", reallyBig.toString), + ErrorMessage.stdCastErr("big", reallyBig.toString))), + DecimalRow("06-NaN", Option(zero), None, Seq( + ErrorMessage.stdCastErr("small", "NaN"), + ErrorMessage.stdCastErr("big", "NaN"))) + ) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[DecimalRow].collect().sortBy(_.description).toList) + } + + test("No pattern, but altered symbols") { + val input = Seq( + ("01-Normal", "123:456"), + ("02-Null", null), + ("03-Far negative", "N100000000:999"), + ("04-Wrong", "hello"), + ("05-Not adhering to pattern", "123456.789") + ) + + val decimalSeparator = ":" + val minusSign = "N" + val srcField = "src" + + val src = input.toDF("description", srcField) + + val desiredSchemaWithAlters = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("src", StringType, nullable = true), + StructField("small", DecimalType(5,2), nullable = false, new MetadataBuilder() + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .putString(MetadataKeys.SourceColumn, srcField) + .build()), + StructField("big", DecimalType(38,18), nullable = true, new MetadataBuilder() + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .putString(MetadataKeys.DefaultValue, "N1:1") + .putString(MetadataKeys.SourceColumn, srcField) + .build()) + )) + + val std = Standardization.standardize(src, desiredSchemaWithAlters).cache() + logDataFrameContent(std) + + val exp = List( + ("01-Normal", "123:456", BigDecimal("123.460000000000000000"), BigDecimal("123.456000000000000000"), Seq.empty), //NB the rounding in the small + ("02-Null", null, BigDecimal("0E-18"), null, Seq(ErrorMessage.stdNullErr(srcField))), + ("03-Far negative", "N100000000:999", BigDecimal("0E-18"), BigDecimal("-100000000.999000000000000000"), Seq(ErrorMessage.stdCastErr(srcField,"N100000000:999"))), + ("04-Wrong", "hello", BigDecimal("0E-18"), BigDecimal("-1.100000000000000000"), Seq(ErrorMessage.stdCastErr(srcField,"hello"), ErrorMessage.stdCastErr(srcField,"hello"))), + ("05-Not adhering to pattern", "123456.789", BigDecimal("0E-18"), BigDecimal("-1.100000000000000000"), Seq(ErrorMessage.stdCastErr(srcField,"123456.789"), ErrorMessage.stdCastErr(srcField,"123456.789"))) + ) + + assertResult(exp)(std.as[(String, String, BigDecimal, BigDecimal, Seq[ErrorMessage])].collect().toList) + } + + test("Using patterns") { + val input = Seq( + ("01-Normal", "123.4‰"), + ("02-Null", null), + ("03-Big", "100,000,000.999‰"), + ("04-Wrong", "hello"), + ("05-Not adhering to pattern", "123456.789") + ) + + val pattern = "#,##0.##‰" + val srcField = "src" + + val src = input.toDF("description", srcField) + + val desiredSchemaWithPatterns = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("src", StringType, nullable = true), + StructField("small", DecimalType(5,2), nullable = false, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.SourceColumn, srcField) + .build()), + StructField("big", DecimalType(38,18), nullable = true, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.DefaultValue, "1,000‰") + .putString(MetadataKeys.SourceColumn, srcField) + .build()) + )) + + val std = Standardization.standardize(src, desiredSchemaWithPatterns).cache() + logDataFrameContent(std) + + val exp = List( + ("01-Normal", "123.4‰", BigDecimal("0.120000000000000000"), BigDecimal("0.123400000000000000"), Seq.empty), + ("02-Null", null, BigDecimal("0E-18"), null, Seq(ErrorMessage.stdNullErr(srcField))), + ("03-Big", "100,000,000.999‰", BigDecimal("0E-18"), BigDecimal("100000.000999000000000000"), Seq(ErrorMessage.stdCastErr(srcField,"100,000,000.999‰"))), + ("04-Wrong", "hello", BigDecimal("0E-18"), BigDecimal("1.000000000000000000"), Seq(ErrorMessage.stdCastErr(srcField,"hello"), ErrorMessage.stdCastErr(srcField,"hello"))), + ("05-Not adhering to pattern", "123456.789", BigDecimal("0E-18"), BigDecimal("1.000000000000000000"), Seq(ErrorMessage.stdCastErr(srcField,"123456.789"), ErrorMessage.stdCastErr(srcField,"123456.789"))) + ) + + assertResult(exp)(std.as[(String, String, BigDecimal, BigDecimal, Seq[ErrorMessage])].collect().toList) + } + + test("Pattern with symbols alterated") { + val input = Seq( + ("01-Normal", "9 123,4"), + ("02-Null", null), + ("03-Big", "100 000 000,999"), + ("04-Wrong", "hello"), + ("05-Not adhering to pattern", "123456.789"), + ("06-Negative", "~54 123,789") + ) + + val srcField = "src" + val decimalSeparator = "," + val groupingSeparator = " " + val minusSign = "~" + val pattern = "#,##0.#" // NB the default symbols, not the redefined ones + + val src = input.toDF("description", srcField) + + val desiredSchemaWithPatterns = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("src", StringType, nullable = true), + StructField("small", DecimalType(7,2), nullable = false, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .putString(MetadataKeys.SourceColumn, srcField) + .build()), + StructField("big", DecimalType(38,18), nullable = true, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .putString(MetadataKeys.DefaultValue, "1,000") + .putString(MetadataKeys.SourceColumn, srcField) + .build()) + )) + + val std = Standardization.standardize(src, desiredSchemaWithPatterns).cache() + logDataFrameContent(std) + + val exp = List( + ("01-Normal", "9 123,4", BigDecimal("9123.400000000000000000"), BigDecimal("9123.400000000000000000"), Seq.empty), + ("02-Null", null, BigDecimal("0E-18"), null, Seq(ErrorMessage.stdNullErr(srcField))), + ("03-Big", "100 000 000,999", BigDecimal("0E-18"), BigDecimal("100000000.999000000000000000"), Seq(ErrorMessage.stdCastErr(srcField,"100 000 000,999"))), + ("04-Wrong", "hello", BigDecimal("0E-18"), BigDecimal("1.000000000000000000"), Seq(ErrorMessage.stdCastErr(srcField,"hello"), ErrorMessage.stdCastErr(srcField,"hello"))), + ("05-Not adhering to pattern", "123456.789", BigDecimal("0E-18"), BigDecimal("1.000000000000000000"), Seq(ErrorMessage.stdCastErr(srcField,"123456.789"), ErrorMessage.stdCastErr(srcField,"123456.789"))), + ("06-Negative", "~54 123,789", BigDecimal("-54123.790000000000000000"), BigDecimal("-54123.789000000000000000"), Seq.empty) + ) + + assertResult(exp)(std.as[(String, String, BigDecimal, BigDecimal, Seq[ErrorMessage])].collect().toList) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_FractionalSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_FractionalSuite.scala new file mode 100644 index 0000000..fe124b3 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_FractionalSuite.scala @@ -0,0 +1,385 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary +import za.co.absa.standardization.{ErrorMessage, LoggerTestBase, SparkTestBase, Standardization} + +class StandardizationInterpreter_FractionalSuite extends AnyFunSuite with SparkTestBase with LoggerTestBase { + import spark.implicits._ + + private implicit val udfLib: UDFLibrary = new UDFLibrary + private implicit val defaults: Defaults = GlobalDefaults + + private def err(value: String, cnt: Int): Seq[ErrorMessage] = { + val item = ErrorMessage.stdCastErr("src",value) + val array = Array.fill(cnt) (item) + array.toList + } + + private val desiredSchema = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("floatField", FloatType, nullable = false), + StructField("doubleField", DoubleType, nullable = true) + )) + + private val desiredSchemaWithInfinity = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("floatField", FloatType, nullable = false, + new MetadataBuilder().putString("allow_infinity", value = "true").build), + StructField("doubleField", DoubleType, nullable = true, + new MetadataBuilder().putString("allow_infinity", value = "true").build) + )) + + test("From String") { + val seq = Seq( + ("01-Pi", "3.14", "3.14"), + ("02-Null", null, null), + ("03-Long", Long.MaxValue.toString, Long.MinValue.toString), + ("04-infinity", "-Infinity", "Infinity"), + ("05-Really big", "123456789123456791245678912324789123456789123456789.12", + "12345678912345679124567891232478912345678912345678912345678912345678912345678912345678912345678912345678912345" + + "678912345678912345678912345678912345678912345678912345678912345678912345678912345678912345678912346789123456" + + "789123456789123456789123456791245678912324789123456789123456789123456789123456789123456791245678912324789123" + + "456789123456789123456789123456789123456789123456789123456789.1"), + ("06-Text", "foo", "bar"), + ("07-Exponential notation", "-1.23E4", "+9.8765E-3") + ) + val src = seq.toDF("description","floatField", "doubleField") + logDataFrameContent(src) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val exp = Seq( + FractionalRow("01-Pi", Option(3.14F), Option(3.14)), + FractionalRow("02-Null", Option(0), None, Seq( + ErrorMessage.stdNullErr("floatField"))), + FractionalRow("03-Long", Option(9.223372E18F), Option(-9.223372036854776E18)), + FractionalRow("04-infinity", Option(0), None, Seq( + ErrorMessage.stdCastErr("floatField", "-Infinity"), + ErrorMessage.stdCastErr("doubleField", "Infinity"))), + FractionalRow("05-Really big", Option(0), None, Seq( + ErrorMessage.stdCastErr("floatField", "123456789123456791245678912324789123456789123456789.12"), + ErrorMessage.stdCastErr("doubleField", "12345678912345679124567891232478912345678912345678912" + + "3456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789" + + "1234567891234567891234567891234567891234567891234678912345678912345678912345678912345679124567891232478912" + + "3456789123456789123456789123456789123456791245678912324789123456789123456789123456789123456789123456789123" + + "456789123456789.1"))), + FractionalRow("06-Text", Option(0), None, Seq( + ErrorMessage.stdCastErr("floatField", "foo"), + ErrorMessage.stdCastErr("doubleField", "bar"))), + FractionalRow("07-Exponential notation", Option(-12300.0f), Option(0.0098765)) + ) + + assertResult(exp)(std.as[FractionalRow].collect().sortBy(_.description).toList) + } + + test("From Long") { + val value = 1984 + val seq = Seq( + InputRowLongsForFractional("01-Null", None, None), + InputRowLongsForFractional("02-Big Long", Option(Long.MaxValue - 1), Option(Long.MinValue + 1)), + InputRowLongsForFractional("03-Long", Option(-value), Option(value)) + ) + val src = spark.createDataFrame(seq) + logDataFrameContent(src) + + val exp = Seq( + FractionalRow("01-Null", Option(0), None, Seq( + ErrorMessage.stdNullErr("floatField"))), + FractionalRow("02-Big Long", Option(9.223372E18F), Option(-9.223372036854776E18)), //NBN! the loss of precision + FractionalRow("03-Long", Option(-value.toFloat), Option(value.toDouble)) + ) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[FractionalRow].collect().sortBy(_.description).toList) + } + + test("From Double") { + val reallyBig = Double.MaxValue + val seq = Seq( + new InputRowDoublesForFractional("01-Pi", Math.PI), + InputRowDoublesForFractional("02-Null", None, None), + InputRowDoublesForFractional("03-Long", Option(Long.MaxValue.toFloat), Option(Long.MinValue.toDouble)), + InputRowDoublesForFractional("04-Infinity", Option(Float.NegativeInfinity), Option(Double.PositiveInfinity)), + new InputRowDoublesForFractional("05-Really big", reallyBig), + InputRowDoublesForFractional("06-NaN", Option(Float.NaN), Option(Double.NaN)) + ) + val src = spark.createDataFrame(seq) + logDataFrameContent(src) + + val exp = Seq( + FractionalRow("01-Pi", Option(Math.PI.toFloat), Option(Math.PI)), + FractionalRow("02-Null", Option(0), None, Seq( + ErrorMessage.stdNullErr("floatField"))), + FractionalRow("03-Long", Option(9.223372E18F), Option(-9.223372036854776E18)), + FractionalRow("04-Infinity", Option(0), None, Seq( + ErrorMessage.stdCastErr("floatField", "-Infinity"), + ErrorMessage.stdCastErr("doubleField", "Infinity"))), + FractionalRow("05-Really big", Option(0), Option(reallyBig), Seq( + ErrorMessage.stdCastErr("floatField", reallyBig.toString))), + FractionalRow("06-NaN", Option(0), None, Seq( + ErrorMessage.stdCastErr("floatField", "NaN"), + ErrorMessage.stdCastErr("doubleField", "NaN"))) + ) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[FractionalRow].collect().sortBy(_.description).toList) + } + + test("With infinity from string") { + val seq = Seq( + ("01-Euler", "2.71", "2.71"), + ("02-Null", null, null), + ("03-Long", Long.MaxValue.toString, Long.MinValue.toString), + ("04-infinity", "-∞", "∞"), + ("05-Really big", "123456789123456791245678912324789123456789123456789.12", + "-1234567891234567912456789123247891234567891234567891234567891234567891234567891234567891234567891234567891234" + + "567891234567891234567891234567891234567891234567891234567891234567891234567891234567891234567891234678912345" + + "678912345678912345678912345679124567891232478912345678912345678912345678912345678912345679124567891232478912" + + "3456789123456789123456789123456789123456789123456789123456789.1"), + ("06-Text", "foo", "bar"), + ("07-Exponential notation", "-1.23E4", "+9.8765E-3") + ) + val src = seq.toDF("description","floatField", "doubleField") + logDataFrameContent(src) + + val std = Standardization.standardize(src, desiredSchemaWithInfinity).cache() + logDataFrameContent(std) + + val exp = Seq( + FractionalRow("01-Euler", Option(2.71F), Option(2.71)), + FractionalRow("02-Null", Option(0), None, Seq( + ErrorMessage.stdNullErr("floatField"))), + FractionalRow("03-Long", Option(9.223372E18F), Option(-9.223372036854776E18)), + FractionalRow("04-infinity", Some(Float.NegativeInfinity), Option(Double.PositiveInfinity)), + FractionalRow("05-Really big", Option(Float.PositiveInfinity), Option(Double.NegativeInfinity)), + FractionalRow("06-Text", Option(0), None, Seq( + ErrorMessage.stdCastErr("floatField", "foo"), + ErrorMessage.stdCastErr("doubleField", "bar"))), + FractionalRow("07-Exponential notation", Option(-12300.0f), Option(0.0098765)) + ) + + assertResult(exp)(std.as[FractionalRow].collect().sortBy(_.description).toList) + } + + test("With infinity from double") { + val reallyBig = Double.MaxValue + val seq = Seq( + new InputRowDoublesForFractional("01-Euler", Math.E), + InputRowDoublesForFractional("02-Null", None, None), + InputRowDoublesForFractional("03-Long", Option(Long.MaxValue.toFloat), Option(Long.MinValue.toDouble)), + InputRowDoublesForFractional("04-Infinity", Option(Float.NegativeInfinity), Option(Double.PositiveInfinity)), + new InputRowDoublesForFractional("05-Really big", reallyBig), + InputRowDoublesForFractional("06-NaN", Option(Float.NaN), Option(Double.NaN)) + ) + val src = spark.createDataFrame(seq) + logDataFrameContent(src) + + val exp = Seq( + FractionalRow("01-Euler", Option(Math.E.toFloat), Option(Math.E)), + FractionalRow("02-Null", Option(0), None, Seq( + ErrorMessage.stdNullErr("floatField"))), + FractionalRow("03-Long", Option(9.223372E18F), Option(-9.223372036854776E18)), + FractionalRow("04-Infinity", Option(Float.NegativeInfinity), Option(Double.PositiveInfinity)), + FractionalRow("05-Really big", Option(Float.PositiveInfinity), Option(reallyBig)), + FractionalRow("06-NaN", Option(0), None, Seq( + ErrorMessage.stdCastErr("floatField", "NaN"), + ErrorMessage.stdCastErr("doubleField", "NaN"))) + ) + + val std = Standardization.standardize(src, desiredSchemaWithInfinity).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[FractionalRow].collect().sortBy(_.description).toList) + } + + test("No pattern, but altered symbols") { + val input = Seq( + ("01-Positive", "+3"), + ("02-Negative", "~8123,4"), + ("03-Null", null), + ("04-Big", "7899012345678901234567890123456789012346789,123456789"), + ("05-Big II", "+1E40"), + ("06-Big III", "2E308"), + ("07-Small", "~7899012345678901234567890123456789012346789,123456789"), + ("08-Small II", "~1,1E40"), + ("09-Small III", "~3E308"), + ("10-Wrong", "hello"), + ("11-Infinity", "+∞"), + ("12-Negative Infinity", "~∞"), + ("13-Old decimal", "5.5"), + ("14-Old minus", "-10"), + ("15-Infinity as word", "Infinity") + ) + + val src = input.toDF("description", "src") + + val decimalSeparator = "," + val minusSign = "~" + val srcField = "src" + + val desiredSchemaWithAlters = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("src", StringType, nullable = true), + StructField("small", FloatType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("big", DoubleType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.DefaultValue, "+1000,001") + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("small_with_infinity", FloatType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DefaultValue, "~999999,9999") + .putString(MetadataKeys.AllowInfinity, "True") + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("big_with_infinity", DoubleType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.AllowInfinity, "True") + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()) + )) + + val std = Standardization.standardize(src, desiredSchemaWithAlters).cache() + logDataFrameContent(std) + + val exp = List( + ("01-Positive", "+3", 3.0F, Some(3.0D), Some(3.0F), 3.0D, Seq.empty), + ("02-Negative", "~8123,4", -8123.4F, Some(-8123.4D), Some(-8123.4F), -8123.4D, Seq.empty), + ("03-Null", null, 0F, None, None, 0D, Array.fill(2)(ErrorMessage.stdNullErr("src")).toList), + ("04-Big", "7899012345678901234567890123456789012346789,123456789", 0F, Some(7.899012345678901E42D), Some(Float.PositiveInfinity), 7.899012345678901E42, + err("7899012345678901234567890123456789012346789,123456789", 1) + ), + ("05-Big II", "+1E40", 0F, Some(1.0E40D), Some(Float.PositiveInfinity), 1.0E40D, err("+1E40", 1)), + ("06-Big III", "2E308", 0F, Some(1000.001D), Some(Float.PositiveInfinity), Double.PositiveInfinity, err("2E308", 2)), + ("07-Small", "~7899012345678901234567890123456789012346789,123456789", 0F, Some(-7.899012345678901E42D), Some(Float.NegativeInfinity), -7.899012345678901E42, + err("~7899012345678901234567890123456789012346789,123456789", 1) + ), + ("08-Small II", "~1,1E40", 0F, Some(-1.1E40D), Some(Float.NegativeInfinity), -1.1E40D, err("~1,1E40", 1)), + ("09-Small III", "~3E308", 0F, Some(1000.001D), Some(Float.NegativeInfinity), Double.NegativeInfinity, err("~3E308", 2)), + ("10-Wrong", "hello", 0F, Some(1000.001D), Some(-1000000.0F), 0D, err("hello", 4)), + ("11-Infinity", "+∞", 0F, Some(1000.001D), Some(Float.PositiveInfinity), Double.PositiveInfinity, err("+∞", 2)), + ("12-Negative Infinity", "~∞", 0F, Some(1000.001D), Some(Float.NegativeInfinity), Double.NegativeInfinity, err("~∞", 2)), + ("13-Old decimal", "5.5", 0F, Some(1000.001D), Some(-1000000.0F), 0D, err("5.5", 4)), + ("14-Old minus", "-10", 0F, Some(1000.001D), Some(-1000000.0F), 0D, err("-10", 4)), + ("15-Infinity as word", "Infinity", 0F, Some(1000.001D), Some(-1000000.0F), 0D, err("Infinity", 4)) + ) + assertResult(exp)(std.as[(String, String, Float, Option[Double], Option[Float], Double, Seq[ErrorMessage])].collect().toList) + } + + test("Using patterns") { + val input = Seq( + ("01-Positive", "+3°"), + ("02-Negative", "(8 123,4°)"), + ("03-Null", null), + ("04-Big", "+789 9012 345 678 901 234 567 890 123 456 789 012 346 789,123456789°"), + ("05-Big II", "+1E40°"), + ("06-Big III", "+2E308°"), + ("07-Small", "(789 9012 345 678 901 234 567 890 123 456 789 012 346 789,123456789°)"), + ("08-Small II", "(1,1E40°)"), + ("09-Small III", "(3E308°)"), + ("10-Wrong", "hello"), + ("11-Not adhering to pattern", "(1 234,56)"), + ("12-Not adhering to pattern II","+1,234.56°"), + ("13-Infinity", "+∞°"), + ("14-Negative Infinity", "(∞°)") + ) + + val src = input.toDF("description", "src") + + val pattern = "+#,000.#°;(#,000.#°)" + val decimalSeparator = "," + val groupingSeparator = " " + val srcField = "src" + + val desiredSchemaWithPatterns = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("src", StringType, nullable = true), + StructField("small", FloatType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .build()), + StructField("big", DoubleType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.DefaultValue, "+1 000,001°") + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .build()), + StructField("small_with_infinity", FloatType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DefaultValue, "(999 999,9999°)") + .putString(MetadataKeys.AllowInfinity, "True") + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .build()), + StructField("big_with_infinity", DoubleType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.AllowInfinity, "True") + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .build()) + )) + + val std = Standardization.standardize(src, desiredSchemaWithPatterns).cache() + logDataFrameContent(std) + + val exp = List( + ("01-Positive", "+3°", 3.0F, Some(3.0D), Some(3.0F), 3.0D, Seq.empty), + ("02-Negative", "(8 123,4°)", -8123.4F, Some(-8123.4D), Some(-8123.4F), -8123.4D, Seq.empty), + ("03-Null", null, 0F, None, None, 0D, Array.fill(2)(ErrorMessage.stdNullErr("src")).toList), + ("04-Big", "+789 9012 345 678 901 234 567 890 123 456 789 012 346 789,123456789°", 0F, Some(7.899012345678901E42D), Some(Float.PositiveInfinity), 7.899012345678901E42, + err("+789 9012 345 678 901 234 567 890 123 456 789 012 346 789,123456789°", 1) + ), + ("05-Big II", "+1E40°", 0F, Some(1.0E40D), Some(Float.PositiveInfinity), 1.0E40D, err("+1E40°", 1)), + ("06-Big III", "+2E308°", 0F, Some(1000.001D), Some(Float.PositiveInfinity), Double.PositiveInfinity, err("+2E308°", 2)), + ("07-Small", "(789 9012 345 678 901 234 567 890 123 456 789 012 346 789,123456789°)", 0F, Some(-7.899012345678901E42D), Some(Float.NegativeInfinity), -7.899012345678901E42, + err("(789 9012 345 678 901 234 567 890 123 456 789 012 346 789,123456789°)", 1) + ), + ("08-Small II", "(1,1E40°)", 0F, Some(-1.1E40D), Some(Float.NegativeInfinity), -1.1E40D, err("(1,1E40°)", 1)), + ("09-Small III", "(3E308°)", 0F, Some(1000.001D), Some(Float.NegativeInfinity), Double.NegativeInfinity, err("(3E308°)", 2)), + ("10-Wrong", "hello", 0F, Some(1000.001D), Some(-1000000.0F), 0D, err("hello", 4)), + ("11-Not adhering to pattern", "(1 234,56)", 0F, Some(1000.001D), Some(-1000000.0F), 0D, err("(1 234,56)", 4)), + ("12-Not adhering to pattern II","+1,234.56°", 0F, Some(1000.001D), Some(-1000000.0F), 0D, err("+1,234.56°", 4)), + ("13-Infinity", "+∞°", 0F, Some(1000.001D), Some(Float.PositiveInfinity), Double.PositiveInfinity, err("+∞°", 2)), + ("14-Negative Infinity", "(∞°)", 0F, Some(1000.001D), Some(Float.NegativeInfinity), Double.NegativeInfinity, err("(∞°)", 2)) + ) + + assertResult(exp)(std.as[(String, String, Float, Option[Double], Option[Float], Double, Seq[ErrorMessage])].collect().toList) + } +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_IntegralSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_IntegralSuite.scala new file mode 100644 index 0000000..0de1006 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_IntegralSuite.scala @@ -0,0 +1,554 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import java.text.{DecimalFormat, NumberFormat} +import java.util.Locale +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary +import za.co.absa.standardization.{ErrorMessage, LoggerTestBase, SparkTestBase, Standardization} + +class StandardizationInterpreter_IntegralSuite extends AnyFunSuite with SparkTestBase with LoggerTestBase{ + + import spark.implicits._ + + private implicit val udfLib: UDFLibrary = new UDFLibrary + private implicit val defaults: Defaults = GlobalDefaults + + private val pathToTestData = "src/test/resources/data/" + private val bigDecimalFormat = { + val pattern = "0.000000000000000000" //18 decimal places + val nf = NumberFormat.getNumberInstance(Locale.US) + val df = nf.asInstanceOf[DecimalFormat] + df.applyPattern(pattern) + df + } + + private val desiredSchema = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("bytesize", ByteType, nullable = false), + StructField("shortsize", ShortType, nullable = false), + StructField("integersize", IntegerType, nullable = true), + StructField("longsize", LongType, nullable = true) + )) + + private def err(value: String, cnt: Int): Seq[ErrorMessage] = { + val item = ErrorMessage.stdCastErr("src",value) + val array = Array.fill(cnt) (item) + array.toList + } + + test("Under-/overflow from CSV") { + val src = spark.read + .option("header", "true") + .csv(s"${pathToTestData}integral_overflow_test.csv") + logDataFrameContent(src) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val exp = Seq( + IntegralRow("Decimal entry", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "1.0"), + ErrorMessage.stdCastErr("shortsize", "2.0"), + ErrorMessage.stdCastErr("integersize", "3.0"), + ErrorMessage.stdCastErr("longsize", "4.0"))), + IntegralRow("Full negative", Option(-128), Option(-32768), Option(-2147483648), Option(-9223372036854775808L)), + IntegralRow("Full positive", Option(127), Option(32767), Option(2147483647), Option(9223372036854775807L)), + IntegralRow("Nulls", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdNullErr("bytesize"), + ErrorMessage.stdNullErr("shortsize"))), + IntegralRow("One", Option(1), Option(1), Option(1), Option(1)), + IntegralRow("Overflow", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "128"), + ErrorMessage.stdCastErr("shortsize", "32768"), + ErrorMessage.stdCastErr("integersize", "2147483648"), + ErrorMessage.stdCastErr("longsize", "9223372036854775808"))), + IntegralRow("Underflow", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "-129"), + ErrorMessage.stdCastErr("shortsize", "-32769"), + ErrorMessage.stdCastErr("integersize", "-2147483649"), + ErrorMessage.stdCastErr("longsize", "-9223372036854775809"))), + IntegralRow("With fractions", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "3.14"), + ErrorMessage.stdCastErr("shortsize", "2.71"), + ErrorMessage.stdCastErr("integersize", "1.41"), + ErrorMessage.stdCastErr("longsize", "1.5"))), + IntegralRow("With plus sign", Option(127), Option(32767), Option(2147483647), Option(9223372036854775807L)), + IntegralRow("With zeros", Option(0), Option(7), Option(-1), Option(0)) + ) + assertResult(exp)(std.as[IntegralRow].collect().sortBy(_.description).toList) + } + + test("Under-/overflow from JSON text") { + val src = spark.read.json(s"${pathToTestData}integral_overflow_test_text.json") + logDataFrameContent(src) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val exp = Seq( + IntegralRow("Decimal entry", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "1.0"), + ErrorMessage.stdCastErr("shortsize", "2.0"), + ErrorMessage.stdCastErr("integersize", "3.0"), + ErrorMessage.stdCastErr("longsize", "4.0"))), + IntegralRow("Full negative", Option(-128), Option(-32768), Option(-2147483648), Option(-9223372036854775808L)), + IntegralRow("Full positive", Option(127), Option(32767), Option(2147483647), Option(9223372036854775807L)), + IntegralRow("Nulls", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdNullErr("bytesize"), + ErrorMessage.stdNullErr("shortsize"))), + IntegralRow("One", Option(1), Option(1), Option(1), Option(1)), + IntegralRow("Overflow", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "128"), + ErrorMessage.stdCastErr("shortsize", "32768"), + ErrorMessage.stdCastErr("integersize", "2147483648"), + ErrorMessage.stdCastErr("longsize", "9223372036854775808"))), + IntegralRow("Underflow", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "-129"), + ErrorMessage.stdCastErr("shortsize", "-32769"), + ErrorMessage.stdCastErr("integersize", "-2147483649"), + ErrorMessage.stdCastErr("longsize", "-9223372036854775809"))), + IntegralRow("With fractions", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "3.14"), + ErrorMessage.stdCastErr("shortsize", "2.71"), + ErrorMessage.stdCastErr("integersize", "1.41"), + ErrorMessage.stdCastErr("longsize", "1.5"))), + IntegralRow("With plus sign", Option(127), Option(32767), Option(2147483647), Option(9223372036854775807L)), + IntegralRow("With zeros", Option(0), Option(7), Option(-1), Option(0)) + ) + assertResult(exp)(std.as[IntegralRow].collect().sortBy(_.description).toList) + } + + test("Under-/overflow from JSON numeric") { + val src = spark.read.json(s"${pathToTestData}integral_overflow_test_numbers.json") + logDataFrameContent(src) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val exp = Seq( + IntegralRow("Decimal entry", Option(0), Option(2), Option(3), Option(4), Seq( + ErrorMessage.stdCastErr("bytesize", "1.1"))), + IntegralRow("Full negative", Option(-128), Option(-32768), Option(-2147483648), None, Seq( + ErrorMessage.stdCastErr("longsize", "-9223372036854776000"))), + IntegralRow("Full positive", Option(127), Option(32767), Option(2147483647), None, Seq( + ErrorMessage.stdCastErr("longsize", "9223372036854776000"))), + IntegralRow("Nulls", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdNullErr("bytesize"), + ErrorMessage.stdNullErr("shortsize"))), + IntegralRow("One", Option(1), Option(1), Option(1), Option(1)), + IntegralRow("Overflow", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "128.0"), + ErrorMessage.stdCastErr("shortsize", "32768"), + ErrorMessage.stdCastErr("integersize", "2147483648"), + ErrorMessage.stdCastErr("longsize", "9223372036854776000"))), + IntegralRow("Underflow", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "-129.0"), + ErrorMessage.stdCastErr("shortsize", "-32769"), + ErrorMessage.stdCastErr("integersize", "-2147483649"), + ErrorMessage.stdCastErr("longsize", "-9223372036854776000"))) + ) + assertResult(exp)(std.as[IntegralRow].collect().sortBy(_.description).toList) + } + + test("Under-/overflow from strongly typed input - long") { + val src = spark.createDataFrame(Seq( + new InputRowLongsForIntegral("1-Byte", Byte.MaxValue), + new InputRowLongsForIntegral("2-Short", Short.MaxValue), + new InputRowLongsForIntegral("3-Int", Int.MaxValue), + new InputRowLongsForIntegral("4-Long", Long.MaxValue), + new InputRowLongsForIntegral("5-Byte", Byte.MinValue), + new InputRowLongsForIntegral("6-Short", Short.MinValue), + new InputRowLongsForIntegral("7-Int", Int.MinValue), + new InputRowLongsForIntegral("8-Long", Long.MinValue) + )) + logDataFrameContent(src) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val exp = Seq( + IntegralRow("1-Byte", Option(Byte.MaxValue), Option(Byte.MaxValue), Option(Byte.MaxValue), Option(Byte.MaxValue)), + IntegralRow("2-Short", Option(0), Option(Short.MaxValue), Option(Short.MaxValue), Option(Short.MaxValue), Seq( + ErrorMessage.stdCastErr("bytesize", Short.MaxValue.toString))), + IntegralRow("3-Int", Option(0), Option(0), Option(Int.MaxValue), Option(Int.MaxValue), Seq( + ErrorMessage.stdCastErr("bytesize", Int.MaxValue.toString), + ErrorMessage.stdCastErr("shortsize", Int.MaxValue.toString))), + IntegralRow("4-Long", Option(0), Option(0), None, Option(Long.MaxValue), Seq( + ErrorMessage.stdCastErr("bytesize", Long.MaxValue.toString), + ErrorMessage.stdCastErr("shortsize", Long.MaxValue.toString), + ErrorMessage.stdCastErr("integersize", Long.MaxValue.toString))), + IntegralRow("5-Byte", Option(Byte.MinValue), Option(Byte.MinValue), Option(Byte.MinValue), Option(Byte.MinValue)), + IntegralRow("6-Short", Option(0), Option(Short.MinValue), Option(Short.MinValue), Option(Short.MinValue), Seq( + ErrorMessage.stdCastErr("bytesize", Short.MinValue.toString))), + IntegralRow("7-Int", Option(0), Option(0), Option(Int.MinValue), Option(Int.MinValue), Seq( + ErrorMessage.stdCastErr("bytesize", Int.MinValue.toString), + ErrorMessage.stdCastErr("shortsize", Int.MinValue.toString))), + IntegralRow("8-Long", Option(0), Option(0), None, Option(Long.MinValue), Seq( + ErrorMessage.stdCastErr("bytesize", Long.MinValue.toString), + ErrorMessage.stdCastErr("shortsize", Long.MinValue.toString), + ErrorMessage.stdCastErr("integersize", Long.MinValue.toString))) + ) + assertResult(exp)(std.as[IntegralRow].collect().sortBy(_.description).toList) + } + + test("Under-/overflow and precision lost from strongly typed input - double") { + + val reallyBig: Double = 24578754548798454658754546785454.0 + val tinyFractionalPart: Double = 1.000000000000001 + val seq: Seq[InputRowDoublesForIntegral] = Seq( + new InputRowDoublesForIntegral("00-One", 1), + new InputRowDoublesForIntegral("01-Byte", Byte.MaxValue.toDouble), + new InputRowDoublesForIntegral("02-Short", Short.MaxValue.toDouble), + new InputRowDoublesForIntegral("03-Int", Int.MaxValue.toDouble), + new InputRowDoublesForIntegral("04-Long", Long.MaxValue.toDouble), + new InputRowDoublesForIntegral("05-Byte", Byte.MinValue.toDouble), + new InputRowDoublesForIntegral("06-Short", Short.MinValue.toDouble), + new InputRowDoublesForIntegral("07-Int", Int.MinValue.toDouble), + new InputRowDoublesForIntegral("08-Long", Long.MinValue.toDouble), + new InputRowDoublesForIntegral("09-Pi", Math.PI), + new InputRowDoublesForIntegral("10-Whole", 7.00), + new InputRowDoublesForIntegral("11-Really small", Double.MinPositiveValue), + new InputRowDoublesForIntegral("12-Really big", reallyBig), + new InputRowDoublesForIntegral("13-Tiny fractional part", tinyFractionalPart), + new InputRowDoublesForIntegral("14-NaN", Double.NaN), + InputRowDoublesForIntegral("15-Null", None, None, None, None) + ) + + val src = spark.createDataFrame(seq) + logDataFrameContent(src) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val exp = Seq( + IntegralRow("00-One", Option(1), Option(1), Option(1), Option(1)), + IntegralRow("01-Byte", Option(Byte.MaxValue), Option(Byte.MaxValue), Option(Byte.MaxValue), Option(Byte.MaxValue)), + IntegralRow("02-Short", Option(0), Option(Short.MaxValue), Option(Short.MaxValue), Option(Short.MaxValue), Seq( + ErrorMessage.stdCastErr("bytesize", Short.MaxValue.toDouble.toString))), + IntegralRow("03-Int", Option(0), Option(0), Option(Int.MaxValue), Option(Int.MaxValue), Seq( + ErrorMessage.stdCastErr("bytesize", Int.MaxValue.toDouble.toString), + ErrorMessage.stdCastErr("shortsize", Int.MaxValue.toDouble.toString))), + IntegralRow("04-Long", Option(0), Option(0), None, Option(Long.MaxValue), Seq( + ErrorMessage.stdCastErr("bytesize", Long.MaxValue.toDouble.toString), + ErrorMessage.stdCastErr("shortsize", Long.MaxValue.toDouble.toString), + ErrorMessage.stdCastErr("integersize", Long.MaxValue.toDouble.toString))), + IntegralRow("05-Byte", Option(Byte.MinValue), Option(Byte.MinValue), Option(Byte.MinValue), Option(Byte.MinValue)), + IntegralRow("06-Short", Option(0), Option(Short.MinValue), Option(Short.MinValue), Option(Short.MinValue), Seq( + ErrorMessage.stdCastErr("bytesize", Short.MinValue.toDouble.toString))), + IntegralRow("07-Int", Option(0), Option(0), Option(Int.MinValue), Option(Int.MinValue), Seq( + ErrorMessage.stdCastErr("bytesize", Int.MinValue.toDouble.toString), + ErrorMessage.stdCastErr("shortsize", Int.MinValue.toDouble.toString))), + IntegralRow("08-Long", Option(0), Option(0), None, Option(Long.MinValue), Seq( + ErrorMessage.stdCastErr("bytesize", Long.MinValue.toDouble.toString), + ErrorMessage.stdCastErr("shortsize", Long.MinValue.toDouble.toString), + ErrorMessage.stdCastErr("integersize", Long.MinValue.toDouble.toString))), + IntegralRow("09-Pi", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", Math.PI.toString), + ErrorMessage.stdCastErr("shortsize", Math.PI.toString), + ErrorMessage.stdCastErr("integersize", Math.PI.toString), + ErrorMessage.stdCastErr("longsize", Math.PI.toString))), + IntegralRow("10-Whole", Option(7), Option(7), Option(7), Option(7)), + IntegralRow("11-Really small", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", Double.MinPositiveValue.toString), + ErrorMessage.stdCastErr("shortsize", Double.MinPositiveValue.toString), + ErrorMessage.stdCastErr("integersize", Double.MinPositiveValue.toString), + ErrorMessage.stdCastErr("longsize", Double.MinPositiveValue.toString))), + IntegralRow("12-Really big", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", reallyBig.toString), + ErrorMessage.stdCastErr("shortsize", reallyBig.toString), + ErrorMessage.stdCastErr("integersize", reallyBig.toString), + ErrorMessage.stdCastErr("longsize", reallyBig.toString))), + IntegralRow("13-Tiny fractional part", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", tinyFractionalPart.toString), + ErrorMessage.stdCastErr("shortsize", tinyFractionalPart.toString), + ErrorMessage.stdCastErr("integersize", tinyFractionalPart.toString), + ErrorMessage.stdCastErr("longsize", tinyFractionalPart.toString))), + IntegralRow("14-NaN", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", "NaN"), + ErrorMessage.stdCastErr("shortsize", "NaN"), + ErrorMessage.stdCastErr("integersize", "NaN"), + ErrorMessage.stdCastErr("longsize", "NaN"))), + IntegralRow("15-Null", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdNullErr("bytesize"), + ErrorMessage.stdNullErr("shortsize"))) + ) + + assertResult(exp)(std.as[IntegralRow].collect().sortBy(_.description).toList) + } + + test("Under-/overflow and precision lost from strongly typed input - decimal") { + def formatBigDecimal(bd: BigDecimal): String = { + bigDecimalFormat.format(bd) + } + + val pi: BigDecimal = Math.PI + val tinyFractionalPart: BigDecimal = BigDecimal("1.000000000000000001") + val reallyBig: BigDecimal = BigDecimal(Long.MaxValue)*2 + val reallySmall: BigDecimal = BigDecimal(Long.MinValue)*2 + val shortOverflow: BigDecimal = Short.MaxValue + 1 + + //formating is nto prices for these + val tinyFractionalPartStr = "1.000000000000000001" + val reallyBigStr = "18446744073709551614.000000000000000000" + val reallySmallStr = "-18446744073709551616.000000000000000000" + + val seq: Seq[InputRowBigDecimalsForIntegral] = Seq( + new InputRowBigDecimalsForIntegral("00-One", 1.0), + new InputRowBigDecimalsForIntegral("01-Pi", pi), + new InputRowBigDecimalsForIntegral("02-Tiny fractional part", tinyFractionalPart), + new InputRowBigDecimalsForIntegral("03-Really big", reallyBig), + new InputRowBigDecimalsForIntegral("04-Really small", reallySmall), + new InputRowBigDecimalsForIntegral("05-Short", shortOverflow), + new InputRowBigDecimalsForIntegral("06-Null", null) + ) + val src = spark.createDataFrame(seq) + logDataFrameContent(src) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + val exp = Seq( + IntegralRow("00-One", Option(1), Option(1), Option(1), Option(1)), + IntegralRow("01-Pi", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", formatBigDecimal(pi)), + ErrorMessage.stdCastErr("shortsize", formatBigDecimal(pi)), + ErrorMessage.stdCastErr("integersize", formatBigDecimal(pi)), + ErrorMessage.stdCastErr("longsize", formatBigDecimal(pi)))), + IntegralRow("02-Tiny fractional part", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", tinyFractionalPartStr), + ErrorMessage.stdCastErr("shortsize", tinyFractionalPartStr), + ErrorMessage.stdCastErr("integersize", tinyFractionalPartStr), + ErrorMessage.stdCastErr("longsize", tinyFractionalPartStr))), + IntegralRow("03-Really big", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", reallyBigStr), + ErrorMessage.stdCastErr("shortsize", reallyBigStr), + ErrorMessage.stdCastErr("integersize", reallyBigStr), + ErrorMessage.stdCastErr("longsize", reallyBigStr))), + IntegralRow("04-Really small", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdCastErr("bytesize", reallySmallStr), + ErrorMessage.stdCastErr("shortsize", reallySmallStr), + ErrorMessage.stdCastErr("integersize", reallySmallStr), + ErrorMessage.stdCastErr("longsize", reallySmallStr))), + IntegralRow("05-Short", Option(0), Option(0), Option(Short.MaxValue + 1), Option(Short.MaxValue + 1), Seq( + ErrorMessage.stdCastErr("bytesize", formatBigDecimal(shortOverflow)), + ErrorMessage.stdCastErr("shortsize", formatBigDecimal(shortOverflow)))), + IntegralRow("06-Null", Option(0), Option(0), None, None, Seq( + ErrorMessage.stdNullErr("bytesize"), + ErrorMessage.stdNullErr("shortsize"))) + ) + + assertResult(exp)(std.as[IntegralRow].collect().sortBy(_.description).toList) + } + + test("No pattern, but altered symbols") { + val input = Seq( + ("01-Normal", "3"), + ("02-Null", null), + ("03-Far negative", "^100000000"), + ("04-Wrong", "hello") + ) + val decimalSeparator = "," + val groupingSeparator = "." + val minusSign = "^" + val srcField = "src" + + val src = input.toDF("description", srcField) + + val desiredSchemaWithAlters = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("src", StringType, nullable = true), + StructField("bf", ByteType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("sf", ShortType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.DefaultValue, "^1") + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("if", IntegerType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("lf", LongType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DefaultValue, "1000") + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()) + )) + + val std = Standardization.standardize(src, desiredSchemaWithAlters).cache() + logDataFrameContent(std) + + val exp = List( + ("01-Normal", "3", 3, Some(3), Some(3), 3, Seq.empty), + ("02-Null", null, 0, None, None, 1000, Array.fill(2)(ErrorMessage.stdNullErr(srcField)).toList), + ("03-Far negative", "^100000000", 0, Some(-1), Some(-100000000), -100000000, err("^100000000", 2)), + ("04-Wrong", "hello", 0, Some(-1), None, 1000, err("hello", 4)) + ) + + assertResult(exp)(std.as[(String, String, Byte, Option[Short], Option[Int], Long, Seq[ErrorMessage])].collect().toList) + } + + test("Using patterns") { + val input = Seq( + ("01-Normal", "3 feet"), + ("02-Null", null), + ("03-Far negative", "^100.000.000 feet"), + ("04-Wrong", "hello"), + ("05-Not adhering to pattern", "123,456,789 feet") + ) + val pattern = "#,##0 feet" + val decimalSeparator = "," + val groupingSeparator = "." + val minusSign = "^" + val srcField = "src" + + val src = input.toDF("description", srcField) + + val desiredSchemaWithPatterns = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("src", StringType, nullable = true), + StructField("bf", ByteType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("sf", ShortType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.DefaultValue, "^1 feet") + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("if", IntegerType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("lf", LongType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.Pattern, pattern) + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DefaultValue, "1.000 feet") + .putString(MetadataKeys.DecimalSeparator, decimalSeparator) + .putString(MetadataKeys.GroupingSeparator, groupingSeparator) + .putString(MetadataKeys.MinusSign, minusSign) + .build()) + )) + + val std = Standardization.standardize(src, desiredSchemaWithPatterns).cache() + logDataFrameContent(std) + + val exp = List( + ("01-Normal", "3 feet", 3, Some(3), Some(3), 3, Seq.empty), + ("02-Null", null, 0, None, None, 1000, Array.fill(2)(ErrorMessage.stdNullErr(srcField)).toList), + ("03-Far negative", "^100.000.000 feet", 0, Some(-1), Some(-100000000), -100000000, err("^100.000.000 feet", 2)), + ("04-Wrong", "hello", 0, Some(-1), None, 1000, err("hello", 4)), + ("05-Not adhering to pattern", "123,456,789 feet", 0, Some(-1), None, 1000, err("123,456,789 feet", 4)) + ) + + assertResult(exp)(std.as[(String, String, Byte, Option[Short], Option[Int], Long, Seq[ErrorMessage])].collect().toList) + } + + test("Changed Radix") { + val input = Seq( + ("00-Null", null), + ("01-Binary", "+1101"), + ("02-Binary negative", "§1001"), + ("03-Septary", "35"), + ("04-Septary negative", "§103"), + ("05-Hex", "FF"), + ("06-Hex negative", "§A1"), + ("07-Hex 0x", "+0xB6"), + ("08-Hex 0x negative", "§0x3c"), + ("09-Radix 27", "Hello"), + ("10-Radix 27 negative", "§Mail"), + ("11-Wrong for all", "0XoXo") + ) + val srcField = "src" + val minusSign = "§" + + val src = input.toDF("description", srcField) + + val desiredSchemaWithAlters = StructType(Seq( + StructField("description", StringType, nullable = false), + StructField("src", StringType, nullable = true), + StructField("bf", ByteType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.Radix, "2") + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("sf", ShortType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.DefaultValue, "§13") //NB 13 is 10 in decimal base + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.Radix, "7") + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("if", IntegerType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.Radix, "16") + .putString(MetadataKeys.MinusSign, minusSign) + .build()), + StructField("lf", LongType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.SourceColumn, srcField) + .putString(MetadataKeys.DefaultValue, "Ada") // NB Ada is 7651 in decimal base + .putString(MetadataKeys.Radix, "27") + .putString(MetadataKeys.MinusSign, minusSign) + .build()) + )) + + val std = Standardization.standardize(src, desiredSchemaWithAlters).cache() + logDataFrameContent(std) + + val exp = List( + ("00-Null" , null , 0 , None , None , 7651 , Array.fill(2)(ErrorMessage.stdNullErr(srcField)).toList), + ("01-Binary" , "+1101", 13, Some(393) , Some(4353) , 20413 , Seq.empty), + ("02-Binary negative" , "§1001", -9, Some(-344), Some(-4097), -19684 , Seq.empty), + ("03-Septary" , "35" , 0 , Some(26) , Some(53) , 86 , err("35", 1)), + ("04-Septary negative" , "§103" , 0 , Some(-52) , Some(-259) , -732 , err("§103", 1)), + ("05-Hex" , "FF" , 0 , Some(-10) , Some(255) , 420 , err("FF", 2)), + ("06-Hex negative" , "§A1" , 0 , Some(-10) , Some(-161) , -271 , err("§A1", 2)), + ("07-Hex 0x" , "+0xB6", 0 , Some(-10) , Some(182) , 7651 , err("+0xB6", 3)), + ("08-Hex 0x negative" , "§0x3c", 0 , Some(-10) , Some(-60) , 7651 , err("§0x3c", 3)), + ("09-Radix 27" , "Hello" , 0, Some(-10) , None , 9325959, err("Hello", 3)), + ("10-Radix 27 negative", "§Mail", 0 , Some(-10) , None , -440823, err("§Mail", 3)), + ("11-Wrong for all" , "0XoXo", 0 , Some(-10) , None , 7651 , err("0XoXo", 4)) + ) + + assertResult(exp)(std.as[(String, String, Byte, Option[Short], Option[Int], Long, Seq[ErrorMessage])].collect().toList) + } + + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_TimestampSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_TimestampSuite.scala new file mode 100644 index 0000000..f1808cc --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/StandardizationInterpreter_TimestampSuite.scala @@ -0,0 +1,368 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import java.sql.Timestamp +import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType, TimestampType} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary +import za.co.absa.standardization.{ErrorMessage, LoggerTestBase, SparkTestBase, Standardization} + +class StandardizationInterpreter_TimestampSuite extends AnyFunSuite with SparkTestBase with LoggerTestBase { + import spark.implicits._ + + private implicit val udfLib: UDFLibrary = new UDFLibrary + private implicit val defaults: Defaults = GlobalDefaults + + private val fieldName = "tms" + + test("epoch") { + val seq = Seq( + 0, + 86400, + 978307199, + 1563288103 + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder().putString("pattern", "epoch").build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59")), + TimestampRow(Timestamp.valueOf("2019-07-16 14:41:43")) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[TimestampRow].collect().toList) + } + + test("epochmilli") { + val seq = Seq( + "0.0", + "86400000.5", + "978307199999.05", + "1563288103123.005", + "-86400000", + "Fail" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder().putString("pattern", "epochmilli").build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00.0005")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59.99905")), + TimestampRow(Timestamp.valueOf("2019-07-16 14:41:43.123005")), + TimestampRow(Timestamp.valueOf("1969-12-31 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "Fail"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[TimestampRow].collect().toList) + } + + test("epochmicro") { + val seq = Seq( + 0L, + 86400000000L, + 978307199999999L, + 1563288103123456L + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder().putString("pattern", "epochmicro").build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59.999999")), + TimestampRow(Timestamp.valueOf("2019-07-16 14:41:43.123456")) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[TimestampRow].collect().toList) + } + + test("epochnano") { + val seq = Seq( + 0, + 86400000000000L, + 978307199999999999L, + 1563288103123456789L + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder().putString("pattern", "epochnano").build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59.999999000")), + TimestampRow(Timestamp.valueOf("2019-07-16 14:41:43.123456000")) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[TimestampRow].collect().toList) + } + + test("pattern up to seconds precision") { + val seq = Seq( + "01.01.1970 00-00-00", + "02.01.1970 00-00-00", + "31.12.2000 23-59-59", + "16.07.2019 14-41-43", + "02.02.1970_00-00-00", + "nope" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder().putString("pattern", "dd.MM.yyyy HH-mm-ss").build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59")), + TimestampRow(Timestamp.valueOf("2019-07-16 14:41:43")), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "02.02.1970_00-00-00"))), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "nope"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[TimestampRow].collect().toList) + } + + test("pattern up to seconds precision with default time zone") { + val seq = Seq( + "31.12.1969 19-00-00", + "01.01.1970 19-00-00", + "31.12.2000 18-59-59", + "29.02.2004 24-00-00", + "16.07.2019 09-41-43", + "02.02.1970_24-00-00", + "nope" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder() + .putString("pattern", "dd.MM.yyyy kk-mm-ss") + .putString("timezone", "EST") + .build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59")), + TimestampRow(Timestamp.valueOf("2004-02-29 05:00:00")), + TimestampRow(Timestamp.valueOf("2019-07-16 14:41:43")), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "02.02.1970_24-00-00"))), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "nope"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[TimestampRow].collect().toList) + } + + test("pattern up to milliseconds precision and with offset time zone") { + val seq = Seq( + "1970 01 01 01 00 00 000 +01:00", + "1970 01 02 03 30 00 001 +03:30", + "2000 12 31 23 59 59 999 +00:00", + "2019 07 16 08 41 43 123 -06:00", + "1970 02 02 00 00 00 112", + "nope" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder().putString("pattern", "yyyy MM dd HH mm ss SSS XXX").build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00.001")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59.999")), + TimestampRow(Timestamp.valueOf("2019-07-16 14:41:43.123")), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "1970 02 02 00 00 00 112"))), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "nope"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[TimestampRow].collect().toList) + + } + + test("pattern up to microseconds precision and with default time zone") { + val seq = Seq( + "01011970 010000.000000", + "02011970 010000.000001", + "01012001 005959.999999", + "16072019 164143.123456", + "02011970 010000 000001", + "nope" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder() + .putString("pattern", "ddMMyyyy HHmmss.iiiiii") + .putString("timezone", "CET") + .build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00.000001")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59.999999")), + TimestampRow(Timestamp.valueOf("2019-07-16 14:41:43.123456")), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "02011970 010000 000001"))), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "nope"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[TimestampRow].collect().toList) + + } + + test("pattern up to nanoseconds precision, no time zone") { + val seq = Seq( + "(000000) 01/01/1970 AM+00:00:00~000", + "(002003) 02/01/1970 am+00:00:00~001", + "(999999) 31/12/2000 PM+11:59:59~999", + "(456789) 16/07/2019 Pm+02:41:43~123", + "02/01/1970 00:00:00 001", + "nope" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder().putString("pattern", "(iiinnn) dd/MM/yyyy aa+KK:mm:ss~SSS").build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00.001002")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59.999999")), + TimestampRow(Timestamp.valueOf("2019-07-16 14:41:43.123456")), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "02/01/1970 00:00:00 001"))), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "nope"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[TimestampRow].collect().toList) + + } + + test("pattern up to nanoseconds precision and named time zone") { + val seq = Seq( + "(000000) 01/01/1970 01:00:00.000 CET", + "(001002) 02/01/1970 08:45:00.003 ACWST", + "(999999) 31/12/2000 15:59:59.999 PST", + "(456789) 16/07/2019 16:41:43.123 EET", + "( ) 02/01/1970 01:00:00.000 CET", + "nope" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder().putString("pattern", "(iiinnn) dd/MM/yyyy HH:mm:ss.SSS ZZ").build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00.003001")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59.999999")), + TimestampRow(Timestamp.valueOf("2019-07-16 14:41:43.123456")), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "( ) 02/01/1970 01:00:00.000 CET"))), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "nope"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + assertResult(exp)(std.as[TimestampRow].collect().toList) + + } + + /* TODO this should work with #7 fixed (originally Enceladus#677) + test("pattern with literal and less common placeholders") { + val seq = Seq( + "70001 star [000] 12:00:00(aM) @000000", + "70002 star [001] 01:00:00(pM) @002003", + "00365 star [999] 11:59:59(pM) @999999", + "80040 star [123] 02:41:43(PM) @456789", + "70002 staT [000] 12:00:00(aM) @000000", + "nope" + ) + val desiredSchema = StructType(Seq( + StructField(fieldName, TimestampType, nullable = false, + new MetadataBuilder().putString("pattern", "yyDDD 'star' [iii] aa hh:mm:ss(aa)@nnnSSS").build) + )) + val exp = Seq( + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00")), + TimestampRow(Timestamp.valueOf("1970-01-02 00:00:00.003001")), + TimestampRow(Timestamp.valueOf("2000-12-31 23:59:59.999999")), + TimestampRow(Timestamp.valueOf("1980-02-09 14:41:43.789123")), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "70002 staT [000] 12:00:00(aM) @000000"))), + TimestampRow(Timestamp.valueOf("1970-01-01 00:00:00"), Seq(ErrorMessage.stdCastErr(fieldName, "nope"))) + ) + + val src = seq.toDF(fieldName) + + val std = Standardization.standardize(src, desiredSchema).cache() + logDataFrameContent(std) + + std.show(false) + std.printSchema() + assertResult(exp)(std.as[TimestampRow].collect().toList) + } + */ + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/PlainSchemaGeneratorSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/PlainSchemaGeneratorSuite.scala new file mode 100644 index 0000000..037a7a4 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/PlainSchemaGeneratorSuite.scala @@ -0,0 +1,71 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.SparkTestBase +import za.co.absa.standardization.stages.PlainSchemaGenerator + +class PlainSchemaGeneratorSuite extends AnyFunSuite with SparkTestBase { + private val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false, new MetadataBuilder().putString("meta", "data").build), + StructField("c", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "override_c").build), + StructField("d", ArrayType(StructType(Seq( + StructField("e", StructType(Seq( + StructField("f", ArrayType(StructType(Seq( + StructField("g", IntegerType, nullable = false), + StructField("h", IntegerType, nullable = false, new MetadataBuilder().putString("meta", "data").build), + StructField("i", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "override_i").build) + )))) + ))) + )))) + )) + + private val expectedSchemaSeq = Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true, new MetadataBuilder().putString("meta", "data").build), + StructField("override_c", StringType, nullable = true, new MetadataBuilder().putString("sourcecolumn", "override_c").build), + StructField("d", ArrayType(StructType(Seq( + StructField("e", StructType(Seq( + StructField("f", ArrayType(StructType(Seq( + StructField("g", StringType, nullable = true), + StructField("h", StringType, nullable = true, new MetadataBuilder().putString("meta", "data").build), + StructField("override_i", StringType, nullable = true, new MetadataBuilder().putString("sourcecolumn", "override_i").build) + )))) + ))) + )))) + ) + + private val expectedSchema = StructType(expectedSchemaSeq) + + private val expectedSchemaWithErrorColumn = StructType( expectedSchemaSeq ++ Seq( + StructField("_error_column", StringType, nullable = true) + )) + + test("Test generateInputSchema") { + val generatedSchema = PlainSchemaGenerator.generateInputSchema(schema) + assertResult(expectedSchema)(generatedSchema) + } + + test("Test generateInputSchema with error column") { + val generatedSchema = PlainSchemaGenerator.generateInputSchema(schema, Option("_error_column")) + assertResult(expectedSchemaWithErrorColumn)(generatedSchema) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/SchemaCheckerSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/SchemaCheckerSuite.scala new file mode 100644 index 0000000..4655df8 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/SchemaCheckerSuite.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import org.apache.spark.sql.types.{DataType, StructType} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.stages.SchemaChecker +import za.co.absa.standardization.{FileReader, SparkTestBase} + +class SchemaCheckerSuite extends AnyFunSuite with SparkTestBase { + test("Bug") { + val sourceFile = FileReader.readFileAsString("src/test/resources/data/bug.json") + val schema = DataType.fromJson(sourceFile).asInstanceOf[StructType] + val output = SchemaChecker.validateSchemaAndLog(schema) + val expected = ( + List( + "Validation error for column 'Conformed_TXN_TIMESTAMP', pattern 'yyyy-MM-ddTHH:mm:ss.SSSX': Illegal pattern character 'T'" + ), + List() + ) + assert(output == expected) + } +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParserSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParserSuite.scala new file mode 100644 index 0000000..bbd5662 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParserSuite.scala @@ -0,0 +1,55 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.SparkTestBase +import za.co.absa.standardization.stages.TypeParser +import za.co.absa.standardization.types.{Defaults, GlobalDefaults} +import za.co.absa.standardization.udf.UDFLibrary + +class TypeParserSuite extends AnyFunSuite with SparkTestBase { + + private implicit val udfLib: UDFLibrary = new UDFLibrary + private implicit val defaults: Defaults = GlobalDefaults + + test("Test standardize with sourcecolumn metadata") { + val structFieldNoMetadata = StructField("a", StringType) + val structFieldWithMetadataNotSourceColumn = StructField("b", StringType, nullable = false, new MetadataBuilder().putString("meta", "data").build) + val structFieldWithMetadataSourceColumn = StructField("c", StringType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "override_c").build) + val schema = StructType(Array(structFieldNoMetadata, structFieldWithMetadataNotSourceColumn, structFieldWithMetadataSourceColumn)) + //Just Testing field name override + val parseOutputStructFieldNoMetadata = TypeParser.standardize(structFieldNoMetadata, "path", schema) + assertResult(true)(parseOutputStructFieldNoMetadata.stdCol.expr.toString().contains("path.a")) + assertResult(false)(parseOutputStructFieldNoMetadata.stdCol.expr.toString().replaceAll("path.a", "").contains("path")) + assertResult(true)(parseOutputStructFieldNoMetadata.errors.expr.toString().contains("path.a")) + assertResult(false)(parseOutputStructFieldNoMetadata.errors.expr.toString().replaceAll("path.a", "").contains("path")) + val parseOutputStructFieldWithMetadataNotSourceColumn = TypeParser.standardize(structFieldWithMetadataNotSourceColumn, "path", schema) + assertResult(true)(parseOutputStructFieldWithMetadataNotSourceColumn.stdCol.expr.toString().contains("path.b")) + assertResult(false)(parseOutputStructFieldWithMetadataNotSourceColumn.stdCol.expr.toString().replaceAll("path.b", "").contains("path")) + assertResult(true)(parseOutputStructFieldWithMetadataNotSourceColumn.errors.expr.toString().contains("path.b")) + assertResult(false)(parseOutputStructFieldWithMetadataNotSourceColumn.errors.expr.toString().replaceAll("path.b", "").contains("path")) + val parseOutputStructFieldWithMetadataSourceColumn = TypeParser.standardize(structFieldWithMetadataSourceColumn, "path",schema) + assertResult(false)(parseOutputStructFieldWithMetadataSourceColumn.stdCol.expr.toString().contains("path.c")) + assertResult(true)(parseOutputStructFieldWithMetadataSourceColumn.stdCol.expr.toString().contains("path.override_c")) + assertResult(false)(parseOutputStructFieldWithMetadataSourceColumn.stdCol.expr.toString().replaceAll("path.override_c", "").contains("path")) + assertResult(false)(parseOutputStructFieldWithMetadataSourceColumn.errors.expr.toString().contains("path.c")) + assertResult(true)(parseOutputStructFieldWithMetadataSourceColumn.errors.expr.toString().contains("path.override_c")) + assertResult(false)(parseOutputStructFieldWithMetadataSourceColumn.errors.expr.toString().replaceAll("path.override_c", "").contains("path")) + } +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParserSuiteTemplate.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParserSuiteTemplate.scala new file mode 100644 index 0000000..dc62e7a --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParserSuiteTemplate.scala @@ -0,0 +1,263 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import java.security.InvalidParameterException +import java.sql.{Date, Timestamp} +import org.apache.log4j.{LogManager, Logger} +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.interpreter.stages.TypeParserSuiteTemplate._ +import za.co.absa.standardization.SparkTestBase +import za.co.absa.standardization.stages.TypeParser +import za.co.absa.standardization.time.DateTimePattern +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, ParseOutput, TypedStructField} +import za.co.absa.standardization.udf.UDFLibrary + +trait TypeParserSuiteTemplate extends AnyFunSuite with SparkTestBase { + + private implicit val udfLib: UDFLibrary = new UDFLibrary + private implicit val defaults: Defaults = GlobalDefaults + + protected def createCastTemplate(toType: DataType, pattern: String, timezone: Option[String]): String + protected def createErrorCondition(srcField: String, target: StructField, castS: String):String + + private val sourceFieldName = "sourceField" + + protected val log: Logger = LogManager.getLogger(this.getClass) + + protected def doTestWithinColumnNullable(input: Input): Unit = { + import input._ + val nullable = true + val field = sourceField(baseType, nullable) + val schema = buildSchema(Array(field), path) + testTemplate(field, schema, path) + } + + protected def doTestWithinColumnNotNullable(input: Input): Unit = { + import input._ + val nullable = false + val field = sourceField(baseType, nullable) + val schema = buildSchema(Array(field), path) + testTemplate(field, schema, path) + } + + protected def doTestIntoStringField(input: Input): Unit = { + import input._ + val stringField = StructField("stringField", StringType, nullable = false, + new MetadataBuilder().putString("sourcecolumn",sourceFieldName).build) + val schema = buildSchema(Array(sourceField(baseType), stringField), path) + testTemplate(stringField, schema, path) + } + + protected def doTestIntoFloatField(input: Input): Unit = { + import input._ + val floatField = StructField("floatField", FloatType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).build) + val schema = buildSchema(Array(sourceField(baseType), floatField), path) + testTemplate(floatField, schema, path) + } + + protected def doTestIntoIntegerField(input: Input): Unit = { + import input._ + val integerField = StructField("integerField", IntegerType, nullable = true, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).build) + val schema = buildSchema(Array(sourceField(baseType), integerField), path) + testTemplate(integerField, schema, path) + } + + protected def doTestIntoBooleanField(input: Input): Unit = { + import input._ + val booleanField = StructField("booleanField", BooleanType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).build) + val schema = buildSchema(Array(sourceField(baseType), booleanField), path) + testTemplate(booleanField, schema, path) + } + + protected def doTestIntoDateFieldNoPattern(input: Input): Unit = { + import input._ + val dateField = StructField("dateField", DateType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).build) + val schema = buildSchema(Array(sourceField(baseType), dateField), path) + + if (datetimeNeedsPattern) { + val errMessage = s"Dates & times represented as ${baseType.typeName} values need specified 'pattern' metadata" + val caughtErr = intercept[InvalidParameterException] { + TypeParser.standardize(dateField, path, schema) + } + assert(caughtErr.getMessage == errMessage) + } else { + testTemplate(dateField, schema, path, "yyyy-MM-dd") + } + } + + protected def doTestIntoTimestampFieldNoPattern(input: Input): Unit = { + import input._ + val timestampField = StructField("timestampField", TimestampType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).build) + val schema = buildSchema(Array(sourceField(baseType), timestampField), path) + + if (datetimeNeedsPattern) { + val errMessage = s"Dates & times represented as ${baseType.typeName} values need specified 'pattern' metadata" + val caughtErr = intercept[InvalidParameterException] { + TypeParser.standardize(timestampField, path, schema) + } + assert(caughtErr.getMessage == errMessage) + } else { + testTemplate(timestampField, schema, path, "yyyy-MM-dd HH:mm:ss") + } + } + + protected def doTestIntoDateFieldWithPattern(input: Input): Unit = { + import input._ + val dateField = StructField("dateField", DateType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).putString("pattern", datePattern).build) + val schema = buildSchema(Array(sourceField(baseType), dateField), path) + testTemplate(dateField, schema, path, datePattern) + } + + protected def doTestIntoTimestampFieldWithPattern(input: Input): Unit = { + import input._ + val timestampField = StructField("timestampField", TimestampType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).putString("pattern", timestampPattern).build) + val schema = buildSchema(Array(sourceField(baseType), timestampField), path) + testTemplate(timestampField, schema, path, timestampPattern) + } + + protected def doTestIntoDateFieldWithPatternAndDefault(input: Input): Unit = { + import input._ + val dateField = StructField("dateField", DateType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).putString("pattern", datePattern).putString("default", defaultValueDate).build) + val schema = buildSchema(Array(sourceField(baseType), dateField), path) + testTemplate(dateField, schema, path, datePattern) + } + + protected def doTestIntoTimestampFieldWithPatternAndDefault(input: Input): Unit = { + import input._ + val timestampField = StructField("timestampField", TimestampType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).putString("pattern", timestampPattern).putString("default", defaultValueTimestamp).build) + val schema = buildSchema(Array(sourceField(baseType), timestampField), path) + testTemplate(timestampField, schema, path, timestampPattern) + } + + protected def doTestIntoDateFieldWithPatternAndTimeZone(input: Input): Unit = { + import input._ + val dateField = StructField("dateField", DateType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).putString("pattern", datePattern).putString("timezone", fixedTimezone).build) + val schema = buildSchema(Array(sourceField(baseType), dateField), path) + testTemplate(dateField, schema, path, datePattern, Option(fixedTimezone)) + } + + protected def doTestIntoTimestampFieldWithPatternAndTimeZone(input: Input): Unit = { + import input._ + val timestampField = StructField("timestampField", TimestampType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).putString("pattern", timestampPattern).putString("timezone", fixedTimezone).build) + val schema = buildSchema(Array(sourceField(baseType), timestampField), path) + testTemplate(timestampField, schema, path, timestampPattern, Option(fixedTimezone)) + } + + protected def doTestIntoDateFieldWithEpochPattern(input: Input): Unit = { + import input._ + val dateField = StructField("dateField", DateType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).putString("pattern", DateTimePattern.EpochKeyword).build) + val schema = buildSchema(Array(sourceField(baseType), dateField), path) + testTemplate(dateField, schema, path, DateTimePattern.EpochKeyword) + } + + protected def doTestIntoTimestampFieldWithEpochPattern(input: Input): Unit = { + import input._ + val timestampField = StructField("timestampField", TimestampType, nullable = false, + new MetadataBuilder().putString("sourcecolumn", sourceFieldName).putString("pattern", DateTimePattern.EpochMilliKeyword).build) + val schema = buildSchema(Array(sourceField(baseType), timestampField), path) + testTemplate(timestampField, schema, path, DateTimePattern.EpochMilliKeyword) + } + + private def sourceField(baseType: DataType, nullable: Boolean = true): StructField = StructField(sourceFieldName, baseType, nullable) + + private def buildSchema(fields: Array[StructField], path: String): StructType = { + val innerSchema = StructType(fields) + + if (path.nonEmpty) { + StructType(Array(StructField(path, innerSchema))) + } else { + innerSchema + } + } + + private def testTemplate(target: StructField, schema: StructType, path: String, pattern: String = "", timezone: Option[String] = None): Unit = { + val srcField = fullName(path, sourceFieldName) + val castString = createCastTemplate(target.dataType, pattern, timezone).format(srcField, srcField) + val errColumnExpression = assembleErrorExpression(srcField, target, castString) + val stdCastExpression = assembleCastExpression(srcField, target, castString, errColumnExpression) + val output: ParseOutput = TypeParser.standardize(target, path, schema) + + doAssert(errColumnExpression, output.errors.toString()) + doAssert(stdCastExpression, output.stdCol.toString()) + } + + private def fullName(path: String, fieldName: String): String = { + if (path.nonEmpty) s"$path.$fieldName" else fieldName + } + + private def assembleCastExpression(srcField: String, + target: StructField, + castExpression: String, + errorExpression: String): String = { + val defaultValue = TypedStructField(target).defaultValueWithGlobal.get + val default = defaultValue match { + case Some(d: Date) => s"DATE '${d.toString}'" + case Some(t: Timestamp) => s"TIMESTAMP('${t.toString}')" + case Some(s: String) => s + case Some(x) => x.toString + case None => "NULL" + } + + s"CASE WHEN (size($errorExpression) > 0) THEN $default ELSE CASE WHEN ($srcField IS NOT NULL) THEN $castExpression END END AS `${target.name}`" + } + + private def assembleErrorExpression(srcField: String, target: StructField, castS: String): String = { + val errCond = createErrorCondition(srcField, target, castS) + + if (target.nullable) { + s"CASE WHEN (($srcField IS NOT NULL) AND ($errCond)) THEN array(stdCastErr($srcField, CAST($srcField AS STRING))) ELSE [] END" + } else { + s"CASE WHEN ($srcField IS NULL) THEN array(stdNullErr($srcField)) ELSE CASE WHEN ($errCond) THEN array(stdCastErr($srcField, CAST($srcField AS STRING))) ELSE [] END END" + } + } + + private def doAssert(expectedExpression: String, actualExpression: String): Unit = { + if (actualExpression != expectedExpression) { + // the expressions tend to be rather long, the assert most often cuts the beginning and/or end of the string + // showing just the vicinity of the difference, so we log the output of the whole strings + log.error(s"Expected: $expectedExpression") + log.error(s"Actual : $actualExpression") + assert(actualExpression == expectedExpression) + } + } + +} + +object TypeParserSuiteTemplate { + case class Input(baseType: DataType, + defaultValueDate: String, + defaultValueTimestamp: String, + datePattern: String, + timestampPattern: String, + fixedTimezone: String, + path: String, + datetimeNeedsPattern: Boolean = true) +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromBooleanTypeSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromBooleanTypeSuite.scala new file mode 100644 index 0000000..c604edf --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromBooleanTypeSuite.scala @@ -0,0 +1,119 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import org.apache.spark.sql.types._ +import TypeParserSuiteTemplate.Input +import za.co.absa.standardization.time.DateTimePattern + +class TypeParser_FromBooleanTypeSuite extends TypeParserSuiteTemplate { + + private val input = Input( + baseType = BooleanType, + defaultValueDate = "1", + defaultValueTimestamp = "1", + datePattern = "u", + timestampPattern = "F", + fixedTimezone = "WST", + path = "Boo" + ) + + override protected def createCastTemplate(toType: DataType, pattern: String, timezone: Option[String]): String = { + val isEpoch = DateTimePattern.isEpoch(pattern) + (toType, isEpoch, timezone) match { + case (DateType, true, _) => s"to_date(CAST((CAST(`%s` AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}L) AS TIMESTAMP))" + case (TimestampType, true, _) => s"CAST((CAST(%s AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}) AS TIMESTAMP)" + case (DateType, _, Some(tz)) => s"to_date(to_utc_timestamp(to_timestamp(CAST(`%s` AS STRING), '$pattern'), '$tz'))" + case (TimestampType, _, Some(tz)) => s"to_utc_timestamp(to_timestamp(CAST(`%s` AS STRING), '$pattern'), $tz)" + case (TimestampType, _, _) => s"to_timestamp(CAST(`%s` AS STRING), '$pattern')" + case (DateType, _, _) => s"to_date(CAST(`%s` AS STRING), '$pattern')" + case _ => s"CAST(%s AS ${toType.sql})" + } + } + + override protected def createErrorCondition(srcField: String, target: StructField, castS: String): String = { + target.dataType match { + case FloatType | DoubleType => s"(($castS IS NULL) OR isnan($castS)) OR ($castS IN (Infinity, -Infinity))" + case _ => s"$castS IS NULL" + } + } + + test("Within the column - type stays, nullable") { + doTestWithinColumnNullable(input) + } + + test("Within the column - type stays, not nullable") { + doTestWithinColumnNotNullable(input) + } + + test("Into string field") { + doTestIntoStringField(input) + } + + test("Into float field") { + doTestIntoFloatField(input) + } + + test("Into integer field") { + doTestIntoIntegerField(input) + } + + test("Into boolean field") { + doTestIntoBooleanField(input) + } + + test("Into date field, no pattern") { + doTestIntoDateFieldNoPattern(input) + } + + test("Into timestamp field, no pattern") { + doTestIntoTimestampFieldNoPattern(input) + } + + test("Into date field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into timestamp field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into date field with pattern and default") { + doTestIntoDateFieldWithPatternAndDefault(input) + } + + test("Into timestamp field with pattern and default") { + doTestIntoTimestampFieldWithPatternAndDefault(input) + } + + test("Into date field with pattern and fixed time zone") { + doTestIntoDateFieldWithPatternAndTimeZone(input) + } + + test("Into timestamp field with pattern and fixed time zone") { + doTestIntoTimestampFieldWithPatternAndTimeZone(input) + } + + test("Into date field with epoch pattern") { + doTestIntoDateFieldWithEpochPattern(input) + } + + test("Into timestamp field with epoch pattern") { + doTestIntoTimestampFieldWithEpochPattern(input) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDateTypeSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDateTypeSuite.scala new file mode 100644 index 0000000..dd1851c --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDateTypeSuite.scala @@ -0,0 +1,120 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import org.apache.spark.sql.types._ +import TypeParserSuiteTemplate.Input +import za.co.absa.standardization.time.DateTimePattern + +class TypeParser_FromDateTypeSuite extends TypeParserSuiteTemplate { + + private val input = Input( + baseType = DateType, + defaultValueDate = "2000-01-01", + defaultValueTimestamp = "2010-12-31 23:59:59", + datePattern = "yyyy-MM-dd", + timestampPattern = "yyyy-MM-dd HH:mm:ss", + fixedTimezone = "CST", + path = "Date", + datetimeNeedsPattern = false + ) + + override protected def createCastTemplate(toType: DataType, pattern: String, timezone: Option[String]): String = { + val isEpoch = DateTimePattern.isEpoch(pattern) + (toType, isEpoch, timezone) match { + case (DateType, true, _) => s"to_date(CAST((CAST(`%s` AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}L) AS TIMESTAMP))" + case (TimestampType, true, _) => s"CAST((CAST(%s AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}) AS TIMESTAMP)" + case (DateType, _, Some(tz)) => s"to_date(to_utc_timestamp(`%s`, '$tz'))" + case (TimestampType, _, Some(tz)) => s"to_utc_timestamp(%s, $tz)" + case (DateType, _, _) => "%s" + case (TimestampType, _, _) => "to_timestamp(`%s`)" + case _ => s"CAST(%s AS ${toType.sql})" + } + } + + override protected def createErrorCondition(srcField: String, target: StructField, castS: String): String = { + target.dataType match { + case FloatType | DoubleType => s"(($castS IS NULL) OR isnan($castS)) OR ($castS IN (Infinity, -Infinity))" + case _ => s"$castS IS NULL" + } + } + + test("Within the column - type stays, nullable") { + doTestWithinColumnNullable(input) + } + + test("Within the column - type stays, not nullable") { + doTestWithinColumnNotNullable(input) + } + + test("Into string field") { + doTestIntoStringField(input) + } + + test("Into float field") { + doTestIntoFloatField(input) + } + + test("Into integer field") { + doTestIntoIntegerField(input) + } + + test("Into boolean field") { + doTestIntoBooleanField(input) + } + + test("Into date field, no pattern") { + doTestIntoDateFieldNoPattern(input) + } + + test("Into timestamp field, no pattern") { + doTestIntoTimestampFieldNoPattern(input) + } + + test("Into date field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into timestamp field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into date field with pattern and default") { + doTestIntoDateFieldWithPatternAndDefault(input) + } + + test("Into timestamp field with pattern and default") { + doTestIntoTimestampFieldWithPatternAndDefault(input) + } + + test("Into date field with pattern and fixed time zone") { + doTestIntoDateFieldWithPatternAndTimeZone(input) + } + + test("Into timestamp field with pattern and fixed time zone") { + doTestIntoTimestampFieldWithPatternAndTimeZone(input) + } + + test("Into date field with epoch pattern") { + doTestIntoDateFieldWithEpochPattern(input) + } + + test("Into timestamp field with epoch pattern") { + doTestIntoTimestampFieldWithEpochPattern(input) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDecimalTypeSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDecimalTypeSuite.scala new file mode 100644 index 0000000..9b40d51 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDecimalTypeSuite.scala @@ -0,0 +1,121 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import org.apache.spark.sql.types._ +import TypeParserSuiteTemplate.Input +import za.co.absa.standardization.time.DateTimePattern + +class TypeParser_FromDecimalTypeSuite extends TypeParserSuiteTemplate { + + private val input = Input( + baseType = DecimalType(10, 4), + defaultValueDate = "700101", + defaultValueTimestamp = "991231.2359", + datePattern = "yyMMdd", + timestampPattern = "yyMMdd.HHmm", + fixedTimezone = "CST", + path = "hello" + ) + + override protected def createCastTemplate(toType: DataType, pattern: String, timezone: Option[String]): String = { + val isEpoch = DateTimePattern.isEpoch(pattern) + (toType, isEpoch, timezone) match { + case (DateType, true, _) => s"to_date(CAST((CAST(`%s` AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}L) AS TIMESTAMP))" + case (TimestampType, true, _) => s"CAST((CAST(%s AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}) AS TIMESTAMP)" + case (DateType, _, Some(tz)) => s"to_date(to_utc_timestamp(to_timestamp(CAST(`%s` AS STRING), '$pattern'), '$tz'))" + case (TimestampType, _, Some(tz)) => s"to_utc_timestamp(to_timestamp(CAST(`%s` AS STRING), '$pattern'), $tz)" + case (TimestampType, _, _) => s"to_timestamp(CAST(`%s` AS STRING), '$pattern')" + case (DateType, _, _) => s"to_date(CAST(`%s` AS STRING), '$pattern')" + case _ => s"CAST(%s AS ${toType.sql})" + } + } + + override protected def createErrorCondition(srcField: String, target: StructField, castS: String): String = { + target.dataType match { + case FloatType | DoubleType => s"(($castS IS NULL) OR isnan($castS)) OR ($castS IN (Infinity, -Infinity))" + case ByteType | ShortType | IntegerType | LongType => + s"($castS IS NULL) OR (NOT ($srcField = CAST($castS AS ${input.baseType.sql})))" + case _ => s"$castS IS NULL" + } + } + + test("Within the column - type stays, nullable") { + doTestWithinColumnNullable(input) + } + + test("Within the column - type stays, not nullable") { + doTestWithinColumnNotNullable(input) + } + + test("Into string field") { + doTestIntoStringField(input) + } + + test("Into float field") { + doTestIntoFloatField(input) + } + + test("Into integer field") { + doTestIntoIntegerField(input) + } + + test("Into boolean field") { + doTestIntoBooleanField(input) + } + + test("Into date field, no pattern") { + doTestIntoDateFieldNoPattern(input) + } + + test("Into timestamp field, no pattern") { + doTestIntoTimestampFieldNoPattern(input) + } + + test("Into date field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into timestamp field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into date field with pattern and default") { + doTestIntoDateFieldWithPatternAndDefault(input) + } + + test("Into timestamp field with pattern and default") { + doTestIntoTimestampFieldWithPatternAndDefault(input) + } + + test("Into date field with pattern and fixed time zone") { + doTestIntoDateFieldWithPatternAndTimeZone(input) + } + + test("Into timestamp field with pattern and fixed time zone") { + doTestIntoTimestampFieldWithPatternAndTimeZone(input) + } + + test("Into date field with epoch pattern") { + doTestIntoDateFieldWithEpochPattern(input) + } + + test("Into timestamp field with epoch pattern") { + doTestIntoTimestampFieldWithEpochPattern(input) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDoubleTypeSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDoubleTypeSuite.scala new file mode 100644 index 0000000..d549fb7 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromDoubleTypeSuite.scala @@ -0,0 +1,132 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import org.apache.spark.sql.types._ +import TypeParserSuiteTemplate.Input +import za.co.absa.standardization.time.DateTimePattern + +class TypeParser_FromDoubleTypeSuite extends TypeParserSuiteTemplate { + + private val input = Input( + baseType = DoubleType, + defaultValueDate = "7001.01", + defaultValueTimestamp = "991231.2359", + datePattern = "yyMM.dd", + timestampPattern = "yyMMdd.HHmm", + fixedTimezone = "CEST", + path = "Double" + ) + + private case class DS(precision: Int, scale: Int) //DecimalSize + private val datePatternDS = DS(6, 2) + private val timestampPatternDS = DS(10, 4) + + override protected def createCastTemplate(toType: DataType, pattern: String, timezone: Option[String]): String = { + val isEpoch = DateTimePattern.isEpoch(pattern) + (toType, isEpoch, timezone) match { + case (DateType, true, _) => s"to_date(CAST((CAST(`%s` AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}L) AS TIMESTAMP))" + case (TimestampType, true, _) => s"CAST((CAST(%s AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}) AS TIMESTAMP)" + case (DateType, _, Some(tz)) => s"to_date(to_utc_timestamp(to_timestamp(CAST(CAST(`%s` AS DECIMAL(${datePatternDS.precision},${datePatternDS.scale})) AS STRING), '$pattern'), '$tz'))" + case (TimestampType, _, Some(tz)) => s"to_utc_timestamp(to_timestamp(CAST(CAST(`%s` AS DECIMAL(${timestampPatternDS.precision},${timestampPatternDS.scale})) AS STRING), '$pattern'), $tz)" + case (DateType, _, _) => s"to_date(CAST(CAST(`%s` AS DECIMAL(${datePatternDS.precision},${datePatternDS.scale})) AS STRING), '$pattern')" + case (TimestampType, _, _) => s"to_timestamp(CAST(CAST(`%s` AS DECIMAL(${timestampPatternDS.precision},${timestampPatternDS.scale})) AS STRING), '$pattern')" + case _ => s"CAST(%s AS ${toType.sql})" + } + } + + override protected def createErrorCondition(srcField: String, target: StructField, castS: String): String = { + val (min, max) = target.dataType match { + case ByteType => (Byte.MinValue, Byte.MaxValue) + case ShortType => (Short.MinValue, Short.MaxValue) + case IntegerType => (Int.MinValue, Int.MaxValue) + case LongType => (Long.MinValue, Long.MaxValue) + case _ => (0,0 ) + } + target.dataType match { + case FloatType | DoubleType => s"(($castS IS NULL) OR isnan($castS)) OR ($castS IN (Infinity, -Infinity))" + case ByteType | ShortType | IntegerType | LongType => + s"((($castS IS NULL) OR (NOT (($srcField % 1.0) = 0.0))) OR ($srcField > $max)) OR ($srcField < $min)" + case _ => s"$castS IS NULL" + } + } + + test("Within the column - type stays, nullable") { + doTestWithinColumnNullable(input) + } + + test("Within the column - type stays, not nullable") { + doTestWithinColumnNotNullable(input) + } + + test("Into string field") { + doTestIntoStringField(input) + } + + test("Into float field") { + doTestIntoFloatField(input) + } + + test("Into integer field") { + doTestIntoIntegerField(input) + } + + test("Into boolean field") { + doTestIntoBooleanField(input) + } + + test("Into date field, no pattern") { + doTestIntoDateFieldNoPattern(input) + } + + test("Into timestamp field, no pattern") { + doTestIntoTimestampFieldNoPattern(input) + } + + test("Into date field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into timestamp field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into date field with pattern and default") { + doTestIntoDateFieldWithPatternAndDefault(input) + } + + test("Into timestamp field with pattern and default") { + doTestIntoTimestampFieldWithPatternAndDefault(input) + } + + test("Into date field with pattern and fixed time zone") { + doTestIntoDateFieldWithPatternAndTimeZone(input) + } + + test("Into timestamp field with pattern and fixed time zone") { + doTestIntoTimestampFieldWithPatternAndTimeZone(input) + } + + test("Into date field with epoch pattern") { + doTestIntoDateFieldWithEpochPattern(input) + } + + test("Into timestamp field with epoch pattern") { + doTestIntoTimestampFieldWithEpochPattern(input) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromLongTypeSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromLongTypeSuite.scala new file mode 100644 index 0000000..a7ff0b5 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromLongTypeSuite.scala @@ -0,0 +1,121 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import org.apache.spark.sql.types._ +import TypeParserSuiteTemplate.Input +import za.co.absa.standardization.time.DateTimePattern + +class TypeParser_FromLongTypeSuite extends TypeParserSuiteTemplate { + + private val input = Input( + baseType = LongType, + defaultValueDate = "20001010", + defaultValueTimestamp = "199912311201", + datePattern = "yyyyMMdd", + timestampPattern = "yyyyMMddHHmm", + fixedTimezone = "EST", + path = "Hey" + ) + + override protected def createCastTemplate(toType: DataType, pattern: String, timezone: Option[String]): String = { + val isEpoch = DateTimePattern.isEpoch(pattern) + (toType, isEpoch, timezone) match { + case (DateType, true, _) => s"to_date(CAST((CAST(`%s` AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}L) AS TIMESTAMP))" + case (TimestampType, true, _) => s"CAST((CAST(%s AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}) AS TIMESTAMP)" + case (DateType, _, Some(tz)) => s"to_date(to_utc_timestamp(to_timestamp(CAST(`%s` AS STRING), '$pattern'), '$tz'))" + case (TimestampType, _, Some(tz)) => s"to_utc_timestamp(to_timestamp(CAST(`%s` AS STRING), '$pattern'), $tz)" + case (TimestampType, _, _) => s"to_timestamp(CAST(`%s` AS STRING), '$pattern')" + case (DateType, _, _) => s"to_date(CAST(`%s` AS STRING), '$pattern')" + case _ => s"CAST(%s AS ${toType.sql})" + } + } + + override protected def createErrorCondition(srcField: String, target: StructField, castS: String): String = { + target.dataType match { + case FloatType | DoubleType => s"(($castS IS NULL) OR isnan($castS)) OR ($castS IN (Infinity, -Infinity))" + case ByteType | ShortType | IntegerType => + s"($castS IS NULL) OR (NOT (CAST(Hey.sourceField AS INT) = CAST(Hey.sourceField AS BIGINT)))" + case _ => s"$castS IS NULL" + } + } + + test("Within the column - type stays, nullable") { + doTestWithinColumnNullable(input) + } + + test("Within the column - type stays, not nullable") { + doTestWithinColumnNotNullable(input) + } + + test("Into string field") { + doTestIntoStringField(input) + } + + test("Into float field") { + doTestIntoFloatField(input) + } + + test("Into integer field") { + doTestIntoIntegerField(input) + } + + test("Into boolean field") { + doTestIntoBooleanField(input) + } + + test("Into date field, no pattern") { + doTestIntoDateFieldNoPattern(input) + } + + test("Into timestamp field, no pattern") { + doTestIntoTimestampFieldNoPattern(input) + } + + test("Into date field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into timestamp field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into date field with pattern and default") { + doTestIntoDateFieldWithPatternAndDefault(input) + } + + test("Into timestamp field with pattern and default") { + doTestIntoTimestampFieldWithPatternAndDefault(input) + } + + test("Into date field with pattern and fixed time zone") { + doTestIntoDateFieldWithPatternAndTimeZone(input) + } + + test("Into timestamp field with pattern and fixed time zone") { + doTestIntoTimestampFieldWithPatternAndTimeZone(input) + } + + test("Into date field with epoch pattern") { + doTestIntoDateFieldWithEpochPattern(input) + } + + test("Into timestamp field with epoch pattern") { + doTestIntoTimestampFieldWithEpochPattern(input) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromStringTypeSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromStringTypeSuite.scala new file mode 100644 index 0000000..9e9dd7f --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromStringTypeSuite.scala @@ -0,0 +1,121 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import org.apache.spark.sql.types._ +import TypeParserSuiteTemplate.Input +import za.co.absa.standardization.time.DateTimePattern + +class TypeParser_FromStringTypeSuite extends TypeParserSuiteTemplate { + + private val input = Input( + baseType = StringType, + defaultValueDate = "01.01.1970", + defaultValueTimestamp = "01.01.1970 00:00:00", + datePattern = "dd.MM.yyyy", + timestampPattern = "dd.MM.yyyy HH:mm:ss", + fixedTimezone = "CET", + path = "", + datetimeNeedsPattern = false + ) + + override protected def createCastTemplate(toType: DataType, pattern: String, timezone: Option[String]): String = { + val isEpoch = DateTimePattern.isEpoch(pattern) + (toType, isEpoch, timezone) match { + case (DateType, true, _) => s"to_date(CAST((CAST(`%s` AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}L) AS TIMESTAMP))" + case (TimestampType, true, _) => s"CAST((CAST(%s AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}) AS TIMESTAMP)" + case (DateType, _, Some(tz)) => s"to_date(to_utc_timestamp(to_timestamp(`%s`, '$pattern'), '$tz'))" + case (TimestampType, _, Some(tz)) => s"to_utc_timestamp(to_timestamp(`%s`, '$pattern'), $tz)" + case (DateType, _, _) => s"to_date(`%s`, '$pattern')" + case (TimestampType, _, _) => s"to_timestamp(`%s`, '$pattern')" + case _ => s"CAST(%s AS ${toType.sql})" + } + } + + override protected def createErrorCondition(srcField: String, target: StructField, castS: String): String = { + target.dataType match { + case FloatType | DoubleType => s"(($castS IS NULL) OR isnan($castS)) OR ($castS IN (Infinity, -Infinity))" + case ByteType | ShortType | IntegerType | LongType => s"($castS IS NULL) OR contains($srcField, .)" + case _ => s"$castS IS NULL" + } + } + + test("Within the column - type stays, nullable") { + doTestWithinColumnNullable(input) + } + + test("Within the column - type stays, not nullable") { + doTestWithinColumnNotNullable(input) + } + + test("Into string field") { + doTestIntoStringField(input) + } + + test("Into float field") { + doTestIntoFloatField(input) + } + + test("Into integer field") { + doTestIntoIntegerField(input) + } + + test("Into boolean field") { + doTestIntoBooleanField(input) + } + + test("Into date field, no pattern") { + doTestIntoDateFieldNoPattern(input) + } + + test("Into timestamp field, no pattern") { + doTestIntoTimestampFieldNoPattern(input) + } + + test("Into date field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into timestamp field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into date field with pattern and default") { + doTestIntoDateFieldWithPatternAndDefault(input) + } + + test("Into timestamp field with pattern and default") { + doTestIntoTimestampFieldWithPatternAndDefault(input) + } + + test("Into date field with pattern and fixed time zone") { + doTestIntoDateFieldWithPatternAndTimeZone(input) + } + + test("Into timestamp field with pattern and fixed time zone") { + doTestIntoTimestampFieldWithPatternAndTimeZone(input) + } + + test("Into date field with epoch pattern") { + doTestIntoDateFieldWithEpochPattern(input) + } + + test("Into timestamp field with epoch pattern") { + doTestIntoTimestampFieldWithEpochPattern(input) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromTimestampTypeSuite.scala b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromTimestampTypeSuite.scala new file mode 100644 index 0000000..aa57b05 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/stages/TypeParser_FromTimestampTypeSuite.scala @@ -0,0 +1,120 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter.stages + +import org.apache.spark.sql.types._ +import TypeParserSuiteTemplate.Input +import za.co.absa.standardization.time.DateTimePattern + +class TypeParser_FromTimestampTypeSuite extends TypeParserSuiteTemplate { + + private val input = Input( + baseType = TimestampType, + defaultValueDate = "2000-01-01", + defaultValueTimestamp = "2010-12-31 23:59:59", + datePattern = "yyyy-MM-dd", + timestampPattern = "yyyy-MM-dd HH:mm:ss", + fixedTimezone = "CST", + path = "timestamp", + datetimeNeedsPattern = false + ) + + override protected def createCastTemplate(toType: DataType, pattern: String, timezone: Option[String]): String = { + val isEpoch = DateTimePattern.isEpoch(pattern) + (toType, isEpoch, timezone) match { + case (DateType, true, _) => s"to_date(CAST((CAST(`%s` AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}L) AS TIMESTAMP))" + case (TimestampType, true, _) => s"CAST((CAST(%s AS DECIMAL(30,9)) / ${DateTimePattern.epochFactor(pattern)}) AS TIMESTAMP)" + case (TimestampType, _, Some(tz)) => s"to_utc_timestamp(%s, $tz)" + case (DateType, _, Some(tz)) => s"to_date(to_utc_timestamp(`%s`, '$tz'))" + case (TimestampType, _, _) => "%s" + case (DateType, _, _) => "to_date(`%s`)" + case _ => s"CAST(%s AS ${toType.sql})" + } + } + + override protected def createErrorCondition(srcField: String, target: StructField, castS: String): String = { + target.dataType match { + case FloatType | DoubleType => s"(($castS IS NULL) OR isnan($castS)) OR ($castS IN (Infinity, -Infinity))" + case _ => s"$castS IS NULL" + } + } + + test("Within the column - type stays, nullable") { + doTestWithinColumnNullable(input) + } + + test("Within the column - type stays, not nullable") { + doTestWithinColumnNotNullable(input) + } + + test("Into string field") { + doTestIntoStringField(input) + } + + test("Into float field") { + doTestIntoFloatField(input) + } + + test("Into integer field") { + doTestIntoIntegerField(input) + } + + test("Into boolean field") { + doTestIntoBooleanField(input) + } + + test("Into date field, no pattern") { + doTestIntoDateFieldNoPattern(input) + } + + test("Into timestamp field, no pattern") { + doTestIntoTimestampFieldNoPattern(input) + } + + test("Into date field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into timestamp field with pattern") { + doTestIntoDateFieldWithPattern(input) + } + + test("Into date field with pattern and default") { + doTestIntoDateFieldWithPatternAndDefault(input) + } + + test("Into timestamp field with pattern and default") { + doTestIntoTimestampFieldWithPatternAndDefault(input) + } + + test("Into date field with pattern and fixed time zone") { + doTestIntoDateFieldWithPatternAndTimeZone(input) + } + + test("Into timestamp field with pattern and fixed time zone") { + doTestIntoTimestampFieldWithPatternAndTimeZone(input) + } + + test("Into date field with epoch pattern") { + doTestIntoDateFieldWithEpochPattern(input) + } + + test("Into timestamp field with epoch pattern") { + doTestIntoTimestampFieldWithEpochPattern(input) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/interpreter/standardizationInterpreter_RowTypes.scala b/src/test/scala/za/co/absa/standardization/interpreter/standardizationInterpreter_RowTypes.scala new file mode 100644 index 0000000..8eca9f2 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/interpreter/standardizationInterpreter_RowTypes.scala @@ -0,0 +1,132 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.interpreter + +import java.sql.{Date, Timestamp} +import za.co.absa.standardization.ErrorMessage + +//Decimal Suite +case class DecimalRow( + description: String, + small: Option[BigDecimal], + big: Option[BigDecimal], + errCol: Seq[ErrorMessage] = Seq.empty + ) + +case class InputRowDoublesForDecimal( + description: String, + small: Option[Double], + big: Option[Double] + ) { + def this(description: String, value: Double) = { + this(description, Option(value), Option(value)) + } +} + +//Fractional Suite +case class FractionalRow( + description: String, + floatField: Option[Float], + doubleField: Option[Double], + errCol: Seq[ErrorMessage] = Seq.empty + ) + +case class InputRowLongsForFractional( + description: String, + floatField: Option[Double], + doubleField: Option[Double] + ) { + def this(description: String, value: Double) = { + this(description, Option(value), Option(value)) + } +} + +case class InputRowDoublesForFractional( + description: String, + floatField: Option[Double], + doubleField: Option[Double] + ) { + def this(description: String, value: Double) = { + this(description, Option(value), Option(value)) + } +} + +//Integral Suite +case class IntegralRow( + description: String, + byteSize: Option[Byte], + shortSize: Option[Short], + integerSize: Option[Int], + longSize: Option[Long], + errCol: Seq[ErrorMessage] = Seq.empty + ) + +case class InputRowLongsForIntegral( + description: String, + bytesize: Long, + shortsize: Long, + integersize: Long, + longsize: Long + ) { + def this(description: String, value: Long) = { + this(description, value, value, value, value) + } +} + +case class InputRowDoublesForIntegral( + description: String, + bytesize: Option[Double], + shortsize: Option[Double], + integersize: Option[Double], + longsize: Option[Double] + ) { + def this(description: String, value: Double) = { + this(description, Option(value), Option(value), Option(value), Option(value)) + } +} + +case class InputRowBigDecimalsForIntegral( + description: String, + bytesize: Option[BigDecimal], + shortsize: Option[BigDecimal], + integersize: Option[BigDecimal], + longsize: Option[BigDecimal] + ) { + def this(description: String, value: BigDecimal) = { + this(description, Option(value), Option(value), Option(value), Option(value)) + } +} + +//Timestamp Suite +case class TimestampRow( + tms: Timestamp, + errCol: Seq[ErrorMessage] = Seq.empty + ) + +//Date Suite +case class DateRow( + dateField: Date, + errCol: Seq[ErrorMessage] = Seq.empty + ) + +case class BinaryRow( + binaryField: Array[Byte], + errCol: Seq[ErrorMessage] = Seq.empty + ) { + + def simpleFields = (binaryField.toSeq, errCol) +} diff --git a/src/test/scala/za/co/absa/standardization/schema/SchemaUtilsSuite.scala b/src/test/scala/za/co/absa/standardization/schema/SchemaUtilsSuite.scala new file mode 100644 index 0000000..a5f1276 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/schema/SchemaUtilsSuite.scala @@ -0,0 +1,477 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.schema + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import za.co.absa.standardization.schema.SchemaUtils._ + +class SchemaUtilsSuite extends AnyFunSuite with Matchers { + // scalastyle:off magic.number + + private val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StructType(Seq( + StructField("c", IntegerType), + StructField("d", StructType(Seq( + StructField("e", IntegerType))), nullable = true)))), + StructField("f", StructType(Seq( + StructField("g", ArrayType.apply(StructType(Seq( + StructField("h", IntegerType)))))))))) + + private val nestedSchema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", ArrayType(StructType(Seq( + StructField("c", StructType(Seq( + StructField("d", ArrayType(StructType(Seq( + StructField("e", IntegerType)))))))))))))) + + private val arrayOfArraysSchema = StructType(Seq( + StructField("a", ArrayType(ArrayType(IntegerType)), nullable = false), + StructField("b", ArrayType(ArrayType(StructType(Seq( + StructField("c", StringType, nullable = false) + )) + )), nullable = true) + )) + + private val structFieldNoMetadata = StructField("a", IntegerType) + + private val structFieldWithMetadataNotSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("meta", "data").build) + private val structFieldWithMetadataSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "override_a").build) + + test("Testing getFieldType") { + + val a = getFieldType("a", schema) + val b = getFieldType("b", schema) + val c = getFieldType("b.c", schema) + val d = getFieldType("b.d", schema) + val e = getFieldType("b.d.e", schema) + val f = getFieldType("f", schema) + val g = getFieldType("f.g", schema) + val h = getFieldType("f.g.h", schema) + + assert(a.get.isInstanceOf[IntegerType]) + assert(b.get.isInstanceOf[StructType]) + assert(c.get.isInstanceOf[IntegerType]) + assert(d.get.isInstanceOf[StructType]) + assert(e.get.isInstanceOf[IntegerType]) + assert(f.get.isInstanceOf[StructType]) + assert(g.get.isInstanceOf[ArrayType]) + assert(h.get.isInstanceOf[IntegerType]) + assert(getFieldType("z", schema).isEmpty) + assert(getFieldType("x.y.z", schema).isEmpty) + assert(getFieldType("f.g.h.a", schema).isEmpty) + } + + test("Testing fieldExists") { + assert(fieldExists("a", schema)) + assert(fieldExists("b", schema)) + assert(fieldExists("b.c", schema)) + assert(fieldExists("b.d", schema)) + assert(fieldExists("b.d.e", schema)) + assert(fieldExists("f", schema)) + assert(fieldExists("f.g", schema)) + assert(fieldExists("f.g.h", schema)) + assert(!fieldExists("z", schema)) + assert(!fieldExists("x.y.z", schema)) + assert(!fieldExists("f.g.h.a", schema)) + } + + test ("Test isColumnArrayOfStruct") { + assert(!isColumnArrayOfStruct("a", schema)) + assert(!isColumnArrayOfStruct("b", schema)) + assert(!isColumnArrayOfStruct("b.c", schema)) + assert(!isColumnArrayOfStruct("b.d", schema)) + assert(!isColumnArrayOfStruct("b.d.e", schema)) + assert(!isColumnArrayOfStruct("f", schema)) + assert(isColumnArrayOfStruct("f.g", schema)) + assert(!isColumnArrayOfStruct("f.g.h", schema)) + assert(!isColumnArrayOfStruct("a", nestedSchema)) + assert(isColumnArrayOfStruct("b", nestedSchema)) + assert(isColumnArrayOfStruct("b.c.d", nestedSchema)) + } + + test("getRenamesInSchema - no renames") { + val result = getRenamesInSchema(StructType(Seq( + structFieldNoMetadata, + structFieldWithMetadataNotSourceColumn))) + assert(result.isEmpty) + } + + test("getRenamesInSchema - simple rename") { + val result = getRenamesInSchema(StructType(Seq(structFieldWithMetadataSourceColumn))) + assert(result == Map("a" -> "override_a")) + + } + + test("getRenamesInSchema - complex with includeIfPredecessorChanged set") { + val sub = StructType(Seq( + StructField("d", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "o").build), + StructField("e", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "e").build), + StructField("f", IntegerType) + )) + val schema = StructType(Seq( + StructField("a", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "x").build), + StructField("b", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "b").build), + StructField("c", sub) + )) + + val includeIfPredecessorChanged = true + val result = getRenamesInSchema(schema, includeIfPredecessorChanged) + val expected = Map( + "a" -> "x" , + "a.d" -> "x.o", + "a.e" -> "x.e", + "a.f" -> "x.f", + "b.d" -> "b.o", + "c.d" -> "c.o" + ) + + assert(result == expected) + } + + test("getRenamesInSchema - complex with includeIfPredecessorChanged not set") { + val sub = StructType(Seq( + StructField("d", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "o").build), + StructField("e", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "e").build), + StructField("f", IntegerType) + )) + val schema = StructType(Seq( + StructField("a", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "x").build), + StructField("b", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "b").build), + StructField("c", sub) + )) + + val includeIfPredecessorChanged = false + val result = getRenamesInSchema(schema, includeIfPredecessorChanged) + val expected = Map( + "a" -> "x", + "a.d" -> "x.o", + "b.d" -> "b.o", + "c.d" -> "c.o" + ) + + assert(result == expected) + } + + + test("getRenamesInSchema - array") { + val sub = StructType(Seq( + StructField("renamed", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "rename source").build), + StructField("same", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "same").build), + StructField("f", IntegerType) + )) + val schema = StructType(Seq( + StructField("array1", ArrayType(sub)), + StructField("array2", ArrayType(ArrayType(ArrayType(sub)))), + StructField("array3", ArrayType(IntegerType), nullable = false, new MetadataBuilder().putString("sourcecolumn", "array source").build) + )) + + val includeIfPredecessorChanged = false + val result = getRenamesInSchema(schema, includeIfPredecessorChanged) + val expected = Map( + "array1.renamed" -> "array1.rename source", + "array2.renamed" -> "array2.rename source", + "array3" -> "array source" + ) + + assert(result == expected) + } + + + test("getRenamesInSchema - source column used multiple times") { + val sub = StructType(Seq( + StructField("x", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "src").build), + StructField("y", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "src").build) + )) + val schema = StructType(Seq( + StructField("a", sub), + StructField("b", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "src").build) + )) + + val result = getRenamesInSchema(schema) + val expected = Map( + "a.x" -> "a.src", + "a.y" -> "a.src", + "b" -> "src" + ) + + assert(result == expected) + } + + test("Testing getFirstArrayPath") { + assertResult("f.g")(getFirstArrayPath("f.g.h", schema)) + assertResult("f.g")(getFirstArrayPath("f.g", schema)) + assertResult("")(getFirstArrayPath("z.x.y", schema)) + assertResult("")(getFirstArrayPath("b.c.d.e", schema)) + } + + test("Testing getAllArrayPaths") { + assertResult(Seq("f.g"))(getAllArrayPaths(schema)) + assertResult(Seq())(getAllArrayPaths(schema("b").dataType.asInstanceOf[StructType])) + } + + test("Testing getAllArraysInPath") { + assertResult(Seq("b", "b.c.d"))(getAllArraysInPath("b.c.d.e", nestedSchema)) + } + + test("Testing getFieldNameOverriddenByMetadata") { + assertResult("a")(getFieldNameOverriddenByMetadata(structFieldNoMetadata)) + assertResult("a")(getFieldNameOverriddenByMetadata(structFieldWithMetadataNotSourceColumn)) + assertResult("override_a")(getFieldNameOverriddenByMetadata(structFieldWithMetadataSourceColumn)) + } + + test("Testing getFieldNullability") { + assert(!getFieldNullability("a", schema).get) + assert(getFieldNullability("b.d", schema).get) + assert(getFieldNullability("x.y.z", schema).isEmpty) + } + + test ("Test isCastAlwaysSucceeds()") { + assert(!isCastAlwaysSucceeds(StructType(Seq()), StringType)) + assert(!isCastAlwaysSucceeds(ArrayType(StringType), StringType)) + assert(!isCastAlwaysSucceeds(StringType, ByteType)) + assert(!isCastAlwaysSucceeds(StringType, ShortType)) + assert(!isCastAlwaysSucceeds(StringType, IntegerType)) + assert(!isCastAlwaysSucceeds(StringType, LongType)) + assert(!isCastAlwaysSucceeds(StringType, DecimalType(10,10))) + assert(!isCastAlwaysSucceeds(StringType, DateType)) + assert(!isCastAlwaysSucceeds(StringType, TimestampType)) + assert(!isCastAlwaysSucceeds(StructType(Seq()), StructType(Seq()))) + assert(!isCastAlwaysSucceeds(ArrayType(StringType), ArrayType(StringType))) + + assert(!isCastAlwaysSucceeds(ShortType, ByteType)) + assert(!isCastAlwaysSucceeds(IntegerType, ByteType)) + assert(!isCastAlwaysSucceeds(IntegerType, ShortType)) + assert(!isCastAlwaysSucceeds(LongType, ByteType)) + assert(!isCastAlwaysSucceeds(LongType, ShortType)) + assert(!isCastAlwaysSucceeds(LongType, IntegerType)) + + assert(isCastAlwaysSucceeds(StringType, StringType)) + assert(isCastAlwaysSucceeds(ByteType, StringType)) + assert(isCastAlwaysSucceeds(ShortType, StringType)) + assert(isCastAlwaysSucceeds(IntegerType, StringType)) + assert(isCastAlwaysSucceeds(LongType, StringType)) + assert(isCastAlwaysSucceeds(DecimalType(10,10), StringType)) + assert(isCastAlwaysSucceeds(DateType, StringType)) + assert(isCastAlwaysSucceeds(TimestampType, StringType)) + assert(isCastAlwaysSucceeds(StringType, StringType)) + + assert(isCastAlwaysSucceeds(ByteType, ByteType)) + assert(isCastAlwaysSucceeds(ByteType, ShortType)) + assert(isCastAlwaysSucceeds(ByteType, IntegerType)) + assert(isCastAlwaysSucceeds(ByteType, LongType)) + assert(isCastAlwaysSucceeds(ShortType, ShortType)) + assert(isCastAlwaysSucceeds(ShortType, IntegerType)) + assert(isCastAlwaysSucceeds(ShortType, LongType)) + assert(isCastAlwaysSucceeds(IntegerType, IntegerType)) + assert(isCastAlwaysSucceeds(IntegerType, LongType)) + assert(isCastAlwaysSucceeds(LongType, LongType)) + assert(isCastAlwaysSucceeds(DateType, TimestampType)) + } + + test("Test isCommonSubPath()") { + assert (isCommonSubPath()) + assert (isCommonSubPath("a")) + assert (isCommonSubPath("a.b.c.d.e.f", "a.b.c.d", "a.b.c", "a.b", "a")) + assert (!isCommonSubPath("a.b.c.d.e.f", "a.b.c.x", "a.b.c", "a.b", "a")) + } + + test("Test getDeepestCommonArrayPath() for a path without an array") { + val schema = StructType(Seq[StructField]( + StructField("a", + StructType(Seq[StructField]( + StructField("b", StringType)) + )))) + + assert (getDeepestCommonArrayPath(schema, Seq("a", "a.b")).isEmpty) + } + + test("Test getDeepestCommonArrayPath() for a path with a single array at top level") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StringType))) + )))) + + val deepestPath = getDeepestCommonArrayPath(schema, Seq("a", "a.b")) + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a") + } + + test("Test getDeepestCommonArrayPath() for a path with a single array at nested level") { + val schema = StructType(Seq[StructField]( + StructField("a", StructType(Seq[StructField]( + StructField("b", ArrayType(StringType)))) + ))) + + val deepestPath = getDeepestCommonArrayPath(schema, Seq("a", "a.b")) + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a.b") + } + + test("Test getDeepestCommonArrayPath() for a path with several nested arrays of struct") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StructType(Seq[StructField]( + StructField("c", ArrayType(StructType(Seq[StructField]( + StructField("d", StructType(Seq[StructField]( + StructField("e", StringType)) + ))) + )))) + ))) + ))))) + + val deepestPath = getDeepestCommonArrayPath(schema, Seq("a", "a.b", "a.b.c.d.e", "a.b.c.d")) + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a.b.c") + } + + test("Test getDeepestArrayPath() for a path without an array") { + val schema = StructType(Seq[StructField]( + StructField("a", + StructType(Seq[StructField]( + StructField("b", StringType)) + )))) + + assert (getDeepestArrayPath(schema, "a.b").isEmpty) + } + + test("Test getDeepestArrayPath() for a path with a single array at top level") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StringType))) + )))) + + val deepestPath = getDeepestArrayPath(schema, "a.b") + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a") + } + + test("Test getDeepestArrayPath() for a path with a single array at nested level") { + val schema = StructType(Seq[StructField]( + StructField("a", StructType(Seq[StructField]( + StructField("b", ArrayType(StringType)))) + ))) + + val deepestPath = getDeepestArrayPath(schema, "a.b") + val deepestPath2 = getDeepestArrayPath(schema, "a") + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a.b") + assert (deepestPath2.isEmpty) + } + + test("Test getDeepestArrayPath() for a path with several nested arrays of struct") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StructType(Seq[StructField]( + StructField("c", ArrayType(StructType(Seq[StructField]( + StructField("d", StructType(Seq[StructField]( + StructField("e", StringType)) + ))) + )))) + ))) + ))))) + + val deepestPath = getDeepestArrayPath(schema, "a.b.c.d.e") + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a.b.c") + } + + + test("Test getClosestUniqueName() is working properly") { + val schema = StructType(Seq[StructField]( + StructField("value", StringType))) + + // A column name that does not exist + val name1 = SchemaUtils.getClosestUniqueName("v", schema) + // A column that exists + val name2 = SchemaUtils.getClosestUniqueName("value", schema) + + assert(name1 == "v") + assert(name2 == "value_1") + } + + test("Test isOnlyField()") { + val schema = StructType(Seq[StructField]( + StructField("a", StringType), + StructField("b", StructType(Seq[StructField]( + StructField("e", StringType), + StructField("f", StringType) + ))), + StructField("c", StructType(Seq[StructField]( + StructField("d", StringType) + ))) + )) + + assert(!isOnlyField(schema, "a")) + assert(!isOnlyField(schema, "b.e")) + assert(!isOnlyField(schema, "b.f")) + assert(isOnlyField(schema, "c.d")) + } + + test("Test getStructField on array of arrays") { + assert(getField("a", arrayOfArraysSchema).contains(StructField("a",ArrayType(ArrayType(IntegerType)),nullable = false))) + assert(getField("b", arrayOfArraysSchema).contains(StructField("b",ArrayType(ArrayType(StructType(Seq(StructField("c",StringType,nullable = false))))), nullable = true))) + assert(getField("b.c", arrayOfArraysSchema).contains(StructField("c",StringType,nullable = false))) + assert(getField("b.d", arrayOfArraysSchema).isEmpty) + } + + test("Test fieldExists") { + assert(fieldExists("a", schema)) + assert(fieldExists("b", schema)) + assert(fieldExists("b.c", schema)) + assert(fieldExists("b.d", schema)) + assert(fieldExists("b.d.e", schema)) + assert(fieldExists("f", schema)) + assert(fieldExists("f.g", schema)) + assert(fieldExists("f.g.h", schema)) + assert(!fieldExists("z", schema)) + assert(!fieldExists("x.y.z", schema)) + assert(!fieldExists("f.g.h.a", schema)) + + assert(fieldExists("a", arrayOfArraysSchema)) + assert(fieldExists("b", arrayOfArraysSchema)) + assert(fieldExists("b.c", arrayOfArraysSchema)) + assert(!fieldExists("b.d", arrayOfArraysSchema)) + } + + test("unpath - empty string remains empty") { + val result = unpath("") + val expected = "" + assert(result == expected) + } + + test("unpath - underscores get doubled") { + val result = unpath("one_two__three") + val expected = "one__two____three" + assert(result == expected) + } + + test("unpath - dot notation conversion") { + val result = unpath("grand_parent.parent.first_child") + val expected = "grand__parent_parent_first__child" + assert(result == expected) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/schema/SparkUtilsSuite.scala b/src/test/scala/za/co/absa/standardization/schema/SparkUtilsSuite.scala new file mode 100644 index 0000000..a5662c0 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/schema/SparkUtilsSuite.scala @@ -0,0 +1,119 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.schema + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{BooleanType, LongType, StructField, StructType} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.SparkTestBase + +class SparkUtilsSuite extends AnyFunSuite with SparkTestBase { + + import za.co.absa.standardization.implicits.DataFrameImplicits.DataFrameEnhancements + + private def getDummyDataFrame: DataFrame = { + import spark.implicits._ + + Seq(1, 1, 1, 2, 1).toDF("value") + } + + test("Test setUniqueColumnNameOfCorruptRecord") { + val expected1 = "_corrupt_record" + val schema1 = StructType(Seq(StructField("id", LongType))) + val result1 = SparkUtils.setUniqueColumnNameOfCorruptRecord(spark, schema1) + assert(result1 == expected1) + assert(spark.conf.get(SparkUtils.ColumnNameOfCorruptRecordConf) == expected1) + //two tests in series as the function has side-effects (on provided spark session) and it might collide in parallel run + val expected2 = "_corrupt_record_1" + val schema2 = StructType(Seq(StructField("id", LongType), StructField(expected1, BooleanType))) + val result2 = SparkUtils.setUniqueColumnNameOfCorruptRecord(spark, schema2) + assert(result2 == expected2) + assert(spark.conf.get(SparkUtils.ColumnNameOfCorruptRecordConf) == expected2) + } + + test("Test withColumnIfNotExist() when the column does not exist") { + val expectedOutput = + """+-----+---+ + ||value|foo| + |+-----+---+ + ||1 |1 | + ||1 |1 | + ||1 |1 | + ||2 |1 | + ||1 |1 | + |+-----+---+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val dfOut = SparkUtils.withColumnIfDoesNotExist(dfIn, "foo", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfOut.schema.length == 2) + assert(dfOut.schema.head.name == "value") + assert(dfOut.schema(1).name == "foo") + assert(actualOutput == expectedOutput) + } + + test("Test withColumnIfNotExist() when the column exists") { + val expectedOutput = + """+-----+----------------------------------------------------------------------------------------------+ + ||value|errCol | + |+-----+----------------------------------------------------------------------------------------------+ + ||1 |[] | + ||1 |[] | + ||1 |[] | + ||1 |[[confLitError, E00005, Conformance Error - Special column value has changed, value, [2], []]]| + ||1 |[] | + |+-----+----------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val dfOut = SparkUtils.withColumnIfDoesNotExist(dfIn, "value", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfIn.schema.length == 1) + assert(dfIn.schema.head.name == "value") + assert(actualOutput == expectedOutput) + } + + test("Test withColumnIfNotExist() when the column exists, but has a different case") { + val expectedOutput = + """+-----+----------------------------------------------------------------------------------------------+ + ||vAlUe|errCol | + |+-----+----------------------------------------------------------------------------------------------+ + ||1 |[] | + ||1 |[] | + ||1 |[] | + ||1 |[[confLitError, E00005, Conformance Error - Special column value has changed, vAlUe, [2], []]]| + ||1 |[] | + |+-----+----------------------------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val dfOut = SparkUtils.withColumnIfDoesNotExist(dfIn, "vAlUe", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfIn.schema.length == 1) + assert(dfIn.schema.head.name == "value") + assert(actualOutput == expectedOutput) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/time/DateTimePatternSuite.scala b/src/test/scala/za/co/absa/standardization/time/DateTimePatternSuite.scala new file mode 100644 index 0000000..6945d8b --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/time/DateTimePatternSuite.scala @@ -0,0 +1,271 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.time + +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.types.Section + +class DateTimePatternSuite extends AnyFunSuite { + + test("Pattern for timestamp") { + val pattern: String = "yyyy~mm~dd_HH.mm.ss" + val dateTimePattern = DateTimePattern(pattern) + assert(!dateTimePattern.isDefault) + assert(dateTimePattern.pattern == pattern) + assert(!dateTimePattern.isEpoch) + assert(0 == dateTimePattern.epochFactor) + } + + test("Pattern for date") { + val pattern: String = "yyyy~mm~dd" + val dateTimePattern = DateTimePattern(pattern) + assert(!dateTimePattern.isDefault) + assert(dateTimePattern.pattern == pattern) + assert(!dateTimePattern.isEpoch) + assert(dateTimePattern.epochFactor == 0) + } + + test("DateTimePattern.isEpoch should return true for known keywords, regardless of case") { + val result1 = DateTimePattern.isEpoch("epoch") + assert(result1) + val result2 = DateTimePattern.isEpoch("epochmilli") + assert(result2) + val result3 = DateTimePattern.isEpoch(" epoch ") + assert(!result3) + val result4 = DateTimePattern.isEpoch("add 54") + assert(!result4) + val result5 = DateTimePattern.isEpoch("") + assert(!result5) + val result6 = DateTimePattern.isEpoch("epochMicro") + assert(result6) + val result7 = DateTimePattern.isEpoch("EPOCHNANO") + assert(result7) + } + + test("DateTimePattern.epochFactor returns appropriate power of ten corresponding the keyword") { + var result = DateTimePattern.epochFactor("Epoch") + assert(result == 1L) + result = DateTimePattern.epochFactor("EpOcHmIlLi") + assert(result == 1000L) + result = DateTimePattern.epochFactor("EpochMICRO") + assert(result == 1000000L) + result = DateTimePattern.epochFactor("epochnano") + assert(result == 1000000000L) + result = DateTimePattern.epochFactor("zoom") + assert(result == 0L) + } + + test("Time zone in epoch pattern") { + val dateTimePattern1 = DateTimePattern("epoch") + assert(dateTimePattern1.timeZoneInPattern) + val dateTimePattern2 = DateTimePattern("epochmilli") + assert(dateTimePattern2.timeZoneInPattern) + val dateTimePattern3 = DateTimePattern("epochmicro") + assert(dateTimePattern3.timeZoneInPattern) + val dateTimePattern4 = DateTimePattern("epochnano") + assert(dateTimePattern4.timeZoneInPattern) + } + + test("Time zone NOT in pattern") { + val dateTimePattern1 = DateTimePattern("yyyy-MM-dd HH:mm:ss") + assert(!dateTimePattern1.timeZoneInPattern) + val dateTimePattern2 = DateTimePattern("") + assert(!dateTimePattern2.timeZoneInPattern) + } + + test("Standard time zone in pattern") { + val dateTimePattern1 = DateTimePattern("ZZ yyyy-MM-dd HH:mm:ss") + assert(dateTimePattern1.timeZoneInPattern) + val dateTimePattern2 = DateTimePattern(" HH:mm:ss ZZZZ yyyy-MM-dd") + assert(dateTimePattern2.timeZoneInPattern) + } + + test("Offset time zone in pattern") { + val dateTimePattern1 = DateTimePattern("yyyy-MM-dd HH:mm:ssXX") + assert(dateTimePattern1.timeZoneInPattern) + val dateTimePattern2 = DateTimePattern("HH:mm:ss XX yyyy-MM-dd") + assert(dateTimePattern2.timeZoneInPattern) + val dateTimePattern3 = DateTimePattern("XXX HH:mm:ss yyyy-MM-dd") + assert(dateTimePattern3.timeZoneInPattern) + } + + test("Time zone with literals in the pattern") { + val dateTimePattern1 = DateTimePattern("yyyy-MM-dd HH:mm:ss'zz'") + assert(!dateTimePattern1.timeZoneInPattern) + val dateTimePattern2 = DateTimePattern("'XXX: 'HH:mm:ss XX yyyy-MM-dd") + assert(dateTimePattern2.timeZoneInPattern) + val dateTimePattern3 = DateTimePattern("""'Date:'yyyy-MM-dd HH:mm:ss\'ZZ\'""") + assert(dateTimePattern3.timeZoneInPattern) + } + + test("Default time zone - not present") { + val dateTimePattern1 = DateTimePattern("yyyy-MM-dd HH:mm:ss") + assert(dateTimePattern1.defaultTimeZone.isEmpty) + val dateTimePattern2 = DateTimePattern("yyyy-MM-dd", assignedDefaultTimeZone = None) + assert(dateTimePattern2.defaultTimeZone.isEmpty) + val dateTimePattern3 = DateTimePattern("") + assert(dateTimePattern3.defaultTimeZone.isEmpty) + } + + test("Default time zone - present") { + val dateTimePattern1 = DateTimePattern("yyyy-MM-dd HH:mm:ss", assignedDefaultTimeZone = Some("CET")) + assert(dateTimePattern1.defaultTimeZone.contains("CET")) + val dateTimePattern2 = DateTimePattern("", assignedDefaultTimeZone = Some("")) + assert(dateTimePattern2.defaultTimeZone.contains("")) + } + + test("Default time zone - overridden by time zone in pattern") { + val dateTimePattern1 = DateTimePattern("yyyy-MM-dd HH:mm:ss zz", Some("CST")) //Standard time zone + assert(dateTimePattern1.defaultTimeZone.isEmpty) + val dateTimePattern2 = DateTimePattern("yyyy-MM-dd HH:mm:ssXX", Some("WST")) //Offset time zone + assert(dateTimePattern2.defaultTimeZone.isEmpty) + } + + test("Default time zone - epoch") { + val dateTimePattern1 = DateTimePattern("epochmilli", Some("WST")) + assert(dateTimePattern1.defaultTimeZone.isEmpty) + val dateTimePattern2 = DateTimePattern("epoch", Some("CET")) + assert(dateTimePattern2.defaultTimeZone.isEmpty) + val dateTimePattern3 = DateTimePattern("epochmicro", Some("WST")) + assert(dateTimePattern3.defaultTimeZone.isEmpty) + val dateTimePattern4 = DateTimePattern("epochnano", Some("CET")) + assert(dateTimePattern4.defaultTimeZone.isEmpty) + } + + test("Is NOT time-zoned ") { + val dateTimePattern1 = DateTimePattern("yyyy-MM-dd HH:mm:ss") + assert(!dateTimePattern1.isTimeZoned) + val dateTimePattern2 = DateTimePattern("yyyy-MM-dd", assignedDefaultTimeZone = None) + assert(!dateTimePattern2.isTimeZoned) + } + + test("Is time-zoned - default time zone") { + val dateTimePattern = DateTimePattern("yyyy-MM-dd HH:mm:ss", Some("EST")) + assert(dateTimePattern.isTimeZoned) + } + + test("Is time-zoned - standard time zone in pattern") { + val dateTimePattern = DateTimePattern("yyyy-MM-dd HH:mm:ss zz") //Standard time zone + assert(dateTimePattern.isTimeZoned) + } + + test("Is time-zoned - offset time zone in pattern") { + val dateTimePattern = DateTimePattern("yyyy-MM-dd HH:mm:ssXX") //Offset time zone + assert(dateTimePattern.isTimeZoned) + } + + test("Is time-zoned - epoch") { + val dateTimePattern = DateTimePattern("epoch") + assert(dateTimePattern.isTimeZoned) + } + + test("Second fractions detection in epoch") { + val dtp = DateTimePattern("epoch") + assert(dtp.millisecondsPosition.isEmpty) + assert(dtp.microsecondsPosition.isEmpty) + assert(dtp.nanosecondsPosition.isEmpty) + assert(dtp.secondFractionsSections.isEmpty) + assert(dtp.patternWithoutSecondFractions == "epoch") + assert(!dtp.containsSecondFractions) + } + + test("Second fractions detection in epochmilli") { + val dtp = DateTimePattern("epochmilli") + assert(dtp.millisecondsPosition.contains(Section(-3,3))) + assert(dtp.microsecondsPosition.isEmpty) + assert(dtp.nanosecondsPosition.isEmpty) + assert(dtp.secondFractionsSections == Seq(Section(-3,3))) + assert(dtp.patternWithoutSecondFractions == "epoch") + assert(dtp.containsSecondFractions) + } + + test("Second fractions detection in epochmicro") { + val dtp = DateTimePattern("epochmicro") + assert(dtp.millisecondsPosition.contains(Section(-6,3))) + assert(dtp.microsecondsPosition.contains(Section(-3,3))) + assert(dtp.nanosecondsPosition.isEmpty) + assert(dtp.secondFractionsSections == Seq(Section(-6,6))) + assert(dtp.patternWithoutSecondFractions == "epoch") + assert(dtp.containsSecondFractions) + } + + test("Second fractions detection in epochnano") { + val dtp = DateTimePattern("epochnano") + assert(dtp.millisecondsPosition.contains(Section(-9,3))) + assert(dtp.microsecondsPosition.contains(Section(-6,3))) + assert(dtp.nanosecondsPosition.contains(Section(-3,3))) + assert(dtp.secondFractionsSections == Seq(Section(-9,9))) + assert(dtp.patternWithoutSecondFractions == "epoch") + assert(dtp.containsSecondFractions) + } + + test("Second fractions detection in regular pattern - milliseconds") { + val pattern = "yyyy-MM-dd HH:mm:ss.SSS" + val dtp = DateTimePattern(pattern) + assert(dtp.millisecondsPosition.contains(Section(20,3))) + assert(dtp.microsecondsPosition.isEmpty) + assert(dtp.nanosecondsPosition.isEmpty) + assert(dtp.secondFractionsSections == Seq(Section(20,3))) + assert(dtp.patternWithoutSecondFractions == "yyyy-MM-dd HH:mm:ss.") + assert(dtp.containsSecondFractions) + } + + test("Second fractions detection in regular pattern - microseconds") { + val pattern = "yyyy-MM-dd HH:mm:ss.iiiiii" + val dtp = DateTimePattern(pattern) + assert(dtp.millisecondsPosition.isEmpty) + assert(dtp.microsecondsPosition.contains(Section(20,6))) + assert(dtp.nanosecondsPosition.isEmpty) + assert(dtp.secondFractionsSections == Seq(Section(20,6))) + assert(dtp.patternWithoutSecondFractions == "yyyy-MM-dd HH:mm:ss.") + assert(dtp.containsSecondFractions) + } + + test("Second fractions detection in regular pattern - nanoseconds") { + val pattern = "yyyy-MM-dd HH:mm:ss.nnnnnnnnn" + val dtp = DateTimePattern(pattern) + assert(dtp.millisecondsPosition.isEmpty) + assert(dtp.microsecondsPosition.isEmpty) + assert(dtp.nanosecondsPosition.contains(Section(20,9))) + assert(dtp.secondFractionsSections == Seq(Section(20,9))) + assert(dtp.patternWithoutSecondFractions == "yyyy-MM-dd HH:mm:ss.") + assert(dtp.containsSecondFractions) + } + + test("Second fractions detection in regular pattern - milli-, micro-, nanosecond combined") { + val pattern = "nnniii|yyyy-MM-dd SSS HH:mm:ss" + val dtp = DateTimePattern(pattern) + assert(dtp.millisecondsPosition.contains(Section(18,3))) + assert(dtp.microsecondsPosition.contains(Section(3,3))) + assert(dtp.nanosecondsPosition.contains(Section(0,3))) + assert(dtp.secondFractionsSections == Seq(Section(18,3), Section(0, 6))) + assert(dtp.patternWithoutSecondFractions == "|yyyy-MM-dd HH:mm:ss") + assert(dtp.containsSecondFractions) + } + + + test("Second fractions detection in regular pattern - not present") { + val pattern = "yyyy-MM-dd HH:mm:ss" + val dtp = DateTimePattern(pattern) + assert(dtp.millisecondsPosition.isEmpty) + assert(dtp.microsecondsPosition.isEmpty) + assert(dtp.nanosecondsPosition.isEmpty) + assert(dtp.secondFractionsSections.isEmpty) + assert(dtp.patternWithoutSecondFractions == "yyyy-MM-dd HH:mm:ss") + assert(!dtp.containsSecondFractions) + } +} diff --git a/src/test/scala/za/co/absa/standardization/types/DefaultsByFormatSuite.scala b/src/test/scala/za/co/absa/standardization/types/DefaultsByFormatSuite.scala new file mode 100644 index 0000000..8e5c8fe --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/types/DefaultsByFormatSuite.scala @@ -0,0 +1,92 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types + +import com.typesafe.config.{ConfigFactory, ConfigValueFactory} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.ConfigReader + +class DefaultsByFormatSuite extends AnyFunSuite { + + private val customTimestampConfig = new ConfigReader( + ConfigFactory.empty() + .withValue("defaultTimestampTimeZone", ConfigValueFactory.fromAnyRef("UTC")) // fallback to "obsolete" + .withValue("standardization.defaultTimestampTimeZone.json", ConfigValueFactory.fromAnyRef("WrongTimeZone")) + ) + + test("Format specific timestamp time zone override exists") { + val default = new DefaultsByFormat("xml") + assert(default.getDefaultTimestampTimeZone.contains("Africa/Johannesburg")) + } + + test("Format specific timestamp time zone override does not exists") { + val default = new DefaultsByFormat("txt") + assert(default.getDefaultTimestampTimeZone.contains("CET")) + } + + test("Format specific timestamp zone fallbacks to obsolete") { + val defaults = new DefaultsByFormat("xml", config = customTimestampConfig) + assert(defaults.getDefaultTimestampTimeZone.contains("UTC")) + } + + test("Format specific timestamp time zone override is not a valid time zone id") { + intercept[IllegalStateException] { + new DefaultsByFormat("json", config = customTimestampConfig) + } + } + + test("Date time zone does not exist at all") { + val default = new DefaultsByFormat("testFormat") + assert(default.getDefaultDateTimeZone.isEmpty) + } + + private val customDateConfig = new ConfigReader( + ConfigFactory.empty() + .withValue("defaultDateTimeZone", ConfigValueFactory.fromAnyRef("UTC")) // fallback to "obsolete" + .withValue("standardization.defaultDateTimeZone.default", ConfigValueFactory.fromAnyRef("PST")) + .withValue("standardization.defaultDateTimeZone.csv", ConfigValueFactory.fromAnyRef("JST")) + .withValue("standardization.defaultDateTimeZone.parquet", ConfigValueFactory.fromAnyRef("Gibberish")) + ) + + test("Format specific date time zone override exists") { + val defaults = new DefaultsByFormat("csv", config = customDateConfig) + assert(defaults.getDefaultDateTimeZone.contains("JST")) + } + + test("Format specific date time zone override does not exists") { + val defaults = new DefaultsByFormat("testFormat", config = customDateConfig) + assert(defaults.getDefaultDateTimeZone.contains("PST")) + } + + test("Format specific date time zone override is not a valid time zone id") { + intercept[IllegalStateException] { + new DefaultsByFormat("parquet", config = customDateConfig) + } + } + + test("Getting the obsolete settings") { + val localConfig = new ConfigReader( + ConfigFactory.empty() + .withValue("defaultTimestampTimeZone", ConfigValueFactory.fromAnyRef("PST")) + .withValue("defaultDateTimeZone", ConfigValueFactory.fromAnyRef("JST")) + ) + val defaults = new DefaultsByFormat("csv", config = localConfig) + assert(defaults.getDefaultTimestampTimeZone.contains("PST")) + assert(defaults.getDefaultDateTimeZone.contains("JST")) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/types/DefaultsSuite.scala b/src/test/scala/za/co/absa/standardization/types/DefaultsSuite.scala new file mode 100644 index 0000000..87dace2 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/types/DefaultsSuite.scala @@ -0,0 +1,94 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite + +import java.sql.{Date, Timestamp} +import java.util.TimeZone +import scala.util.Success + +class DefaultsSuite extends AnyFunSuite { + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + + test("ByteType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(ByteType, nullable = false) === Success(Some(0.toByte))) + } + + test("ShortType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(ShortType, nullable = false) === Success(Some(0.toShort))) + } + + test("IntegerType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(IntegerType, nullable = false) === Success(Some(0))) + } + + test("LongType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(LongType, nullable = false) === Success(Some(0L))) + } + + test("FloatType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(FloatType, nullable = false) === Success(Some(0F))) + } + + test("DoubleType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(DoubleType, nullable = false) === Success(Some(0D))) + } + + test("StringType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(StringType, nullable = false) === Success(Some(""))) + } + + test("DateType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(DateType, nullable = false) === Success(Some(new Date(0)))) + } + + test("TimestampType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(TimestampType, nullable = false) === Success(Some(new Timestamp(0)))) + } + + test("BooleanType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(BooleanType, nullable = false) === (Success(Some(false)))) + } + + test("DecimalType") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(DecimalType(6, 3), nullable = false) === Success(Some(BigDecimal("000.000")))) + } + + test("ArrayType") { + val dataType = ArrayType(StringType) + val result = GlobalDefaults.getDataTypeDefaultValueWithNull(dataType, nullable = false) + val e = intercept[IllegalStateException] { + result.get + } + assert(e.getMessage == s"No default value defined for data type ${dataType.typeName}") + } + + test("Nullable default is None") { + assert(GlobalDefaults.getDataTypeDefaultValueWithNull(BooleanType, nullable = true) === Success(None)) + } + + test("Default time zone for timestamps does not exists") { + assert(GlobalDefaults.getDefaultTimestampTimeZone.isEmpty) + } + + test("Default time zone for dates does not exist") { + assert(GlobalDefaults.getDefaultDateTimeZone.isEmpty) + } +} + diff --git a/src/test/scala/za/co/absa/standardization/types/SectionSuite.scala b/src/test/scala/za/co/absa/standardization/types/SectionSuite.scala new file mode 100644 index 0000000..5c12b6e --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/types/SectionSuite.scala @@ -0,0 +1,713 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types + +import java.security.InvalidParameterException + +import org.scalatest.funsuite.AnyFunSuite + +import scala.util.{Failure, Try} + +class SectionSuite extends AnyFunSuite { + + private def checkSectionRemoveExtractInject( + section: Section, + fullString: String, + remainder: String, + extracted: String + ): Boolean = { + assert(section.removeFrom(fullString) == remainder) + assert(section.extractFrom(fullString) == extracted) + section.injectInto(remainder, extracted).toOption.contains(fullString) + } + + private def circularCheck(start: Integer, length: Int, string: String): Boolean = { + val section = Section(start, length) + val removeResult = section.removeFrom(string) + val extractResult = section.extractFrom(string) + val injectResult = section.injectInto(removeResult, extractResult) + injectResult.toOption.contains(string) + } + + private def checkTryOnFailure(result: Try[String], failureMessage: String): Boolean = { + result match { + case Failure(e: InvalidParameterException) => e.getMessage == failureMessage + case _ => false + } + } + + private val invalidParameterExceptionMessageTemplate = + "The length of the string to inject (%d) doesn't match Section(%d, %d) for string of length %d." + + + test("Negative length is turned into length 0") { + assert(Section(3, -1) == Section(3, 0)) + } + + test("Section end doesn't overflow integer") { + val start = Int.MaxValue - 2 + assert(Section(start, 3) == Section(start, 2)) + } + + //sorting + test("Sorting") { + val inputSeq = Seq( + Section(-11, 3), + Section(-13, 5), + Section(-13, 1), + Section(-12, 2), + Section(6, 6), + Section(6, 1), + Section(2, 3), + Section(4, 1), + Section(0, 1) + ) + val expectedSeq = Seq( + Section(0, 1), + Section(2, 3), + Section(4, 1), + Section(6, 1), + Section(6, 6), + Section(-11, 3), + Section(-12, 2), + Section(-13, 1), + Section(-13, 5) + ) + + assert(inputSeq.sorted == expectedSeq) + } + + //toSubstringParameters + test("toSubstringParameters: simple case with positive value of Start") { + val start = 3 + val length = 5 + val after = 8 + + val section = Section(start, length) + + val result1 = section.toSubstringParameters("Hello world") + assert(result1 == (start, after)) + val result2 = section.toSubstringParameters("") + assert(result2 == (0, 0)) + } + + test("toSubstringParameters: with negative value of Start, within bounds of the string") { + val start = -6 + val length = 3 + + val section = Section(start, length) + + val result = section.toSubstringParameters("Hello world") + assert(result == (5, 8)) + } + + test("toSubstringParameters: with negative value of Start, on a too short string") { + val start = -5 + val length = 3 + + val section = Section(start, length) + + val result = section.toSubstringParameters("foo") + assert(result == (0, 1)) + } + + //extract, remove, inject + test("extractFrom, removeFrom, injectInto: with positive value of Start within the input string") { + val section = Section(2,4) + val fullString = "abcdefghi" + val remainder = "abghi" + val extracted = "cdef" + assert(checkSectionRemoveExtractInject(section, fullString, remainder, extracted)) + } + + test("extractFrom, removeFrom, injectInto: with positive value of Start till the end of the input string") { + val section = Section(4,2) + val fullString = "abcdef" + val remainder = "abcd" + val extracted = "ef" + assert(checkSectionRemoveExtractInject(section, fullString, remainder, extracted)) + } + + test("extractFrom, removeFrom, injectInto: with positive value of Start, extending over the end of the input string") { + val section = Section(4,4) + val fullString = "abcdef" + val remainder = "abcd" + val extracted = "ef" + assert(checkSectionRemoveExtractInject(section, fullString, remainder, extracted)) + } + + test("extractFrom, removeFrom, injectInto: with positive value of Start, Start beyond the end of input string") { + val section = Section(10,7) + val fullString = "abcdef" + val remainder = "abcdef" + val extracted = "" + assert(checkSectionRemoveExtractInject(section, fullString, remainder, extracted)) + } + + test("extractFrom, removeFrom, injectInto: with negative value of Start, within the input string") { + val section = Section (-4,2) + val fullString = "abcdef" + val remainder = "abef" + val extracted = "cd" + assert(checkSectionRemoveExtractInject(section, fullString, remainder, extracted)) + } + + test("extractFrom, removeFrom, injectInto: with negative value of Start, before beginning of the input string") { + val section = Section (-8,5) + val fullString = "abcdef" + val remainder = "def" + val extracted = "abc" + assert(checkSectionRemoveExtractInject(section, fullString, remainder, extracted)) + } + + test("extractFrom, removeFrom, injectInto: with negative value of Start, far before beginning of the input string") { + val section = Section (-10,3) + val fullString = "abcdef" + val remainder = "abcdef" + val extracted = "" + assert(checkSectionRemoveExtractInject(section, fullString, remainder, extracted)) + } + + test("extractFrom, removeFrom, injectInto: with negative value of Start, extending over end of the input string") { + val section = Section (-2,4) + val fullString = "abcdef" + val remainder = "abcd" + val extracted = "ef" + assert(checkSectionRemoveExtractInject(section, fullString, remainder, extracted)) + } + + test("extractFrom, removeFrom, injectInto: zero length sections") { + val section1 = Section (2,0) + val section2 = Section (0,0) + val section3 = Section (-2,0) + val section4 = Section (20,0) + val section5 = Section (-20,0) + val fullString = "abcdef" + val remainder = "abcdef" + val extracted = "" + assert(checkSectionRemoveExtractInject(section1, fullString, remainder, extracted)) + assert(checkSectionRemoveExtractInject(section2, fullString, remainder, extracted)) + assert(checkSectionRemoveExtractInject(section3, fullString, remainder, extracted)) + assert(checkSectionRemoveExtractInject(section4, fullString, remainder, extracted)) + assert(checkSectionRemoveExtractInject(section5, fullString, remainder, extracted)) + } + + test("extractFrom, removeFrom, injectInto (automated): whole spectrum of Start, length 0") { + val playString = "abcdefghij" + val length = 0 + for (i <- -15 to 15) { + assert(circularCheck(i, length, playString)) + } + } + + test("extractFrom, removeFrom, injectInto (automated): whole spectrum of STart, length 3") { + val playString = "abcdefghij" + val length = 3 + for (i <- -15 to 15) { + assert(circularCheck(i, length, playString)) + } + } + //inject + test("injectInto: injected string too long compared to Section") { + val what = "what" + val into = "This is long enough" + val sec1 = Section(0, 3) + val sec2 = Section(2,3) + val sec3 = Section(-4, 3) + + var result = sec1.injectInto(into, what) + assert(checkTryOnFailure(result, invalidParameterExceptionMessageTemplate.format(4, 0, 3, 19))) + result = sec2.injectInto(into, what) + assert(checkTryOnFailure(result, invalidParameterExceptionMessageTemplate.format(4, 2, 3, 19))) + result = sec3.injectInto(into, what) + assert(checkTryOnFailure(result, invalidParameterExceptionMessageTemplate.format(4, -4, 3, 19))) + } + + test("injectInto: injected string too short compared to Section length") { + val into = "abcdef" + + //within but short + val section1 = Section(3, 3) + var result1= section1.injectInto(into, "xx") + assert(checkTryOnFailure(result1, invalidParameterExceptionMessageTemplate.format(2, 3, 3, 6))) + //within from behind, but short + val section2 = Section(-5, 3) + val result2 = section2.injectInto(into, "xx") + assert(checkTryOnFailure(result2, invalidParameterExceptionMessageTemplate.format(2, -5, 3, 6))) + //too far behind + val section3 = Section(10, 3) + val result3 = section3.injectInto(into, "xx") + assert(checkTryOnFailure(result3, invalidParameterExceptionMessageTemplate.format(2, 10, 3, 6))) + //too far ahead + val section4 = Section(-7, 3) + val result4 = section4.injectInto(into, "xx") + assert(checkTryOnFailure(result4, invalidParameterExceptionMessageTemplate.format(2, -7, 3, 6))) + } + + test("injectInto: empty string") { + val into = "abcdef" + val what = "" + // ok for section length 0 + assert(Section(3, 0).injectInto(into, what).toOption.contains(into)) + // ok for section behind + assert(Section(6, 3).injectInto(into, what).toOption.contains(into)) + // ok for section far enough ahead + assert(Section(-8, 2).injectInto(into, what).toOption.contains(into)) + // fails otherwise + val section1 = Section(2, 2) + val result = section1.injectInto(into, what) + assert(checkTryOnFailure(result, invalidParameterExceptionMessageTemplate.format(0, 2, 2, 6))) + } + + test("injectInto: Special fail on seemingly correct input, but not if considered as reverse to remove and except") { + val into = "abcdef" + val what = "xxx" + val section = Section(-2, 3) + val result = section.injectInto(into, what) + assert(checkTryOnFailure(result, invalidParameterExceptionMessageTemplate.format(3, -2, 3, 6))) + } + + //distance + test("distance: two Sections with one negative and other positive value of Start") { + val a = Section(-1, 12) + val b = Section(3, 2) + assert((a distance b) == (b distance a)) + assert((a distance b).isEmpty) + } + + test("distance: two Sections with same positive values of Start") { + val a = Section(5, 3) + val b = Section(5, 3) + assert((a distance b) == (b distance a)) + assert((a distance b).contains(-3)) + } + + test("distance: two Sections with positive value of Start, gap between them") { + val a = Section(3, 2) + val b = Section(6, 2) + assert((a distance b) == (b distance a)) + assert((a distance b).contains(1)) + } + + test("distance: two Sections with positive value of Start, adjacent to each other") { + val a = Section(3, 2) + val b = Section(5, 2) + assert((a distance b) == (b distance a)) + assert((a distance b).contains(0)) + } + + test("distance: two Sections with positive value of Start, overlapping") { + val a = Section(3, 4) + val b = Section(5, 3) + assert((a distance b) == (b distance a)) + assert((a distance b).contains(-2)) + } + + test("distance: two Sections with positive value of Start, one within other") { + val a = Section(3, 3) + val b = Section(4, 1) + assert((a distance b) == (b distance a)) + assert((a distance b).contains(-2)) + } + + test("distance: two Sections with same negative values of Start") { + val a = Section(-5, 3) + val b = Section(-5, 3) + assert((a distance b) == (b distance a)) + assert((a distance b).contains(-3)) + } + + test("distance: two Sections with negative value of Start, gap between them") { + val a = Section(-3, 2) + val b = Section(-6, 2) + assert((a distance b) == (b distance a)) + assert((a distance b).contains(1)) + } + + test("distance: two Sections with negative value of Start, adjacent to each other") { + val a = Section(-3, 2) + val b = Section(-5, 2) + assert((a distance b) == (b distance a)) + assert((a distance b).contains(0)) + } + + test("distance: two Sections with negative value of Start, overlapping") { + val a = Section(-3, 4) + val b = Section(-5, 3) + assert((a distance b) == (b distance a)) + assert((a distance b).contains(-1)) + } + + test("distance: two Sections with negative value of Start, one within other") { + val a = Section(-5, 3) + val b = Section(-4, 1) + assert((a distance b) == (b distance a)) + assert((a distance b).contains(-2)) + } + + //overlaps + test("overlaps: no overlap - with positive values of Start") { + val a = Section(1, 2) + val b = Section(4, 2) + assert((a overlaps b) == (b overlaps a)) + assert(!(a overlaps b)) + } + + test("overlaps: touching - with positive values of Start") { + val a = Section(1, 2) + val b = Section(3, 2) + assert((a overlaps b) == (b overlaps a)) + assert(!(a overlaps b)) + } + + test("overlaps: overlap - with positive values of Start") { + val a = Section(1, 3) + val b = Section(3, 3) + assert((a overlaps b) == (b overlaps a)) + assert(a overlaps b) + } + + test("overlaps: no overlap - with negative values of Start") { + val a = Section(-3, 2) + val b = Section(-6, 2) + assert((a overlaps b) == (b overlaps a)) + assert(!(a overlaps b)) + } + + test("overlaps: touching - with negative values of Start") { + val a = Section(-3, 2) + val b = Section(-6, 3) + assert((a overlaps b) == (b overlaps a)) + assert(!(a overlaps b)) + } + + test("overlaps: overlap - with negative values of Start") { + val a = Section(-3, 2) + val b = Section(-6, 4) + assert((a overlaps b) == (b overlaps a)) + assert(a overlaps b) + } + + test("overlaps: overlap - one with negative value of Start and one with positive value of Start") { + val a = Section(-1, 2) + val b = Section(4, 2) + assert((a overlaps b) == (b overlaps a)) + assert(!(a overlaps b)) + } + + //touches + test("touches: no overlap - with positive values of Start") { + val a = Section(1, 2) + val b = Section(4, 2) + assert((a touches b) == (b touches a)) + assert(!(a touches b)) + } + + test("touches: touching - with positive values of Start") { + val a = Section(1, 2) + val b = Section(3, 2) + assert((a touches b) == (b touches a)) + assert(a touches b) + } + + test("touches: overlap - with positive values of Start") { + val a = Section(1, 3) + val b = Section(3, 3) + assert((a touches b) == (b touches a)) + assert(a touches b) + } + + test("touches: no overlap - with negative values of Start") { + val a = Section(-3, 2) + val b = Section(-6, 2) + assert((a touches b) == (b touches a)) + assert(!(a touches b)) + } + + test("touches: touching - with negative values of Start") { + val a = Section(-3, 2) + val b = Section(-6, 3) + assert((a touches b) == (b touches a)) + assert(a touches b) + } + + test("touches: overlap - with negative values of Start") { + val a = Section(-3, 2) + val b = Section(-6, 4) + assert((a touches b) == (b touches a)) + assert(a touches b) + } + + test("touches: overlap - one with negative value of Start and one with positive value of Start") { + val a = Section(-1, 2) + val b = Section(4, 2) + assert((a touches b) == (b touches a)) + assert(!(a touches b)) + } + + //ofSameChars + test("ofSameChars: single character") { + val section = Section.ofSameChars("abcdefghijkl", 4) + val expected = Section(4, 1) + assert(section == expected) + } + + test("ofSameChars: more characters") { + val section = Section.ofSameChars("aabbbccccddddd", 5) + val expected = Section(5, 4) + assert(section == expected) + } + + test("ofSameChars: more characters, fromIndex start within the sequence of same chars") { + val section = Section.ofSameChars("aabbbbbccccddddd", 4) + val expected = Section(4, 3) + assert(section == expected) + } + + test("ofSameChars: start out of input string range") { + val section = Section.ofSameChars("xxxxyyyzz", 20) + val expected = Section(20, 0) + assert(section == expected) + } + + test("ofSameChars: with negative value of Start") { + val section = Section.ofSameChars("xxxxyyyzz", -5) + val expected = Section(-5, 3) + assert(section == expected) + } + + test("ofSameChars: with negative value of Start, 'in front' of input string") { + val section = Section.ofSameChars("xxxxyyyzz", -15) + val expected = Section(-15, 0) + assert(section == expected) + } + + //removeMultiple + test("removeMultiple: two Sections") { + val sections = Seq(Section(2, 2), Section(6, 3)) + val result = Section.removeMultipleFrom("abcdefghijkl", sections) + val expected = "abefjkl" + assert(result == expected) + } + + test("removeMultiple: three Sections, unordered") { + val sections = Seq(Section(6, 3), Section(2, 2), Section(11, 2)) + val result = Section.removeMultipleFrom("abcdefghijklmnop", sections) + val expected = "abefjknop" + assert(result == expected) + } + + test("removeMultiple: adjacent Sections") { + val sections = Seq(Section(2, 2), Section(11, 2), Section(6, 5)) + val result = Section.removeMultipleFrom("abcdefghijklmnop", sections) + val expected = "abefnop" + assert(result == expected) + } + + test("removeMultiple: overlapping Sections") { + val sections = Seq(Section(2, 2), Section(10, 3), Section(6, 5)) + val result = Section.removeMultipleFrom("abcdefghijklmnop", sections) + val expected = "abefnop" + assert(result == expected) + } + + test("removeMultiple: one with negative value of Start, other positive") { + val sections = Seq(Section(1, 1), Section(-2, 1)) + val result = Section.removeMultipleFrom("abcdefghijklmnop", sections) + val expected = "acdefghijklmnp" + assert(result == expected) + } + + test("removeMultiple: two Sections with negative value of Start, overlapping") { + val sections = Seq(Section(-3, 3), Section(-5, 3)) + val result = Section.removeMultipleFrom("abcdefghijklmnop", sections) + val expected = "abcdefghijk" + assert(result == expected) + } + + test("removeMultiple: two Sections running out of bounds of input string") { + val sections = Seq(Section(-6, 2), Section(4, 2)) + val result = Section.removeMultipleFrom("abcde", sections) + val expected = "bcd" + assert(result == expected) + } + + test("removeMultiple: two Sections totally out of bounds of input string") { + val sections = Seq(Section(-10, 2), Section(10, 2)) + val result = Section.removeMultipleFrom("abcde", sections) + val expected = "abcde" + assert(result == expected) + } + + test("removeMultiple: zero length") { + val sections = Seq(Section(1, 0), Section(-2, 0)) + val result = Section.removeMultipleFrom("abcdefghijklmnop", sections) + val expected = "abcdefghijklmnop" + assert(result == expected) + } + + //mergeTouchingSectionsAndSort + test("mergeTouchingSectionsAndSort: empty sequence of sections") { + val sections = Seq.empty[Section] + val result = Section.mergeTouchingSectionsAndSort(sections) + assert(result.isEmpty) + } + + test("mergeTouchingSectionsAndSort: one item in input sequence") { + val sections = Seq(Section(42, 7)) + val result = Section.mergeTouchingSectionsAndSort(sections) + assert(result == sections) + } + + /* For a string: + * 01234567890ACDFEFGHIJKLMNOPQUSTUVWXYZ + * ^ ^^^ ^-^^^ ^ + * | | | | | | + * | | Section(5,1) | | Section(-1,1) + * | Section(3,2) | Section(-4,2) + * Section(1,1) Section(-7,3) + * Output of the merge: + * 01234567890ACDFEFGHIJKLMNOPQUSTUVWXYZ + * ^ ^-^ ^---^ ^ + * | | | | + * | Section(3,3) | Section(-1,1) + * Section(1,1) Section(-7,5) + */ + test("mergeTouchingSectionsAndSort: two pairs of touching sections, ordering checked too") { + val sections = Seq( + Section( 1, 1), //D + Section( 3, 2), //B1 + Section(-4, 2), //A2 + Section( 5, 1), //B2 + Section(-7, 3), //A1 + Section(-1, 1) //B + ) + val expected = Seq( + Section(-7, 5), //->A + Section(-1, 1), //->B + Section( 3, 3), //->C + Section( 1, 1) //->D + ) + val result = Section.mergeTouchingSectionsAndSort(sections) + assert(result == expected) + } + + test("mergeTouchingSectionsAndSort: two pairs of overlapping sections, ordering checked too") { + val sections = Seq( + //overlapping + Section( 11, 4), //B1 + Section(-20, 6), //A2 + Section( 12, 5), //B2 + Section(-23, 6), //A1 + Section( -1, 1), //B + Section( 1, 1) //D + ) + + val expected = Seq( + Section(-23, 9), //->A + Section( -1, 1), //->B + Section( 11, 6), //->C + Section( 1, 1) //->D + ) + val result = Section.mergeTouchingSectionsAndSort(sections) + assert(result == expected) + } + + test("mergeTouchingSectionsAndSort: two pairs of same sections") { + val sections = Seq( + //overlapping + Section( 7, 5), //C + Section(-7, 5), //A + Section( 7, 5), //C + Section(-7, 5), //A + Section(-1, 1), //B + Section( 1, 1) //D + ) + + val expected = Seq( + Section(-7, 5), //->A + Section(-1, 1), //->B + Section( 7, 5), //->C + Section( 1, 1) //->D + ) + val result = Section.mergeTouchingSectionsAndSort(sections) + assert(result == expected) + } + + test("mergeTouchingSectionsAndSort: one section withing other") { + val sections = Seq( + //overlapping + Section( 12, 1), // E + Section(-20, 6), // B! + Section( 12, 5), // E! + Section(-20, 3), // B + Section( -1, 1), // C! + Section( 1, 1), // F! + Section( 23, 2), // D + Section( 20, 5), // D! + Section(-30, 9), // A! + Section(-27, 6) // A + ) + + val expected = Seq( + Section(-30, 9), // ->A + Section(-20, 6), // ->B + Section( -1, 1), // ->C + Section( 20, 5), // ->D + Section( 12, 5), // ->E + Section( 1, 1) // ->F + ) + val result = Section.mergeTouchingSectionsAndSort(sections) + assert(result == expected) + } + + test("mergeTouchingSectionsAndSort: sequence of 3 sections") { + val sections = Seq( + //overlapping + Section( 22, 10), //C3 + Section(-42, 10), //A1 + Section( 10, 10), //C1 + Section(-35, 10), //A2 + Section( 15, 10), //C2 + Section(-30, 10), //A3 + Section( -1, 1), //B + Section( 1, 1) //D + ) + + val expected = Seq( + Section(-42, 22), //->A + Section( -1, 1), //->B + Section( 10, 22), //->C + Section( 1, 1) //->D + ) + val result = Section.mergeTouchingSectionsAndSort(sections) + assert(result == expected) + } + + test("copy: change start") { + val s1 = Section(1, 3) + val s2 = s1.copy(start = 3) + assert(s2 == Section(3, 3)) + } + + test("copy: change length") { + val s1 = Section(-4, 3) + val s2 = s1.copy(length = 2) + assert(s2 == Section(-4, 2)) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/types/TypedStructFieldSuite.scala b/src/test/scala/za/co/absa/standardization/types/TypedStructFieldSuite.scala new file mode 100644 index 0000000..349e66d --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/types/TypedStructFieldSuite.scala @@ -0,0 +1,281 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.{ValidationError, ValidationIssue, ValidationWarning} +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.TypedStructField._ + +import java.text.ParseException +import scala.util.{Failure, Success, Try} + +class TypedStructFieldSuite extends AnyFunSuite { + private implicit val defaults: Defaults = GlobalDefaults + private val fieldName = "test_field" + private def createField(dataType: DataType, + nullable: Boolean = false, + default: Option[Any] = None, + otherMetadata: Map[String, Any] = Map.empty + ): StructField = { + def addMetadata(builder: MetadataBuilder, key: String, value: Option[Any]): MetadataBuilder = { + value match { + case None => builder + case Some(null) => builder.putNull(key) // scalastyle:ignore null + case Some(s: String) => builder.putString(key, s) + case Some(i: Int) => builder.putLong(key, i) + case Some(l: Long) => builder.putLong(key, l) + case Some(b: Boolean) => builder.putBoolean(key, b) + case Some(f: Float) => builder.putDouble(key, f) + case Some(d: Double) => builder.putDouble(key, d) + case Some(x) => builder.putString(key, x.toString) + } + } + val metadataBuilder: MetadataBuilder = otherMetadata.foldLeft(new MetadataBuilder()) ( + (builder, data) => addMetadata(builder, key = data._1, value = Some(data._2))) + val metadata = addMetadata(metadataBuilder, MetadataKeys.DefaultValue, default).build() + StructField(fieldName, dataType, nullable,metadata) + } + + def checkField(field: TypedStructField, + dataType: DataType, + ownDefaultValue: Try[Option[Option[Any]]], + defaultValueWithGlobal: Try[Option[Any]], + nullable: Boolean = false, + validationIssues: Seq[ValidationIssue] = Nil): Unit = { + + def assertTry(got: Try[Any], expected:Try[Any]): Unit = { + expected match { + case Success(_) => assert(got == expected) + case Failure(e) => + val caught = intercept[Exception] { + got.get + } + assert(caught.getClass == e.getClass) + assert(caught.getMessage == e.getMessage) + } + } + + assert(field.name == fieldName) + assert(field.dataType == dataType) + val (correctType, expectedTypeName) = dataType match { + case ByteType => (field.isInstanceOf[ByteTypeStructField], "ByteTypeStructField") + case ShortType => (field.isInstanceOf[ShortTypeStructField], "ShortTypeStructField") + case IntegerType => (field.isInstanceOf[IntTypeStructField], "IntTypeStructField") + case LongType => (field.isInstanceOf[LongTypeStructField], "LongTypeStructField") + case FloatType => (field.isInstanceOf[FloatTypeStructField], "FloatTypeStructField") + case DoubleType => (field.isInstanceOf[DoubleTypeStructField], "DoubleTypeStructField") + case StringType => (field.isInstanceOf[StringTypeStructField], "StringTypeStructField") + case BinaryType => (field.isInstanceOf[BinaryTypeStructField], "BinaryTypeStructField") + case BooleanType => (field.isInstanceOf[BooleanTypeStructField], "BooleanTypeStructField") + case DateType => (field.isInstanceOf[DateTypeStructField], "DateTypeStructField") + case TimestampType => (field.isInstanceOf[TimestampTypeStructField], "TimestampTypeStructField") + case _: DecimalType => (field.isInstanceOf[DecimalTypeStructField], "DecimalTypeStructField") + case _: ArrayType => (field.isInstanceOf[ArrayTypeStructField], "ArrayTypeStructField") + case _: StructType => (field.isInstanceOf[StructTypeStructField], "StructTypeStructField") + case _ => (field.isInstanceOf[GeneralTypeStructField], "GeneralTypeStructField") + } + assert(correctType, s"\nWrong TypedStructField type. Expected: '$expectedTypeName', but got: '${field.getClass.getSimpleName}'") + assert(field.nullable == nullable) + assertTry(field.ownDefaultValue, ownDefaultValue) + assertTry(field.defaultValueWithGlobal, defaultValueWithGlobal) + assert(field.validate() == validationIssues) + } + + test("String type without default defined") { + val fieldType = StringType + val field = createField(fieldType) + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(None), Success(Some(""))) + } + + test("Integer type without default defined, nullable") { + val fieldType = IntegerType + val nullable = true + val field = createField(fieldType, nullable) + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(None), Success(None), nullable) + } + + test("Double type with default defined, not-nullable") { + val fieldType = DoubleType + val nullable = false + val field = createField(fieldType, nullable, Some("3.14")) + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(Some(Some(3.14))), Success(Some(3.14)), nullable) + } + + test("Date type with default defined as null, nullable") { + val fieldType = DateType + val nullable = true + val field = createField(fieldType, nullable, Some(null)) // scalastyle:ignore null + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(Some(None)), Success(None), nullable) + } + + test("StructType, not nullable") { + val innerField = createField(FloatType) + val fieldType = StructType(Seq(innerField)) + val nullable = false + val field = createField(fieldType, nullable) + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(None), Failure(new IllegalStateException("No default value defined for data type struct")), nullable) + } + + test("String type not nullable, with default defined as null") { + val fieldType = StringType + val nullable = false + val field = createField(fieldType, nullable, Some(null)) // scalastyle:ignore null + val typed = TypedStructField(field) + val errMsg = s"null is not a valid value for field '$fieldName'" + val fail = Failure(new IllegalArgumentException(errMsg)) + checkField(typed, fieldType, fail, fail, nullable, Seq(ValidationError(errMsg))) + } + + test("Binary type not nullable, with default defined as null") { + val fieldType = BinaryType + val nullable = false + val field = createField(fieldType, nullable, Some(null)) // scalastyle:ignore null + val typed = TypedStructField(field) + val errMsg = s"null is not a valid value for field '$fieldName'" + val warnMsg ="Default value of 'null' found, but no encoding is specified. Assuming 'none'." + val fail = Failure(new IllegalArgumentException(errMsg)) + checkField(typed, fieldType, fail, fail, nullable, Seq(ValidationError(errMsg), ValidationWarning(warnMsg))) + } + + test("Byte type not nullable, with default defined as not not-numeric string") { + val fieldType = ByteType + val nullable = false + val field = createField(fieldType, nullable, Some("seven")) + val typed = TypedStructField(field) + val errMsg = "'seven' cannot be cast to byte" + val fail = Failure(new NumberFormatException(errMsg)) + checkField(typed, fieldType, fail, fail, nullable, Seq(ValidationError(errMsg))) + } + + test("Long type not nullable, with default defined as binary integer") { + val fieldType = LongType + val nullable = false + val field = createField(fieldType, nullable, Some(-1L)) + val typed = TypedStructField(field) + val errMsg = "java.lang.Long cannot be cast to java.lang.String" + val fail = Failure(new ClassCastException(errMsg)) + checkField(typed, fieldType, fail, fail, nullable, Seq(ValidationError(errMsg))) + } + + test("Float type nullable, with default defined in exponential notation, allowInfinity is set to true") { + val fieldType = FloatType + val nullable = true + val field = createField(fieldType, nullable, Some("314e-2"), Map(MetadataKeys.AllowInfinity -> "true")) + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(Some(Some(3.14F))), Success(Some(3.14F)), nullable) + } + + test("Boolean type nullable, with default defined as wrong keyword") { + val fieldType = BooleanType + val nullable = true + val field = createField(fieldType, nullable, Some("Nope")) + val typed = TypedStructField(field) + val errMsg = "'Nope' cannot be cast to boolean" + val fail = Failure(new IllegalArgumentException(errMsg)) + checkField(typed, fieldType, fail, fail, nullable, Seq(ValidationError(errMsg))) + } + + test("Timestamp type not nullable, with default not adhering to pattern") { + val fieldType = TimestampType + val nullable = false + val field = createField(fieldType, nullable, Some("00:00:00 01.01.2000"), Map("pattern" -> "yyyy-MM-dd HH:mm:ss X")) + val typed = TypedStructField(field) + val errMsg = """Unparseable date: "00:00:00 01.01.2000"""" + val fail = Failure(new ParseException(errMsg, 0)) + checkField(typed, fieldType, fail, fail, nullable, Seq(ValidationError(errMsg))) + } + + test("Float type nullable, with default defined as Long and allowInfinity as binary Boolean") { + val fieldType = FloatType + val nullable = false + val field = createField(fieldType, nullable, Some(1000L), Map( MetadataKeys.AllowInfinity->false )) + val typed = TypedStructField(field) + val errMsg = "java.lang.Long cannot be cast to java.lang.String" + val fail = Failure(new ClassCastException(errMsg)) + checkField(typed, fieldType, fail, fail, nullable, Seq( + ValidationError(errMsg), + ValidationError(s"${MetadataKeys.AllowInfinity} metadata value of field 'test_field' is not Boolean in String format") + )) + } + + test("Decimal type with strictParsing enabled, incorrect value") { + val fieldType = DecimalType(10, 2) + val nullable = false + val field = createField(fieldType, nullable, Some("1000.89899"), Map( MetadataKeys.StrictParsing -> "true" )) + val typed = TypedStructField(field) + val errMsg = "'1000.89899' cannot be cast to decimal(10,2)" + val fail = Failure(new IllegalArgumentException(errMsg)) + checkField(typed, fieldType, fail, fail, nullable, Seq(ValidationError(errMsg))) + } + test("Decimal type with strictParsing enabled, correct value") { + val fieldType = DecimalType(10, 2) + val nullable = false + val field = createField(fieldType, nullable, Some("1000.9"), Map( MetadataKeys.StrictParsing -> "true" )) + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(Some(Some(1000.9))), Success(Some(1000.9)), nullable, Seq()) + } + + test("Decimal type with incorrect strictParsing value") { + val fieldType = DecimalType(10, 2) + val nullable = false + val field = createField(fieldType, nullable, Some("1000.889"), Map( MetadataKeys.StrictParsing -> "t" )) + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(Some(Some(1000.889))), Success(Some(1000.889)), nullable, Seq( + ValidationError(s"${MetadataKeys.StrictParsing} metadata value of field 'test_field' is not Boolean in String format")) + ) + } + + test("Decimal type with false strictParsing") { + val fieldType = DecimalType(10, 2) + val nullable = false + val field = createField(fieldType, nullable, Some("1000.8889"), Map( MetadataKeys.StrictParsing -> "false" )) + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(Some(Some(1000.8889))), Success(Some(1000.8889)), nullable, Seq()) + } + + test("Decimal type with no set strictParsing") { + val fieldType = DecimalType(10, 2) + val nullable = false + val field = createField(fieldType, nullable, Some("1000.8889"), Map()) + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(Some(Some(1000.8889))), Success(Some(1000.8889)), nullable, Seq()) + } + + test("Array type with long element data type and correct associated metadata") { + val fieldType = ArrayType(LongType) + val nullable = true + val field = createField(fieldType, nullable, Some("9999"), Map( MetadataKeys.MinusSign->"$")) + val typed = TypedStructField(field) + checkField(typed, fieldType, Success(None), Success(None), nullable) + } + + test("Array type with date element data type and incorrect associated metadata") { + val fieldType = ArrayType(DateType) + val nullable = true + val field = createField(fieldType, nullable, None, Map( MetadataKeys.Pattern->"Fubar")) + val typed = TypedStructField(field) + val errMsg = "Illegal pattern character 'b'" + checkField(typed, fieldType, Success(None), Success(None), nullable, Seq(ValidationError(errMsg))) + } +} diff --git a/src/test/scala/za/co/absa/standardization/types/parsers/DateTimeParserSuite.scala b/src/test/scala/za/co/absa/standardization/types/parsers/DateTimeParserSuite.scala new file mode 100644 index 0000000..3d7771f --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/types/parsers/DateTimeParserSuite.scala @@ -0,0 +1,203 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.time.TimeZoneNormalizer + +import java.sql.{Date, Timestamp} +import java.text.{ParseException, SimpleDateFormat} + +case class TestInputRow(id: Int, stringField: String) + +class DateTimeParserSuite extends AnyFunSuite{ + TimeZoneNormalizer.normalizeJVMTimeZone() + + test("EnceladusDateParser class epoch") { + val parser = DateTimeParser("epoch") + + val value: String = "1547553153" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("2019-01-15") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("2019-01-15 11:52:33") + assert(resultTimestamp == expectedTimestamp) + } + + test("EnceladusDateParser class epochmilli") { + val parser = DateTimeParser("epochmilli") + + val value: String = "1547553153198" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("2019-01-15") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("2019-01-15 11:52:33.198") + assert(resultTimestamp == expectedTimestamp) + } + + test("EnceladusDateParser class epochmicro") { + val parser = DateTimeParser("epochmicro") + + val value: String = "1547553153198765" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("2019-01-15") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("2019-01-15 11:52:33.198765") + assert(resultTimestamp == expectedTimestamp) + } + + test("EnceladusDateParser class epochnano") { + val parser = DateTimeParser("epochnano") + + val value: String = "1547553153198765432" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("2019-01-15") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("2019-01-15 11:52:33.198765432") + assert(resultTimestamp == expectedTimestamp) + } + + test("EnceladusDateParser class actual pattern without time zone") { + val parser = DateTimeParser("yyyy_MM_dd:HH.mm.ss") + + val value: String = "2019_01_15:11.52.33" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("2019-01-15") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("2019-01-15 11:52:33") + assert(resultTimestamp == expectedTimestamp) + } + + test("EnceladusDateParser class actual pattern with standard time zone") { + val parser = DateTimeParser("yyyy-MM-dd-HH-mm-ss-zz") + + val value: String = "2011-01-31-22-52-33-EST" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("2011-02-01") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("2011-02-01 03:52:33") + assert(resultTimestamp == expectedTimestamp) + } + + test("EnceladusDateParser class actual pattern with offset time zone") { + val parser = DateTimeParser("yyyy/MM/dd HH:mm:ssXXX") + + val value: String = "1990/01/31 22:52:33+01:00" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("1990-01-31") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("1990-01-31 21:52:33") + assert(resultTimestamp == expectedTimestamp) + } + + test("EnceladusDateParser class actual pattern without time zone with milliseconds") { + val parser = DateTimeParser("SSS|yyyy_MM_dd:HH.mm.ss") + + val value: String = "123|2019_01_15:11.52.33" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("2019-01-15") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("2019-01-15 11:52:33.123") + assert(resultTimestamp == expectedTimestamp) + } + + test("EnceladusDateParser class actual pattern without time zone and microseconds") { + val parser = DateTimeParser("yyyy_MM_dd:HH.mm.ss.iiiiii") + + val value: String = "2019_01_15:11.52.33.123456" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("2019-01-15") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("2019-01-15 11:52:33.123456") + assert(resultTimestamp == expectedTimestamp) + } + + test("EnceladusDateParser class actual pattern with standard time zone and nanoseconds") { + val parser = DateTimeParser("yyyy-MM-dd-HH-mm-ss.nnnnnnnnn-zz") + + val value: String = "2011-01-31-22-52-33.123456789-EST" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("2011-02-01") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("2011-02-01 03:52:33.123456789") + assert(resultTimestamp == expectedTimestamp) + } + + test("EnceladusDateParser class actual pattern with offset time zone and all second fractions") { + val parser = DateTimeParser("nnnSSSyyyy/MM/dd iii HH:mm:ssXXX") + + val value: String = "1234561990/01/31 789 22:52:33+01:00" + val resultDate: Date = parser.parseDate(value) + val expectedDate: Date = Date.valueOf("1990-01-31") + assert(resultDate == expectedDate) + + val resultTimestamp: Timestamp = parser.parseTimestamp(value) + val expectedTimestamp: Timestamp = Timestamp.valueOf("1990-01-31 21:52:33.456789123") + assert(resultTimestamp == expectedTimestamp) + } + + test("format") { + val t: Timestamp = Timestamp.valueOf("1970-01-02 01:00:00.123456789") //25 hours to epoch with some second fractions + val parser1 = DateTimeParser("yyyy-MM-dd HH:mm:ss") + assert(parser1.format(t) == "1970-01-02 01:00:00") + val parser2 = DateTimeParser("epoch") + assert(parser2.format(t) == "90000") + val parser3 = DateTimeParser("epochmilli") + assert(parser3.format(t) == "90000123") + val parser4 = DateTimeParser("epochmicro") + assert(parser4.format(t) == "90000123456") + val parser5 = DateTimeParser("epochnano") + assert(parser5.format(t) == "90000123456789") + val parser6 = DateTimeParser("yyyy-MM-dd HH:mm:ss.iiiiii") + assert(parser6.format(t) == "1970-01-02 01:00:00.123456") + val parser7 = DateTimeParser("(nnn) yyyy-MM-dd (SSS) HH:mm:ss (iii)") + assert(parser7.format(t) == "(789) 1970-01-02 (123) 01:00:00 (456)") + } + + test("Lenient interpretation is not accepted") { + //first lenient interpretation + val pattern = "dd-MM-yyyy" + val dateString = "2015-01-01" + val sdf = new SimpleDateFormat(pattern) + sdf.parse(dateString) + //non lenient within DateTimeParser + val parser = DateTimeParser(pattern) + intercept[ParseException] { + parser.parseDate(dateString) + } + } +} diff --git a/src/test/scala/za/co/absa/standardization/types/parsers/DecimalParserSuite.scala b/src/test/scala/za/co/absa/standardization/types/parsers/DecimalParserSuite.scala new file mode 100644 index 0000000..099ef81 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/types/parsers/DecimalParserSuite.scala @@ -0,0 +1,132 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.numeric.{DecimalSymbols, NumericPattern} +import za.co.absa.standardization.types.GlobalDefaults + +import scala.util.Success + +class DecimalParserSuite extends AnyFunSuite { + test("No pattern, no limitations") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern(decimalSymbols) + val parser = DecimalParser(pattern) + assert(parser.parse("3.14") == Success(BigDecimal("3.14"))) + assert(parser.parse("1.") == Success(BigDecimal.valueOf(1))) + assert(parser.parse("-7") == Success(BigDecimal.valueOf(-7))) + assert(parser.parse(".271E1") == Success(BigDecimal("2.71"))) + assert(parser.parse("271E-2") == Success(BigDecimal("2.71"))) + } + + test("No pattern, strict parsing") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern(decimalSymbols) + val parser = DecimalParser(pattern,maxScale = Some(2)) + assert(parser.parse("3.14") == Success(BigDecimal("3.14"))) + assert(parser.parse("1.") == Success(BigDecimal.valueOf(1))) + assert(parser.parse("-7") == Success(BigDecimal.valueOf(-7))) + assert(parser.parse("12.123455").isFailure) + assert(parser.parse(".271E1") == Success(BigDecimal("2.71"))) + + assert(parser.parse("271E-2") == Success(BigDecimal("2.71"))) + } + + test("No pattern, no limitations, minus sign and decimal separator altered") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols.copy(minusSign = 'N', decimalSeparator = ',') + val pattern = NumericPattern(decimalSymbols) + val parser = DecimalParser(pattern) + assert(parser.parse("3,14") == Success(BigDecimal("3.14"))) + assert(parser.parse("1,") == Success(BigDecimal.valueOf(1))) + assert(parser.parse("N7") == Success(BigDecimal.valueOf(-7))) + assert(parser.parse(",271E1") == Success(BigDecimal("2.71"))) + assert(parser.parse("271EN2") == Success(BigDecimal("2.71"))) + assert(parser.parse("-11.1").isFailure) + } + + test("Simple pattern, some limitations") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern("0.#", decimalSymbols) + val parser = DecimalParser(pattern, Some(BigDecimal.valueOf(-1000)), Some(BigDecimal.valueOf(1000))) + assert(parser.parse("3.14") == Success(BigDecimal("3.14"))) //NB! number of hashes and 0 in pattern is not reliable + assert(parser.parse("1.") == Success(BigDecimal.valueOf(1))) + assert(parser.parse("-7") == Success(BigDecimal.valueOf(-7))) + assert(parser.parse(".271E1") == Success(BigDecimal("2.71"))) //NB! number of hashes and 0 in pattern is not reliable + assert(parser.parse("271E-2") == Success(BigDecimal("2.71"))) + assert(parser.parse("1000.0000000000000000000000000000000000000000000000000001").isFailure) + assert(parser.parse("-1000.0000000000000000000000000000000000000000000000000001").isFailure) + } + + test("pattern with altered decimal symbols") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols.copy( + decimalSeparator = ',', + groupingSeparator = ' ', + minusSign = '~' + ) + val pattern = NumericPattern("#,##0.000",decimalSymbols) //NB! that the standard grouping and decimal separators are used + val parser = DecimalParser(pattern) + + assert(parser.parse("100") == Success(BigDecimal.valueOf(100))) + assert(parser.parse("~1") == Success(BigDecimal.valueOf(-1))) + assert(parser.parse("1 000,3") == Success(BigDecimal("1000.3"))) + assert(parser.parse("~2 000,003") == Success(BigDecimal("-2000.003"))) + assert(parser.parse("3 0000,000001") == Success(BigDecimal("30000.000001"))) // grouping size is not reliable for parsing + assert(parser.parse("31,4E4") == Success(BigDecimal.valueOf(314000))) + assert(parser.parse("-4").isFailure) + assert(parser.parse("3,14E3") == Success(BigDecimal("3140"))) + assert(parser.parse("0.000 1").isFailure) // NB! grouping separator is not supported + assert(parser.parse("3.14E3").isFailure) + assert(parser.parse("~1 ").isFailure) + assert(parser.parse(" ~1 ").isFailure) + } + + test("grouping separator is not supported in decimal places") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern1 = NumericPattern("0.000,#",decimalSymbols) + val exception = intercept[IllegalArgumentException] { + DecimalParser(pattern1) + } + assert(exception.getMessage == """Malformed pattern "0.000,#"""") + + val pattern2 = NumericPattern("0.000#",decimalSymbols) + val parser2 = DecimalParser(pattern2) + assert(parser2.parse("0.000,1").isFailure) + } + + test("Prefix, suffix and different negative pattern") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern("Alt: 0.#Feet;Alt: (0.#)Feet",decimalSymbols) + val parser = DecimalParser(pattern) + + assert(parser.parse("Alt: 10000.5Feet") == Success(BigDecimal("10000.5"))) + assert(parser.parse("Alt: (100)Feet") == Success(BigDecimal.valueOf(-100))) + assert(parser.parse("Alt: 612E-2Feet") == Success(BigDecimal("6.12"))) + assert(parser.parse("Alt: 10,000Feet").isFailure) + assert(parser.parse("100").isFailure) + } + + test("Percent") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern("#,##0.#%",decimalSymbols) + val parser = DecimalParser(pattern) + + assert(parser.parse("113.8%") == Success(BigDecimal("1.138"))) + assert(parser.parse("-5,000.1%") == Success(BigDecimal("-50.001"))) + assert(parser.parse("113.8").isFailure) + } +} diff --git a/src/test/scala/za/co/absa/standardization/types/parsers/FractionalParserSuite.scala b/src/test/scala/za/co/absa/standardization/types/parsers/FractionalParserSuite.scala new file mode 100644 index 0000000..f3fdde4 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/types/parsers/FractionalParserSuite.scala @@ -0,0 +1,184 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.numeric.{DecimalSymbols, NumericPattern} +import za.co.absa.standardization.types.GlobalDefaults + +import scala.util.Success + +class FractionalParserSuite extends AnyFunSuite { + private val reallyBigNumberString = "12345678901234567890123456789012345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890" + + "12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890" + private val reallySmallNumberString = s"-$reallyBigNumberString" + + test("No pattern") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern(decimalSymbols) + val parserFloat = FractionalParser[Float](pattern, Float.MinValue, Float.MaxValue) + val parserDouble = FractionalParser[Double](pattern, Double.MinValue, Double.MaxValue) + assert(parserFloat.parse("3.14") == Success(3.14F)) + assert(parserDouble.parse("3.14") == Success(3.14D)) + assert(parserFloat.parse("+1.") == Success(1F)) + assert(parserDouble.parse("1.") == Success(1D)) + assert(parserFloat.parse("-7") == Success(-7F)) + assert(parserDouble.parse("-7") == Success(-7D)) + assert(parserFloat.parse(".271E1") == Success(2.71F)) + assert(parserDouble.parse(".271E1") == Success(2.71D)) + assert(parserFloat.parse("271E-2") == Success(2.71F)) + assert(parserDouble.parse("+271E-2") == Success(2.71D)) + assert(parserFloat.parse("1E40").isFailure) + assert(parserDouble.parse("1E40") == Success(1.0E40)) + assert(parserDouble.parse("1E360").isFailure) + } + + test("Simple pattern, some limitations") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern("0.#", decimalSymbols) + val parserFloat = FractionalParser[Float](pattern, Float.MinValue, Float.MaxValue) + val parserDouble = FractionalParser[Double](pattern, Double.MinValue, Double.MaxValue) + assert(parserFloat.parse("3.14") == Success(3.14F)) + assert(parserDouble.parse("3.14") == Success(3.14D)) + assert(parserFloat.parse("1.") == Success(1F)) + assert(parserDouble.parse("1.") == Success(1D)) + assert(parserFloat.parse("-7") == Success(-7F)) + assert(parserDouble.parse("-7") == Success(-7D)) + assert(parserFloat.parse(".271E1") == Success(2.71F)) //NB! number of hashes and 0 in pattern is not reliable + assert(parserDouble.parse(".271E1") == Success(2.71D)) //NB! number of hashes and 0 in pattern is not reliable + assert(parserFloat.parse("271E-2") == Success(2.71F)) + assert(parserDouble.parse("271E-2") == Success(2.71D)) + assert(parserFloat.parse("1E40").isFailure) + assert(parserDouble.parse("1E40") == Success(1.0E40)) + assert(parserDouble.parse("1E360").isFailure) + } + + test("plus doesn't work if pattern is specified") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern("0", decimalSymbols) + val parserFloat = FractionalParser[Float](pattern, Float.MinValue, Float.MaxValue) + val parserDouble = FractionalParser[Double](pattern, Double.MinValue, Double.MaxValue) + assert(parserFloat.parse("+2.71").isFailure) + assert(parserDouble.parse("+2.71").isFailure) + } + + test("infinities") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern1 = NumericPattern(decimalSymbols) + val pattern2 = NumericPattern("0.#", decimalSymbols) + val parserFloatStd1 = FractionalParser.withInfinity[Float](pattern1) + val parserDoubleStd1 = FractionalParser.withInfinity[Double](pattern1) + val parserFloatStd2 = FractionalParser.withInfinity[Float](pattern2) + val parserDoubleStd2 = FractionalParser.withInfinity[Double](pattern2) + assert(parserFloatStd1.parse("∞") == Success(Float.PositiveInfinity)) + assert(parserFloatStd1.parse("-∞") == Success(Float.NegativeInfinity)) + assert(parserDoubleStd1.parse("∞") == Success(Double.PositiveInfinity)) + assert(parserDoubleStd1.parse("-∞") == Success(Double.NegativeInfinity)) + assert(parserFloatStd2.parse("∞") == Success(Float.PositiveInfinity)) + assert(parserFloatStd2.parse("-∞") == Success(Float.NegativeInfinity)) + assert(parserDoubleStd2.parse("∞") == Success(Double.PositiveInfinity)) + assert(parserDoubleStd2.parse("-∞") == Success(Double.NegativeInfinity)) + assert(parserFloatStd1.parse("3E40") == Success(Float.PositiveInfinity)) + assert(parserFloatStd1.parse("-7699980973893499984399399999999999999998976876999") == Success(Float.NegativeInfinity)) + assert(parserDoubleStd1.parse(reallyBigNumberString) == Success(Double.PositiveInfinity)) + assert(parserDoubleStd1.parse(reallySmallNumberString) == Success(Double.NegativeInfinity)) + assert(parserDoubleStd1.parse("2E308") == Success(Double.PositiveInfinity)) + assert(parserDoubleStd1.parse("-6.6E666") == Success(Double.NegativeInfinity)) + assert(parserFloatStd2.parse("1276493809384398420983098239843298980977679008") == Success(Float.PositiveInfinity)) + assert(parserFloatStd2.parse("-2.71E55") == Success(Float.NegativeInfinity)) + assert(parserDoubleStd2.parse(reallyBigNumberString) == Success(Double.PositiveInfinity)) + assert(parserDoubleStd2.parse(reallySmallNumberString) == Success(Double.NegativeInfinity)) + assert(parserDoubleStd2.parse("2E308") == Success(Double.PositiveInfinity)) + assert(parserDoubleStd2.parse("-1E1000") == Success(Double.NegativeInfinity)) + } + + test("infinities redefined") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols.copy(minusSign = '&', infinityValue = "Infinity") + val pattern1 = NumericPattern(decimalSymbols) + val pattern2 = NumericPattern("#", decimalSymbols) + val pattern3 = NumericPattern("#;Negative#", decimalSymbols) + val parser1 = FractionalParser.withInfinity[Double](pattern1) + val parser2 = FractionalParser.withInfinity[Float](pattern2) + val parser3 = FractionalParser.withInfinity[Double](pattern3) + assert(parser1.parse("Infinity") == Success(Double.PositiveInfinity)) + assert(parser1.parse("&Infinity") == Success(Double.NegativeInfinity)) + assert(parser2.parse("Infinity") == Success(Float.PositiveInfinity)) + assert(parser2.parse("&Infinity") == Success(Float.NegativeInfinity)) + assert(parser3.parse("Infinity") == Success(Double.PositiveInfinity)) + assert(parser3.parse("NegativeInfinity") == Success(Double.NegativeInfinity)) + assert(parser3.parse("&Infinity").isFailure) + } + + test("No pattern, no limitations, minus sign and decimal separator altered") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols.copy(minusSign = 'N', decimalSeparator = ',') + val pattern = NumericPattern(decimalSymbols) + val parser = FractionalParser(pattern) + assert(parser.parse("6,28") == Success(6.28D)) + assert(parser.parse("10000,") == Success(10000D)) + assert(parser.parse("N7") == Success(-7D)) + assert(parser.parse(",271E1") == Success(2.71D)) + assert(parser.parse("271EN2") == Success(2.71D)) + assert(parser.parse("-11.1").isFailure) + } + + test("pattern with altered decimal symbols") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols.copy( + decimalSeparator = ',', + groupingSeparator = ''', + minusSign = '@' + ) + val pattern = NumericPattern("#,##0",decimalSymbols) //NB! that the standard grouping separator is used + val parser = FractionalParser(pattern) + + assert(parser.parse("100") == Success(100)) + assert(parser.parse("@,1") == Success(-0.1D)) + assert(parser.parse("1'032,") == Success(1032D)) + assert(parser.parse("@2'000,55") == Success(-2000.55D)) + assert(parser.parse("3'0000,001") == Success(30000.001D)) // grouping size is not reliable for parsing + assert(parser.parse("314E@2") == Success(3.14D)) + assert(parser.parse("-4").isFailure) + assert(parser.parse("3.14E3").isFailure) + assert(parser.parse("@1 ").isFailure) + assert(parser.parse(" @1 ").isFailure) + } + + test("Prefix, suffix and different negative pattern") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern("Temperature #,##0C;Freezing -0",decimalSymbols) + val parser = FractionalParser(pattern) + + assert(parser.parse("Temperature 100C") == Success(100D)) + assert(parser.parse("Temperature 36.8C") == Success(36.8D)) + assert(parser.parse("Freezing -12") == Success(-12D)) + assert(parser.parse("Temperature 1,234C") == Success(1234)) + assert(parser.parse("Freezing 300.0C").isFailure) + assert(parser.parse("100.2").isFailure) + } + + test("Percent") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern("#,##0.#%",decimalSymbols) + val parser = FractionalParser(pattern) + + assert(parser.parse("113.8%") == Success(1.138D)) + assert(parser.parse("-5,000.1%") == Success(-5000.1D / 100)) // -5000.1D / 100 = -50.001000000000005 + assert(parser.parse("113.8").isFailure) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/types/parsers/IntegralParser_PatternIntegralParserSuite.scala b/src/test/scala/za/co/absa/standardization/types/parsers/IntegralParser_PatternIntegralParserSuite.scala new file mode 100644 index 0000000..792aa16 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/types/parsers/IntegralParser_PatternIntegralParserSuite.scala @@ -0,0 +1,130 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.numeric.{DecimalSymbols, NumericPattern} +import za.co.absa.standardization.types.GlobalDefaults + +import scala.util.Success + +class IntegralParser_PatternIntegralParserSuite extends AnyFunSuite { + test("No pattern, no limitations") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern(decimalSymbols) + val ipLong = IntegralParser[Long](pattern, None, None) + val ipInt = IntegralParser[Int](pattern, None, None) + val ipShort = IntegralParser[Short](pattern, None, None) + val ipByte = IntegralParser[Byte](pattern, None, None) + assert(ipLong.parse("98987565664") == Success(98987565664L)) + assert(ipLong.parse("-31225927393149") == Success(-31225927393149L)) + assert(ipInt.parse("2100000") == Success(2100000)) + assert(ipInt.parse("-1000") == Success(-1000)) + assert(ipShort.parse("16000") == Success(16000)) + assert(ipShort.parse("-16000") == Success(-16000)) + assert(ipByte.parse("127") == Success(127)) + assert(ipByte.parse("-17") == Success(-17)) + } + + test("No pattern, no limitations, minus sign altered") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols.copy(minusSign = 'N') + val pattern = NumericPattern(decimalSymbols) + val ipLong = IntegralParser[Long](pattern, None, None) + val ipInt = IntegralParser[Int](pattern, None, None) + val ipShort = IntegralParser[Short](pattern, None, None) + val ipByte = IntegralParser[Byte](pattern, None, None) + assert(ipLong.parse("98987565664") == Success(98987565664L)) + assert(ipLong.parse("N31225927393149") == Success(-31225927393149L)) + assert(ipLong.parse("-31225927393149").isFailure) + assert(ipInt.parse("2100000") == Success(2100000)) + assert(ipInt.parse("N1000") == Success(-1000)) + assert(ipInt.parse("-1000").isFailure) + assert(ipShort.parse("16000") == Success(16000)) + assert(ipShort.parse("N16000") == Success(-16000)) + assert(ipShort.parse("-16000").isFailure) + assert(ipByte.parse("127") == Success(127)) + assert(ipByte.parse("N17") == Success(-17)) + assert(ipByte.parse("-17").isFailure) + } + + test("Limit breaches") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern(decimalSymbols) + val ipLong = IntegralParser[Long](pattern, Some(10000000000L), Some(10000000010L)) + val ipInt = IntegralParser[Int](pattern, Some(-700000), None) + val ipShort = IntegralParser[Short](pattern, None, Some(5000)) + val ipByte = IntegralParser[Byte](pattern, None, None) + assert(ipLong.parse("10000000011").isFailure) + assert(ipLong.parse("9999999999").isFailure) + assert(ipInt.parse("2147483648").isFailure) + assert(ipInt.parse("-800000").isFailure) + assert(ipShort.parse("5001").isFailure) + assert(ipShort.parse("-32769").isFailure) + assert(ipByte.parse("128").isFailure) + assert(ipByte.parse("-129").isFailure) + } + + test("pattern with standard decimal symbols") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern("0,000",decimalSymbols) + val parser = IntegralParser(pattern) + + assert(parser.parse("100") == Success(100)) + assert(parser.parse("-1") == Success(-1)) + assert(parser.parse("1,000") == Success(1000)) + assert(parser.parse("-2000") == Success(-2000)) + assert(parser.parse("3,0000") == Success(30000)) // grouping size is not reliable for parsing + assert(parser.parse("314E3") == Success(314000)) + assert(parser.parse("3.14E3").isFailure) + assert(parser.parse("-1 ").isFailure) + assert(parser.parse(" -1 ").isFailure) + } + + test("pattern with altered decimal symbols") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols.copy( + decimalSeparator = ',', + groupingSeparator = ' ', + minusSign = '~' + ) + val pattern = NumericPattern("#,##0",decimalSymbols) //NB! that the standard grouping separator is used + val parser = IntegralParser(pattern) + + assert(parser.parse("100") == Success(100)) + assert(parser.parse("~1") == Success(-1)) + assert(parser.parse("1 000") == Success(1000)) + assert(parser.parse("~2 000") == Success(-2000)) + assert(parser.parse("3 0000") == Success(30000)) // grouping size is not reliable for parsing + assert(parser.parse("314E3") == Success(314000)) + assert(parser.parse("-4").isFailure) + assert(parser.parse("3,14E3").isFailure) + assert(parser.parse("3.14E3").isFailure) + assert(parser.parse("~1 ").isFailure) + assert(parser.parse(" ~1 ").isFailure) + } + + test("Prefix, suffix and different negative pattern") { + val decimalSymbols: DecimalSymbols = GlobalDefaults.getDecimalSymbols + val pattern = NumericPattern("Price: 0'EUR';Price: -0'EUR'",decimalSymbols) + val parser = IntegralParser(pattern) + + assert(parser.parse("Price: 100EUR") == Success(100)) + assert(parser.parse("Price: -12EUR") == Success(-12)) + assert(parser.parse("Price: 1,234EUR").isFailure) + assert(parser.parse("Price: 100.0EUR").isFailure) + assert(parser.parse("100").isFailure) + } +} diff --git a/src/test/scala/za/co/absa/standardization/types/parsers/IntegralParser_RadixIntegralParserSuite.scala b/src/test/scala/za/co/absa/standardization/types/parsers/IntegralParser_RadixIntegralParserSuite.scala new file mode 100644 index 0000000..0c94bdd --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/types/parsers/IntegralParser_RadixIntegralParserSuite.scala @@ -0,0 +1,173 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.types.parsers + +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.numeric.Radix +import za.co.absa.standardization.numeric.Radix.RadixFormatException +import za.co.absa.standardization.types.parsers.NumericParser.NumericParserException + +import scala.util.Success + +class IntegralParser_RadixIntegralParserSuite extends AnyFunSuite { + + test("base 10 parsing succeeds") { + val parser = IntegralParser.ofRadix(Radix(10)) + assert(parser.parse("1111") == Success(1111)) + assert(parser.parse("-1") == Success(-1)) + } + + test("base 10 parsing fails on too big number") { + val parser = IntegralParser.ofRadix(Radix(10)) + val tooBig = "45455782147845454874654658875324" + val fail = parser.parse(tooBig).failed.get + assert(fail.isInstanceOf[NumberFormatException]) + assert(fail.getMessage == """For input string: """" + tooBig + """"""") + } + + test("base 10 parsing fails on wrong input") { + val parser = IntegralParser.ofRadix(Radix(10)) + val wrong = "Hello" + val fail = parser.parse(wrong).failed.get + assert(fail.isInstanceOf[NumberFormatException]) + assert(fail.getMessage == """For input string: """" + wrong + """"""") + } + + test("base 16 parsing succeeds") { + val parser = IntegralParser.ofRadix(Radix(16)) + assert(parser.parse("CAFE") == Success(51966)) + assert(parser.parse("-a") == Success(-10)) + assert(parser.parse("0xFFFfF") == Success(1048575)) + assert(parser.parse("+0X1A") == Success(26)) + assert(parser.parse("-0X1") == Success(-1)) + assert(parser.parse("7FFFFFFFFFFFFFFF") == Success(9223372036854775807L)) + } + + test("base 16 parsing fails on incomplete input") { + val parser = IntegralParser.ofRadix(Radix(16)) + val fail1 = parser.parse("").failed.get + assert(fail1.isInstanceOf[NumberFormatException]) + assert(fail1.getMessage == "Zero length BigInteger") + val fail2 = parser.parse("0x").failed.get + assert(fail2.isInstanceOf[NumberFormatException]) + assert(fail2.getMessage == "Zero length BigInteger") + } + + test("base 16 parsing fails on too big input") { + val parser = IntegralParser.ofRadix(Radix(16)) + val tooBig = "8FFFFFFFFFFFFFFF" + val fail = parser.parse(tooBig).failed.get + assert(fail.isInstanceOf[NumericParserException]) + assert(fail.getMessage == s"The number '$tooBig' is out of range ") + } + + test("base 16 parsing fails on bad input") { + val parser = IntegralParser.ofRadix(Radix(16)) + val wrong = "g" + val fail = parser.parse(wrong).failed.get + assert(fail.isInstanceOf[NumberFormatException]) + assert(fail.getMessage == """For input string: "g"""") + } + + test("base 2 parsing succeeds") { + val parser = IntegralParser.ofRadix(Radix(2)) + assert(parser.parse("1" * 63) == Success(Long.MaxValue)) + assert(parser.parse("-10101") == Success(-21)) + } + + test("base 2 parsing fails on too big number") { + val parser = IntegralParser.ofRadix(Radix(2)) + val tooBig = "1" * 64 + val result = parser.parse(tooBig) + val fail = result.failed.get + assert(fail.isInstanceOf[NumericParserException]) + assert(fail.getMessage == s"The number '$tooBig' is out of range ") + } + + test("base 2 parsing fails on wrong input") { + val parser = IntegralParser.ofRadix(Radix(2)) + val wrong = "3" + val fail = parser.parse(wrong).failed.get + assert(fail.isInstanceOf[NumberFormatException]) + assert(fail.getMessage == """For input string: """" + wrong + """"""") + } + + + test("base 36 parsing succeeds") { + val parser = IntegralParser.ofRadix(Radix(36)) + assert(parser.parse("Zardoz1") == Success(76838032045L)) + assert(parser.parse("-Wick3") == Success(-54603795)) + } + + test("base 36 parsing fails on too big number") { + val parser = IntegralParser.ofRadix(Radix(36)) + val tooBig = "DowningStreet10" + val result = parser.parse(tooBig) + val fail = result.failed.get + assert(fail.isInstanceOf[NumericParserException]) + assert(fail.getMessage == s"The number '$tooBig' is out of range ") + } + + test("base 36 parsing fails on wrong input") { + val parser = IntegralParser.ofRadix(Radix(36)) + val wrong = "__" + val fail = parser.parse(wrong).failed.get + assert(fail.isInstanceOf[NumberFormatException]) + assert(fail.getMessage == """For input string: """" + wrong + """"""") + } + + test("string base inputs") { + assert(IntegralParser.ofStringRadix("").radix.value == 10) + assert(IntegralParser.ofStringRadix("DEC").radix.value == 10) + assert(IntegralParser.ofStringRadix("decImal").radix.value == 10) + assert(IntegralParser.ofStringRadix("Hex").radix.value == 16) + assert(IntegralParser.ofStringRadix("HexaDecimal").radix.value == 16) + assert(IntegralParser.ofStringRadix("bin").radix.value == 2) + assert(IntegralParser.ofStringRadix("binarY").radix.value == 2) + assert(IntegralParser.ofStringRadix("oct").radix.value == 8) + assert(IntegralParser.ofStringRadix("OCTAL").radix.value == 8) + assert(IntegralParser.ofStringRadix("23").radix.value == 23) + } + + test("base out of range") { + val exception1 = intercept[RadixFormatException] { + IntegralParser.ofRadix(Radix(0)) + } + assert(exception1.getMessage == "Radix has to be greater then 0, 0 was entered") + val exception2 = intercept[RadixFormatException] { + IntegralParser.ofRadix(Radix(37)) + } + assert(exception2.getMessage == "Maximum supported radix is 36, 37 was entered") + } + + test("base not recognized") { + val exception = intercept[RadixFormatException] { + IntegralParser.ofStringRadix("hello") + } + assert(exception.getMessage == "'hello' was not recognized as a Radix value") + } + + test("base is smaller then 10 and minus is a higher digit") { + val decimalSymbols = NumericParser.defaultDecimalSymbols.copy(minusSign = '5') + + val parser = IntegralParser.ofRadix(Radix(5), decimalSymbols) + assert(parser.parse("4321") == Success(586)) + assert(parser.parse("54321") == Success(-586)) + assert(parser.parse("-4321").isFailure) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/udf/UDFBuilderSuite.scala b/src/test/scala/za/co/absa/standardization/udf/UDFBuilderSuite.scala new file mode 100644 index 0000000..9d13767 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/udf/UDFBuilderSuite.scala @@ -0,0 +1,126 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.udf + +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.TypedStructField.NumericTypeStructField +import za.co.absa.standardization.types.parsers.IntegralParser.{PatternIntegralParser, RadixIntegralParser} +import za.co.absa.standardization.types.parsers.{DecimalParser, FractionalParser} +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + +class UDFBuilderSuite extends AnyFunSuite { + private implicit val defaults: Defaults = GlobalDefaults + + test("Serialization and deserialization of stringUdfViaNumericParser (FractionalParser)") { + val fieldName = "test" + val field: StructField = StructField(fieldName, DoubleType, nullable = false) + val typedField = TypedStructField(field) + + + val numericTypeField = typedField.asInstanceOf[NumericTypeStructField[Double]] + val defaultValue: Option[Double] = typedField.defaultValueWithGlobal.get.map(_.asInstanceOf[Double]) + val parser = numericTypeField.parser.get.asInstanceOf[FractionalParser[Double]] + val udfFnc = UDFBuilder.stringUdfViaNumericParser(parser, numericTypeField.nullable, fieldName, defaultValue) + //write + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + oos.writeObject(udfFnc) + oos.flush() + val serialized = baos.toByteArray + assert(serialized.nonEmpty) + //read + val ois = new ObjectInputStream(new ByteArrayInputStream(serialized)) + (ois readObject ()).asInstanceOf[UserDefinedFunction] + } + + test("Serialization and deserialization of stringUdfViaNumericParser (DecimalParser)") { + val fieldName = "test" + val field: StructField = StructField(fieldName, DecimalType(20,5), nullable = false) + val typedField = TypedStructField(field) + + + val numericTypeField = typedField.asInstanceOf[NumericTypeStructField[BigDecimal]] + val defaultValue = typedField.defaultValueWithGlobal.get.map(_.asInstanceOf[BigDecimal]) + val parser = numericTypeField.parser.get.asInstanceOf[DecimalParser] + val udfFnc = UDFBuilder.stringUdfViaNumericParser(parser, numericTypeField.nullable, fieldName, defaultValue) + //write + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + oos.writeObject(udfFnc) + oos.flush() + val serialized = baos.toByteArray + assert(serialized.nonEmpty) + //read + val ois = new ObjectInputStream(new ByteArrayInputStream(serialized)) + (ois readObject ()).asInstanceOf[UserDefinedFunction] + } + + test("Serialization and deserialization of stringUdfViaNumericParser (RadixIntegralParser)") { + val fieldName = "test" + val field: StructField = StructField(fieldName, LongType, nullable = false, new MetadataBuilder() + .putString(MetadataKeys.Radix, "hex") + .putString(MetadataKeys.DefaultValue, "FF") + .build) + val typedField = TypedStructField(field) + + + val numericTypeField = typedField.asInstanceOf[NumericTypeStructField[Long]] + val defaultValue: Option[Long] = typedField.defaultValueWithGlobal.get.map(_.asInstanceOf[Long]) + val parser = numericTypeField.parser.get.asInstanceOf[RadixIntegralParser[Long]] + val udfFnc = UDFBuilder.stringUdfViaNumericParser(parser, numericTypeField.nullable, fieldName, defaultValue) + //write + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + oos.writeObject(udfFnc) + oos.flush() + val serialized = baos.toByteArray + assert(serialized.nonEmpty) + //read + val ois = new ObjectInputStream(new ByteArrayInputStream(serialized)) + (ois readObject ()).asInstanceOf[UserDefinedFunction] + } + + test("Serialization and deserialization of stringUdfViaNumericParser (PatternIntegralParser)") { + val fieldName = "test" + val field: StructField = StructField(fieldName, ShortType, nullable = true, new MetadataBuilder() + .putString(MetadataKeys.Pattern, "0 feet") + .build) + val typedField = TypedStructField(field) + + + val numericTypeField = typedField.asInstanceOf[NumericTypeStructField[Short]] + val defaultValue: Option[Short] = typedField.defaultValueWithGlobal.get.map(_.asInstanceOf[Short]) + val parser = numericTypeField.parser.get.asInstanceOf[PatternIntegralParser[Short]] + val udfFnc = UDFBuilder.stringUdfViaNumericParser(parser, numericTypeField.nullable, fieldName, defaultValue) + //write + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + oos.writeObject(udfFnc) + oos.flush() + val serialized = baos.toByteArray + assert(serialized.nonEmpty) + //read + val ois = new ObjectInputStream(new ByteArrayInputStream(serialized)) + (ois readObject ()).asInstanceOf[UserDefinedFunction] + } + +} diff --git a/src/test/scala/za/co/absa/standardization/validation/field/BinaryValidatorSuite.scala b/src/test/scala/za/co/absa/standardization/validation/field/BinaryValidatorSuite.scala new file mode 100644 index 0000000..e56c65e --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/validation/field/BinaryValidatorSuite.scala @@ -0,0 +1,61 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import org.apache.spark.sql.types.{BinaryType, MetadataBuilder, StructField} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.{ValidationError, ValidationWarning} +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField} + +class BinaryValidatorSuite extends AnyFunSuite { + private implicit val defaults: Defaults = GlobalDefaults + + private def field(defaultValue: Option[String] = None, encoding: Option[String] = None, nullable: Boolean = true): TypedStructField = { + val base = new MetadataBuilder() + val builder2 = defaultValue.map(base.putString(MetadataKeys.DefaultValue, _)).getOrElse(base) + val builder3 = encoding.map(builder2.putString(MetadataKeys.Encoding, _)).getOrElse(builder2) + val result = StructField("test_field", BinaryType, nullable, builder3.build()) + TypedStructField(result) + } + + test("field with no meta validates") { + assert(BinaryFieldValidator.validate(field()).isEmpty) + } + + test("field with explicit default or explicit non-base64 encoding validates") { + assert(BinaryFieldValidator.validate(field(defaultValue = Some("abc"), encoding = Some("none"))).isEmpty) + assert(BinaryFieldValidator.validate(field(encoding = Some("none"))).isEmpty) + + assert(BinaryFieldValidator.validate(field(defaultValue = Some("abc"))) == Seq( + ValidationWarning("Default value of 'abc' found, but no encoding is specified. Assuming 'none'.") + )) + } + + test("field with base64 encoding and with no or correct defaultValue validates") { + assert(BinaryFieldValidator.validate(field(encoding = Some("base64"))).isEmpty) + // base64("test") => "dGVzdA==" + assert(BinaryFieldValidator.validate(field(defaultValue = Some("dGVzdA=="), encoding = Some("base64"))).isEmpty) + } + + test("field with base64 encoding and non-base64 default has issues ") { + val result = BinaryFieldValidator.validate(field(defaultValue = Some("bogus!@#"), encoding = Some("base64"))) + assert(result.contains(ValidationError("'bogus!@#' cannot be cast to binary"))) + assert(result.contains(ValidationError("Invalid default value bogus!@# for Base64 encoding (cannot be decoded)!"))) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/validation/field/DateFieldValidatorSuite.scala b/src/test/scala/za/co/absa/standardization/validation/field/DateFieldValidatorSuite.scala new file mode 100644 index 0000000..045d72f --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/validation/field/DateFieldValidatorSuite.scala @@ -0,0 +1,213 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import org.apache.spark.sql.types.{DateType, MetadataBuilder, StructField} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.{ValidationError, ValidationIssue, ValidationWarning} +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.time.TimeZoneNormalizer +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField} + +class DateFieldValidatorSuite extends AnyFunSuite { + TimeZoneNormalizer.normalizeJVMTimeZone() + private implicit val defaults: Defaults = GlobalDefaults + + private def field(pattern: String, defaultValue: Option[String] = None, defaultTimeZone: Option[String] = None): TypedStructField = { + val builder = new MetadataBuilder().putString(MetadataKeys.Pattern, pattern) + val builder2 = defaultValue.map(builder.putString(MetadataKeys.DefaultValue, _)).getOrElse(builder) + val builder3 = defaultTimeZone.map(builder2.putString(MetadataKeys.DefaultTimeZone, _)).getOrElse(builder2) + val result = StructField("test_field", DateType, nullable = false, builder3.build()) + TypedStructField(result) + } + + test("epoch pattern") { + assert(DateFieldValidator.validate(field("epoch")).isEmpty) + //with default + assert(DateFieldValidator.validate(field("epoch", Option("5545556"))).isEmpty) + } + + test("epochmilli pattern") { + assert(DateFieldValidator.validate(field("epochmilli")).isEmpty) + //with default + assert(DateFieldValidator.validate(field("epochmilli", Option("5545556000"))).isEmpty) + } + + test("epochmicro pattern") { + assert(DateFieldValidator.validate(field("epochmicro")).isEmpty) + //with default + assert(DateFieldValidator.validate(field("epochmicro", Option("5545556000111"))).isEmpty) + } + + test("epochnano pattern") { + assert(DateFieldValidator.validate(field("epochnano")).isEmpty) + //with default + assert(DateFieldValidator.validate(field("epochnano", Option("5545556000111222"))).isEmpty) + } + + test("date pattern") { + //no default + assert(DateFieldValidator.validate(field("yyyy-MM-dd")).isEmpty) + //default as date + assert(DateFieldValidator.validate(field("dd.MM.yy", Option("01.05.18"))).isEmpty) + //default as timestamp + assert(DateFieldValidator.validate(field("yyyy/dd/MM", Option("2010/21/11 04:00:00"))).isEmpty) + } + + test("date with time zone in pattern") { + val expected = Set( + ValidationWarning("Time zone is defined in pattern for date. While it's valid, it can lead to unexpected outcomes.") + ) + //no default + assert(DateFieldValidator.validate(field("yyyy-MM-dd zz")).toSet == expected) + //default as timestamp + assert(DateFieldValidator.validate(field("dd.MM.yyyy+zz", Option("23.10.2000+CET"))).toSet == expected) + //extra chars in default + assert(DateFieldValidator.validate(field("yyMMdd_zz", Option("190301_EST!!!!"))).toSet == expected) + //timestamp with offset time zone + assert(DateFieldValidator.validate(field("yyyy/MM/dd XXX", Option("2019/01/31 -11:00"))).toSet == expected) + } + + test("invalid pattern") { + val expected1 = Set( + ValidationError("Illegal pattern character 'f'") + ) + assert(DateFieldValidator.validate(field("fubar")).toSet == expected1) + val expected2 = Set( + ValidationError("Illegal pattern character 'x'") + ) + assert(DateFieldValidator.validate(field("yyMMdd_xx")).toSet == expected2) + } + + test("invalid default") { + //empty default + val expected1 = Set( + ValidationError("""Unparseable date: """""), + ValidationWarning("Time zone is defined in pattern for date. While it's valid, it can lead to unexpected outcomes.") + ) + assert(DateFieldValidator.validate(field("yyMMdd_zz", Option(""))).toSet == expected1) + //wrong default + val expected2 = Set( + ValidationError("""Unparseable date: "1999-12-31"""") + ) + assert(DateFieldValidator.validate(field("yyyy/MM/dd", Option("1999-12-31"))).toSet == expected2) + //invalid epoch default + val expected3 = Set( + ValidationError("'2019-01-01' cannot be cast to date") + ) + assert(DateFieldValidator.validate(field("epoch", Option("2019-01-01"))).toSet == expected3) + //epoch overflow + val expected5 = Set( + ValidationError("'8748743743948390823948239084294938231122123' cannot be cast to date") + ) + assert(DateFieldValidator.validate(field("epoch", Option("8748743743948390823948239084294938231122123"))).toSet == expected5) + } + + test("utilizing default time zone") { + val pattern = "yyyy-MM-dd" + val value = Option("2000-01-01") + val expected = Set( + ValidationWarning("Time zone is defined in pattern for date. While it's valid, it can lead to unexpected outcomes.") + ) + // full name + assert(DateFieldValidator.validate(field(pattern, value, Option("Africa/Johannesburg"))).toSet == expected) + // abbreviation + assert(DateFieldValidator.validate(field(pattern, value, Option("CET"))).toSet == expected) + // offset to GMT + assert(DateFieldValidator.validate(field(pattern, value, Option("Etc/GMT-6"))).toSet == expected) + } + + test("issues with default time zone") { + def expected(timeZone: String): Set[ValidationIssue] = { + val q ="\"" + Set( + ValidationError(s"$q$timeZone$q is not a valid time zone designation"), + ValidationWarning("Time zone is defined in pattern for date. While it's valid, it can lead to unexpected outcomes.") + ) + } + val pattern = "yyyy-MM-dd" + val value = Option("2000-01-01") + // offset + val tz1 = "-03:00" + assert(DateFieldValidator.validate(field(pattern, value, Option(tz1))).toSet == expected(tz1)) + // empty + val tz2 = "" + assert(DateFieldValidator.validate(field(pattern, value, Option(tz2))).toSet == expected(tz2)) + // gibberish + val tz3 = "Gjh878-++_?" + assert(DateFieldValidator.validate(field(pattern, value, Option(tz3))).toSet == expected(tz3)) + // non-existing + val tz4 = "Africa/New York" + assert(DateFieldValidator.validate(field(pattern, value, Option(tz4))).toSet == expected(tz4)) + } + + test("warning issues: double time zone") { + val expected = Set( + ValidationWarning("Pattern includes time zone placeholder and default time zone is also defined (will never be used)"), + ValidationWarning("Time zone is defined in pattern for date. While it's valid, it can lead to unexpected outcomes.") + ) + assert(DateFieldValidator.validate(field("yyyy-MM-dd XX", None, Option("CET"))).toSet == expected) + assert(DateFieldValidator.validate(field("yyyy-MM-dd zz", None, Option("CET"))).toSet == expected) + } + + test("warning issues: missing placeholders") { + val expected = Set( + ValidationWarning("No year placeholder 'yyyy' found."), + ValidationWarning("No month placeholder 'MM' found."), + ValidationWarning("No day placeholder 'dd' found.") + ) + assert(DateFieldValidator.validate(field("GG")).toSet == expected) + } + + test("warning issues: redundant placeholders") { + val expected = Set( + ValidationWarning("Redundant hour placeholder 'H' found."), + ValidationWarning("Redundant minute placeholder 'm' found."), + ValidationWarning("Redundant second placeholder 's' found."), + ValidationWarning("Redundant millisecond placeholder 'S' found."), + ValidationWarning("Redundant microsecond placeholder 'i' found."), + ValidationWarning("Redundant nanosecond placeholder 'n' found."), + ValidationWarning("Redundant am/pm placeholder 'a' found."), + ValidationWarning("Redundant hour placeholder 'k' found."), + ValidationWarning("Redundant hour placeholder 'h' found."), + ValidationWarning("Redundant hour placeholder 'H' found.") + ) + assert(DateFieldValidator.validate(field("yyyy-MM-dd HH:mm:ss.SSSiiinnn (aakkhhKK)", None, None)).toSet == expected) + } + + test("warning issues: missing placeholders with default time zone") { + val expected = Set( + ValidationWarning("No year placeholder 'yyyy' found."), + ValidationWarning("No month placeholder 'MM' found."), + ValidationWarning("No day placeholder 'dd' found."), + ValidationWarning("Time zone is defined in pattern for date. While it's valid, it can lead to unexpected outcomes.") + ) + assert(DateFieldValidator.validate(field("GG", None, Option("CET"))).toSet == expected) + } + + test("warning issues: day placeholder wrong case") { + val expected = Set( + ValidationWarning("No day placeholder 'dd' found."), + ValidationWarning("Rarely used DayOfYear placeholder 'D' found. Possibly DayOfMonth 'd' intended.") + ) + assert(DateFieldValidator.validate(field("yyyy/MM/DD")).toSet == expected) + } + + test("all relevant placeholders") { + assert(DateFieldValidator.validate(field("GG yyyy MM ww W DDD dd F E")).isEmpty) + } +} diff --git a/src/test/scala/za/co/absa/standardization/validation/field/FieldValidatorSuite.scala b/src/test/scala/za/co/absa/standardization/validation/field/FieldValidatorSuite.scala new file mode 100644 index 0000000..8968a3f --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/validation/field/FieldValidatorSuite.scala @@ -0,0 +1,32 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +class FieldValidatorSuite extends AnyFunSuite with Matchers { + + test("strip type name prefixes where they exists") { + FieldValidator.simpleTypeName("za.co.absa.standardization.validation.field.FieldValidator") shouldBe "FieldValidator" + FieldValidator.simpleTypeName("scala.Boolean") shouldBe "Boolean" + } + + test("be no-op for no prefixes") { + FieldValidator.simpleTypeName("Boolean") shouldBe "Boolean" + } +} diff --git a/src/test/scala/za/co/absa/standardization/validation/field/FractionalFieldValidatorSuite.scala b/src/test/scala/za/co/absa/standardization/validation/field/FractionalFieldValidatorSuite.scala new file mode 100644 index 0000000..b8f4c6f --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/validation/field/FractionalFieldValidatorSuite.scala @@ -0,0 +1,91 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import org.apache.spark.sql.types.{DataType, DoubleType, FloatType, MetadataBuilder, StructField} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.ValidationError +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.TypedStructField.FractionalTypeStructField +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField} + +class FractionalFieldValidatorSuite extends AnyFunSuite { + private implicit val defaults: Defaults = GlobalDefaults + + private def field(dataType: DataType, metadataBuilder: MetadataBuilder): FractionalTypeStructField[_] = { + val result = StructField("test_field", dataType, nullable = false, metadataBuilder.build()) + TypedStructField(result).asInstanceOf[FractionalTypeStructField[_]] + } + + test("No allow_infinity metadata") { + val builder = new MetadataBuilder + val f = field(FloatType, builder) + assert(FractionalFieldValidator.validate(f).isEmpty) + } + + test("allow_infinity metadata defined") { + val builder1 = new MetadataBuilder().putString(MetadataKeys.AllowInfinity, "false") + val f1 = field(FloatType, builder1) + assert(FractionalFieldValidator.validate(f1).isEmpty) + val builder2 = new MetadataBuilder().putString(MetadataKeys.AllowInfinity, "True") + val f2 = field(DoubleType, builder2) + assert(FractionalFieldValidator.validate(f2).isEmpty) + } + + test("allow_infinity not boolean") { + val builder = new MetadataBuilder().putString(MetadataKeys.AllowInfinity, "23") + val f = field(FloatType, builder) + assert(FractionalFieldValidator.validate(f) == Seq( + ValidationError(s"${MetadataKeys.AllowInfinity} metadata value of field 'test_field' is not Boolean in String format") + )) + } + + test("allow_infinity boolean in binary form") { + val builder = new MetadataBuilder().putBoolean(MetadataKeys.AllowInfinity, value = true) + val f = field(FloatType, builder) + assert(FractionalFieldValidator.validate(f) == Seq( + ValidationError(s"${MetadataKeys.AllowInfinity} metadata value of field 'test_field' is not Boolean in String format") + )) + } + + test("Decimal symbols redefined wrongly, invalid pattern") { + val builder = new MetadataBuilder() + .putString(MetadataKeys.GroupingSeparator, "") + .putString(MetadataKeys.DecimalSeparator, "xxx") + .putLong(MetadataKeys.MinusSign, 1) + .putString(MetadataKeys.Pattern, "0.###,#") + val f = field(DoubleType, builder) + val exp = Set( + ValidationError(s"${MetadataKeys.GroupingSeparator} metadata value of field 'test_field' is not Char in String format"), + ValidationError(s"${MetadataKeys.DecimalSeparator} metadata value of field 'test_field' is not Char in String format"), + ValidationError(s"${MetadataKeys.MinusSign} metadata value of field 'test_field' is not Char in String format"), + ValidationError("""Malformed pattern "0.###,#"""") + ) + assert(NumericFieldValidator.validate(f).toSet == exp) + } + + test("Pattern defined, default value doesn't adhere to it") { + val builder = new MetadataBuilder() + .putString(MetadataKeys.Pattern, "0.#MPH") + .putString(MetadataKeys.DefaultValue, "0.0") + val f = field(FloatType, builder) + assert(NumericFieldValidator.validate(f) == Seq( + ValidationError("Parsing of '0.0' failed.") + )) + } + +} diff --git a/src/test/scala/za/co/absa/standardization/validation/field/IntegralFieldValidatorSuite.scala b/src/test/scala/za/co/absa/standardization/validation/field/IntegralFieldValidatorSuite.scala new file mode 100644 index 0000000..9159538 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/validation/field/IntegralFieldValidatorSuite.scala @@ -0,0 +1,94 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, LongType, MetadataBuilder, ShortType, StructField} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.{ValidationError, ValidationWarning} +import za.co.absa.standardization.numeric.Radix +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.TypedStructField.IntegralTypeStructField +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField} + +class IntegralFieldValidatorSuite extends AnyFunSuite { + private implicit val defaults: Defaults = GlobalDefaults + + private def field(dataType: DataType, metadataBuilder: MetadataBuilder): IntegralTypeStructField[_] = { + val result = StructField("test_field", dataType, nullable = false, metadataBuilder.build()) + TypedStructField(result).asInstanceOf[IntegralTypeStructField[_]] + } + + private def field(dataType: DataType, pattern: Option[String], radix: Option[Radix]): IntegralTypeStructField[_] = { + val builder = new MetadataBuilder() + val builder2 = pattern.map(builder.putString(MetadataKeys.Pattern, _)).getOrElse(builder) + val builder3 = radix.map(r => builder2.putString(MetadataKeys.Radix, r.value.toString)).getOrElse(builder2) + field(dataType, builder3) + } + + test("No Radix nor Pattern defined") { + val f = field(LongType, None, None) + assert(IntegralFieldValidator.validate(f).isEmpty) + } + + test("Radix and Pattern collide") { + val f = field(IntegerType, Option("##0"), Option(Radix(3))) + assert(IntegralFieldValidator.validate(f) == Seq( + ValidationWarning("Both Radix and Pattern defined for field test_field, for Radix different from Radix(10) Pattern is ignored")) + ) + } + + test("Radix defined as default, Pattern defined non-default") { + val f = field(ByteType, Option("##0"), Option(Radix(10))) + assert(IntegralFieldValidator.validate(f).isEmpty) + } + + test("Radix defined is non-default, Pattern defined as default") { + val f = field(ShortType, Option(""), Option(Radix(16))) + assert(IntegralFieldValidator.validate(f).isEmpty) + } + + test("Decimal symbols redefined wrongly, invalid pattern") { + val builder = new MetadataBuilder() + .putString(MetadataKeys.GroupingSeparator, "Hello") + .putLong(MetadataKeys.DecimalSeparator, 789) + .putString(MetadataKeys.MinusSign, "") + .putString(MetadataKeys.Pattern, "%0.###,#") + val f = field(LongType, builder) + val exp = Set( + ValidationError(s"${MetadataKeys.GroupingSeparator} metadata value of field 'test_field' is not Char in String format"), + ValidationError(s"${MetadataKeys.DecimalSeparator} metadata value of field 'test_field' is not Char in String format"), + ValidationError(s"${MetadataKeys.MinusSign} metadata value of field 'test_field' is not Char in String format"), + ValidationError("""Malformed pattern "%0.###,#"""") + ) + + println(NumericFieldValidator.validate(f).toSet) + + println(exp) + + assert(NumericFieldValidator.validate(f).toSet == exp) + } + + test("Pattern defined, default value doesn't adhere to it") { + val builder = new MetadataBuilder() + .putString(MetadataKeys.Pattern, "0XP") + .putString(MetadataKeys.DefaultValue, "0.0XP") + val f = field(IntegerType, builder) + assert(NumericFieldValidator.validate(f) == Seq( + ValidationError("Parsing of '0.0XP' failed.") + )) + } +} diff --git a/src/test/scala/za/co/absa/standardization/validation/field/NumericFieldValidatorSuite.scala b/src/test/scala/za/co/absa/standardization/validation/field/NumericFieldValidatorSuite.scala new file mode 100644 index 0000000..0487dd0 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/validation/field/NumericFieldValidatorSuite.scala @@ -0,0 +1,92 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import org.apache.spark.sql.types.{DecimalType, MetadataBuilder, StructField} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.ValidationError +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.TypedStructField.NumericTypeStructField +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField} + +class NumericFieldValidatorSuite extends AnyFunSuite { + private implicit val defaults: Defaults = GlobalDefaults + + private def field(metadataBuilder: MetadataBuilder): NumericTypeStructField[_] = { + val result = StructField("test_field", DecimalType(15, 5), nullable = false, metadataBuilder.build()) + TypedStructField(result).asInstanceOf[NumericTypeStructField[_]] + } + + + test("No extra metadata") { + val builder = new MetadataBuilder + val f = field(builder) + assert(NumericFieldValidator.validate(f).isEmpty) + } + + test("Decimal symbols redefined") { + val builder = new MetadataBuilder() + .putString(MetadataKeys.GroupingSeparator, " ") + .putString(MetadataKeys.DecimalSeparator, ",") + .putString(MetadataKeys.MinusSign, "N") + val f = field(builder) + assert(NumericFieldValidator.validate(f).isEmpty) + } + + test("Pattern defined") { + val builder = new MetadataBuilder() + .putString(MetadataKeys.Pattern, "#,##0.#%") + .putString(MetadataKeys.DefaultValue, "100%") + val f = field(builder) + assert(NumericFieldValidator.validate(f).isEmpty) + } + + test("Pattern not string") { + val builder = new MetadataBuilder() + .putLong(MetadataKeys.Pattern, 0) + val f = field(builder) + assert(NumericFieldValidator.validate(f) == Seq( + ValidationError(s"${MetadataKeys.Pattern} metadata value of field 'test_field' is not String in String format") + )) + } + + test("Decimal symbols redefined wrongly, invalid pattern") { + val builder = new MetadataBuilder() + .putBoolean(MetadataKeys.GroupingSeparator, value = false) + .putString(MetadataKeys.DecimalSeparator, "") + .putString(MetadataKeys.MinusSign, "xyz") + .putString(MetadataKeys.Pattern, "0.0.0.0") + val f = field(builder) + val exp = Set( + ValidationError(s"${MetadataKeys.GroupingSeparator} metadata value of field 'test_field' is not Char in String format"), + ValidationError(s"${MetadataKeys.DecimalSeparator} metadata value of field 'test_field' is not Char in String format"), + ValidationError(s"${MetadataKeys.MinusSign} metadata value of field 'test_field' is not Char in String format"), + ValidationError("""Multiple decimal separators in pattern "0.0.0.0"""") + ) + assert(NumericFieldValidator.validate(f).toSet == exp) + } + + test("Pattern defined, default value doesn't adhere to it") { + val builder = new MetadataBuilder() + .putString(MetadataKeys.Pattern, "#,##0.#%") + .putString(MetadataKeys.DefaultValue, "100") + val f = field(builder) + assert(NumericFieldValidator.validate(f) == Seq( + ValidationError("Parsing of '100' failed.") + )) + } +} diff --git a/src/test/scala/za/co/absa/standardization/validation/field/ScalarFieldValidatorSuite.scala b/src/test/scala/za/co/absa/standardization/validation/field/ScalarFieldValidatorSuite.scala new file mode 100644 index 0000000..7000711 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/validation/field/ScalarFieldValidatorSuite.scala @@ -0,0 +1,52 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import org.apache.spark.sql.types.{MetadataBuilder, StringType, StructField} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.ValidationError +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField} + +class ScalarFieldValidatorSuite extends AnyFunSuite { + + private implicit val defaults: Defaults = GlobalDefaults + + test("Default value is set") { + val field = StructField("test_field", StringType, nullable = false, new MetadataBuilder().putString(MetadataKeys.DefaultValue, "foo").build()) + val testResult = ScalarFieldValidator.validate(TypedStructField(field)) + assert(testResult.isEmpty) + } + + test("Default value is not set") { + val field = StructField("test_field", StringType, nullable = true) + val testResult = ScalarFieldValidator.validate(TypedStructField(field)) + assert(testResult.isEmpty) + } + + test("Default value is set to NULL") { + val field = StructField("test_field", StringType, nullable = true, new MetadataBuilder().putString(MetadataKeys.DefaultValue, null).build()) + val testResult = ScalarFieldValidator.validate(TypedStructField(field)) + assert(testResult.isEmpty) + } + + test("Default value is set to non string value fails") { + val field = StructField("test_field", StringType, nullable = false, new MetadataBuilder().putBoolean(MetadataKeys.DefaultValue, value = true).build()) + val testResult = ScalarFieldValidator.validate(TypedStructField(field)) + assert(testResult == Seq(ValidationError("java.lang.Boolean cannot be cast to java.lang.String"))) + } +} diff --git a/src/test/scala/za/co/absa/standardization/validation/field/TimestampFieldValidatorSuite.scala b/src/test/scala/za/co/absa/standardization/validation/field/TimestampFieldValidatorSuite.scala new file mode 100644 index 0000000..f08afa7 --- /dev/null +++ b/src/test/scala/za/co/absa/standardization/validation/field/TimestampFieldValidatorSuite.scala @@ -0,0 +1,233 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.standardization.validation.field + +import org.apache.spark.sql.types.{MetadataBuilder, StructField, TimestampType} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.standardization.{ValidationError, ValidationIssue, ValidationWarning} +import za.co.absa.standardization.schema.MetadataKeys +import za.co.absa.standardization.time.TimeZoneNormalizer +import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField} + +class TimestampFieldValidatorSuite extends AnyFunSuite { + TimeZoneNormalizer.normalizeJVMTimeZone() + private implicit val defaults: Defaults = GlobalDefaults + + private def field(pattern: String, defaultValue: Option[String] = None, defaultTimeZone: Option[String] = None): TypedStructField = { + val builder = new MetadataBuilder().putString(MetadataKeys.Pattern, pattern) + val builder2 = defaultValue.map(builder.putString(MetadataKeys.DefaultValue, _)).getOrElse(builder) + val builder3 = defaultTimeZone.map(builder2.putString(MetadataKeys.DefaultTimeZone, _)).getOrElse(builder2) + val result = StructField("test_field", TimestampType, nullable = false, builder3.build()) + TypedStructField(result) + } + + test("epoch pattern") { + assert(TimestampFieldValidator.validate(field("epoch")).isEmpty) + //with default + assert(TimestampFieldValidator.validate(field("epoch", Option("5545556"))).isEmpty) + } + + test("epochmilli pattern") { + assert(TimestampFieldValidator.validate(field("epochmilli")).isEmpty) + //with default + assert(TimestampFieldValidator.validate(field("epochmilli", Option("5545556000"))).isEmpty) + } + + test("epochmicro pattern") { + assert(DateFieldValidator.validate(field("epochmicro")).isEmpty) + //with default + assert(DateFieldValidator.validate(field("epochmicro", Option("5545556000111"))).isEmpty) + } + + test("epochnano pattern") { + assert(DateFieldValidator.validate(field("epochnano")).isEmpty) + //with default + assert(DateFieldValidator.validate(field("epochnano", Option("5545556000111222"))).isEmpty) + } + + test("timestamp pattern") { + //no default + assert(TimestampFieldValidator.validate(field("yyyy-MM-dd HH:mm:ss")).isEmpty) + //default as timestamp + assert(TimestampFieldValidator.validate(field("HH-mm-ss~~dd.MM.yyyy", Option("23-10-11~~31.12.2004"))).isEmpty) + //extra chars in default + assert(TimestampFieldValidator.validate(field("HH-mm-ss~~dd.MM.yyyy", Option("23-10-11~~31.12.2004kkkkk"))).isEmpty) + } + + test("timestamp with time zone in pattern") { + //no default + assert(TimestampFieldValidator.validate(field("yyyy-MM-dd HH:mm:ss zz")).isEmpty) + //default as timestamp + assert(TimestampFieldValidator.validate(field("HH-mm-ss~~dd.MM.yyyy+zz", Option("23-10-11~~31.12.2004+CET"))).isEmpty) + //extra chars in default + assert(TimestampFieldValidator.validate(field("yyMMdd_HHmmss_zz", Option("190301_194533_EST!!!!"))).isEmpty) + //timestamp with offset time zone + assert(TimestampFieldValidator.validate(field("yyyy/MM/dd HH:mm:ssXXX", Option("2019/01/31 23:59:59-11:00"))).isEmpty) + } + + test("invalid pattern") { + val expected1 = Set( + ValidationError("Illegal pattern character 'f'") + ) + assert(TimestampFieldValidator.validate(field("fubar")).toSet == expected1) + val expected2 = Set( + ValidationError("Illegal pattern character 'x'") + ) + assert(TimestampFieldValidator.validate(field("yyMMdd_hhmmss_zz_xx")).toSet == expected2) + } + + test("invalid default") { + //empty default + val expected1 = Set( + ValidationError("""Unparseable date: """""), + ValidationWarning("Placeholder for hour 1-12 'h' found, but no am/pm 'a' placeholder. Possibly 0-23 'H' intended.") + ) + assert(TimestampFieldValidator.validate(field("yyMMdd_hhmmss_zz", Option(""))).toSet == expected1) + //wrong default + val expected2 = Set( + ValidationError("""Unparseable date: "1999-12-31""""), + ValidationWarning("No hour placeholder 'HH' found."), + ValidationWarning("No minute placeholder 'mm' found."), + ValidationWarning("No second placeholder 'ss' found.") + ) + assert(TimestampFieldValidator.validate(field("yyyy/MM/dd", Option("1999-12-31"))).toSet == expected2) + //invalid epoch default + val expected3 = Set( + ValidationError("'2019-01-01' cannot be cast to timestamp") + ) + assert(TimestampFieldValidator.validate(field("epoch", Option("2019-01-01"))).toSet == expected3) + //timestamp pattern, date default + val expected4 = Set( + ValidationError("""Unparseable date: "31.12.2004""""), + ValidationWarning("Placeholder for hour 1-12 'h' found, but no am/pm 'a' placeholder. Possibly 0-23 'H' intended.") + ) + assert(TimestampFieldValidator.validate(field("dd.MM.yyyy hh-mm-ss", Option("31.12.2004"))).toSet == expected4) + //epoch overflow + val expected5 = Set( + ValidationError("'8748743743948390823948239084294938231122123' cannot be cast to timestamp") + ) + assert(TimestampFieldValidator.validate(field("epoch", Option("8748743743948390823948239084294938231122123"))).toSet == expected5) + } + + test("utilizing default time zone") { + val pattern = "yyyy-MM-dd HH:mm:ss" + val value = Option("2000-01-01 00:00:00") + // full name + assert(TimestampFieldValidator.validate(field(pattern, value, Option("Africa/Johannesburg"))).isEmpty) + // abbreviation + assert(TimestampFieldValidator.validate(field(pattern, value, Option("CET"))).isEmpty) + // offset to GMT + assert(TimestampFieldValidator.validate(field(pattern, value, Option("Etc/GMT-6"))).isEmpty) + } + + test("issues with default time zone") { + def expected(timeZone: String): Set[ValidationIssue] = { + val q ="\"" + Set(ValidationError(s"$q$timeZone$q is not a valid time zone designation")) + } + val pattern = "yyyy-MM-dd HH:mm:ss" + val value = Option("2000-01-01 00:00:00") + // offset + val tz1 = "-03:00" + assert(TimestampFieldValidator.validate(field(pattern, value, Option(tz1))).toSet == expected(tz1)) + // empty + val tz2 = "" + assert(TimestampFieldValidator.validate(field(pattern, value, Option(tz2))).toSet == expected(tz2)) + // gibberish + val tz3 = "Gjh878-++_?" + assert(TimestampFieldValidator.validate(field(pattern, value, Option(tz3))).toSet == expected(tz3)) + // non-existing + val tz4 = "Africa/New York" + assert(TimestampFieldValidator.validate(field(pattern, value, Option(tz4))).toSet == expected(tz4)) + } + + test("warning issues: double time zone") { + val expected = Set( + ValidationWarning("Pattern includes time zone placeholder and default time zone is also defined (will never be used)") + ) + assert(TimestampFieldValidator.validate(field("yyyy-MM-dd HH:mm:ss XX", None, Option("CET"))).toSet == expected) + assert(TimestampFieldValidator.validate(field("yyyy-MM-dd HH:mm:ss zz", None, Option("CET"))).toSet == expected) + } + + test("warning issues: missing placeholders") { + val expected = Set( + ValidationWarning("No year placeholder 'yyyy' found."), + ValidationWarning("No month placeholder 'MM' found."), + ValidationWarning("No day placeholder 'dd' found."), + ValidationWarning("No hour placeholder 'HH' found."), + ValidationWarning("No minute placeholder 'mm' found."), + ValidationWarning("No second placeholder 'ss' found.") + ) + assert(TimestampFieldValidator.validate(field("GG")).toSet == expected) + } + + test("warning issues: missing placeholders with default time zone") { + val expected = Set( + ValidationWarning("No year placeholder 'yyyy' found."), + ValidationWarning("No month placeholder 'MM' found."), + ValidationWarning("No day placeholder 'dd' found."), + ValidationWarning("No hour placeholder 'HH' found."), + ValidationWarning("No minute placeholder 'mm' found."), + ValidationWarning("No second placeholder 'ss' found.") + ) + assert(TimestampFieldValidator.validate(field("GG", None, Option("CET"))).toSet == expected) + } + + test("warning issues: day placeholder wrong case") { + val expected = Set( + ValidationWarning("No day placeholder 'dd' found."), + ValidationWarning("Rarely used DayOfYear placeholder 'D' found. Possibly DayOfMonth 'd' intended.") + ) + assert(TimestampFieldValidator.validate(field("yyyy/MM/DD HH-mm-ss")).toSet == expected) + } + + test("warning issues: h instead of H") { + val expected = Set( + ValidationWarning("Placeholder for hour 1-12 'h' found, but no am/pm 'a' placeholder. Possibly 0-23 'H' intended.") + ) + assert(TimestampFieldValidator.validate(field("yyyy/MM/dd hh-mm-ss")).toSet == expected) + } + + test("warning issues: K instead of k") { + val expected = Set( + ValidationWarning("Placeholder for hour 0-11 'K' found, but no am/pm 'a' placeholder. Possibly 1-24 'k' intended.") + ) + assert(TimestampFieldValidator.validate(field("yyyy/MM/dd KK-mm-ss")).toSet == expected) + } + + test("warning issues: k and no H is ok") { + assert(TimestampFieldValidator.validate(field("yyyy/MM/dd kk-mm-ss")).isEmpty) + + } + + test("all placeholders") { + assert(TimestampFieldValidator.validate(field("X GG yyyy MM ww W DDD dd F E a HH kk KK hh mm ss SSS ZZ zzz")).isEmpty) + } + + test("nano seconds precision lost") { + val expected1 = Set( + ValidationWarning("Pattern 'epochnano'. While supported it comes with possible loss of precision beyond microseconds.") + ) + assert(TimestampFieldValidator.validate(field("epochnano")).toSet == expected1) + + val expected2 = Set( + ValidationWarning("Placeholder 'n' for nanoseconds recognized. While supported, it brings possible loss of precision.") + ) + assert(TimestampFieldValidator.validate(field("yyyy-MM-dd HH:mm:ss.nnnnnnnnn")).toSet == expected2) + + } +}