diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml index b715039160678..795a4d8f904b6 100644 --- a/docs/data/sql_functions.yml +++ b/docs/data/sql_functions.yml @@ -673,6 +673,9 @@ collection: - sql: MAP_KEYS(map) table: MAP.mapKeys() description: Returns the keys of the map as array. No order guaranteed. + - sql: MAP_UNION(map1, ...) + table: map1.mapUnion(...) + description: Returns a map created by merging at least one map. These maps should have a common map type. If there are overlapping keys, the value from 'map2' will overwrite the value from 'map1', the value from 'map3' will overwrite the value from 'map2', the value from 'mapn' will overwrite the value from 'map(n-1)'. If any of maps is null, return null. - sql: MAP_VALUES(map) table: MAP.mapValues() description: Returns the values of the map as array. No order guaranteed. diff --git a/docs/data/sql_functions_zh.yml b/docs/data/sql_functions_zh.yml index b67ac9498353b..fbb3e5d8415db 100644 --- a/docs/data/sql_functions_zh.yml +++ b/docs/data/sql_functions_zh.yml @@ -814,6 +814,9 @@ collection: - sql: MAP_FROM_ARRAYS(array_of_keys, array_of_values) table: mapFromArrays(array_of_keys, array_of_values) description: 返回由 key 的数组 keys 和 value 的数组 values 创建的 map。请注意两个数组的长度应该相等。 + - sql: MAP_UNION(map1, map2) + table: map1.mapUnion(map2) + description: 返回一个通过合并两个图 'map1' 和 'map2' 创建的图。这两个图应该具有共同的图类型。如果有重叠的键,'map2' 的值将覆盖 'map1' 的值。如果任一图为空,则返回 null。 json: - sql: IS JSON [ { VALUE | SCALAR | ARRAY | OBJECT } ] diff --git a/flink-python/docs/reference/pyflink.table/expressions.rst b/flink-python/docs/reference/pyflink.table/expressions.rst index dbc69682a3b37..e6ae0d9292173 100644 --- a/flink-python/docs/reference/pyflink.table/expressions.rst +++ b/flink-python/docs/reference/pyflink.table/expressions.rst @@ -240,6 +240,7 @@ advanced type helper functions Expression.array_union Expression.map_entries Expression.map_keys + Expression.map_union Expression.map_values Expression.array_except diff --git a/flink-python/pyflink/table/expression.py b/flink-python/pyflink/table/expression.py index 8c892b52e3508..648ad62bfb541 100644 --- a/flink-python/pyflink/table/expression.py +++ b/flink-python/pyflink/table/expression.py @@ -1627,6 +1627,17 @@ def map_keys(self) -> 'Expression': """ return _unary_op("mapKeys")(self) + def map_union(self, *maps) -> 'Expression': + """ + Returns a map created by merging at least one map. These maps should have a common map type. + If there are overlapping keys, the value from 'map2' will overwrite the value from 'map1', + the value from 'map3' will overwrite the value from 'map2', the value from 'mapn' will + overwrite the value from 'map(n-1)'. If any of maps is null, return null. + + .. seealso:: :py:attr:`~Expression.map_union` + """ + return _binary_op("mapUnion")(self, *maps) + @property def map_values(self) -> 'Expression': """ diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java index 090274d89a59c..bf5df75126fbe 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java @@ -132,6 +132,7 @@ import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.LTRIM; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MAP_ENTRIES; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MAP_KEYS; +import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MAP_UNION; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MAP_VALUES; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MAX; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MD5; @@ -1548,6 +1549,21 @@ public OutType mapEntries() { return toApiSpecificExpression(unresolvedCall(MAP_ENTRIES, toExpr())); } + /** + * Returns a map created by merging at least one map. These maps should have a common map type. + * If there are overlapping keys, the value from 'map2' will overwrite the value from 'map1', + * the value from 'map3' will overwrite the value from 'map2', the value from 'mapn' will + * overwrite the value from 'map(n-1)'. If any of maps is null, return null. + */ + public OutType mapUnion(InType... inputs) { + Expression[] args = + Stream.concat( + Stream.of(toExpr()), + Arrays.stream(inputs).map(ApiExpressionUtils::objectToExpression)) + .toArray(Expression[]::new); + return toApiSpecificExpression(unresolvedCall(MAP_UNION, args)); + } + // Time definition /** diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index 07de1eb2c9c64..3c7b2c6680bfb 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -78,6 +78,7 @@ import static org.apache.flink.table.types.inference.InputTypeStrategies.TYPE_LITERAL; import static org.apache.flink.table.types.inference.InputTypeStrategies.and; import static org.apache.flink.table.types.inference.InputTypeStrategies.commonArrayType; +import static org.apache.flink.table.types.inference.InputTypeStrategies.commonMapType; import static org.apache.flink.table.types.inference.InputTypeStrategies.commonMultipleArrayType; import static org.apache.flink.table.types.inference.InputTypeStrategies.commonType; import static org.apache.flink.table.types.inference.InputTypeStrategies.comparable; @@ -167,6 +168,16 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL) "org.apache.flink.table.runtime.functions.scalar.MapValuesFunction") .build(); + public static final BuiltInFunctionDefinition MAP_UNION = + BuiltInFunctionDefinition.newBuilder() + .name("MAP_UNION") + .kind(SCALAR) + .inputTypeStrategy(commonMapType(1)) + .outputTypeStrategy(COMMON) + .runtimeClass( + "org.apache.flink.table.runtime.functions.scalar.MapUnionFunction") + .build(); + public static final BuiltInFunctionDefinition MAP_ENTRIES = BuiltInFunctionDefinition.newBuilder() .name("MAP_ENTRIES") diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/InputTypeStrategies.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/InputTypeStrategies.java index ef6b1b20f59be..5bea0c20e16bd 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/InputTypeStrategies.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/InputTypeStrategies.java @@ -26,6 +26,7 @@ import org.apache.flink.table.types.inference.strategies.CommonArgumentTypeStrategy; import org.apache.flink.table.types.inference.strategies.CommonArrayInputTypeStrategy; import org.apache.flink.table.types.inference.strategies.CommonInputTypeStrategy; +import org.apache.flink.table.types.inference.strategies.CommonMapInputTypeStrategy; import org.apache.flink.table.types.inference.strategies.ComparableTypeStrategy; import org.apache.flink.table.types.inference.strategies.CompositeArgumentTypeStrategy; import org.apache.flink.table.types.inference.strategies.ConstraintArgumentTypeStrategy; @@ -368,6 +369,14 @@ public static InputTypeStrategy commonMultipleArrayType(int minCount) { /** @see ItemAtIndexArgumentTypeStrategy */ public static final ArgumentTypeStrategy ITEM_AT_INDEX = new ItemAtIndexArgumentTypeStrategy(); + /** + * An {@link InputTypeStrategy} that expects {@code minCount} arguments that have a common map + * type. + */ + public static InputTypeStrategy commonMapType(int minCount) { + return new CommonMapInputTypeStrategy(ConstantArgumentCount.from(minCount)); + } + // -------------------------------------------------------------------------------------------- private InputTypeStrategies() { diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CommonArrayInputTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CommonArrayInputTypeStrategy.java index a3963673e78a9..6e07f16e61d14 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CommonArrayInputTypeStrategy.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CommonArrayInputTypeStrategy.java @@ -19,96 +19,15 @@ package org.apache.flink.table.types.inference.strategies; import org.apache.flink.annotation.Internal; -import org.apache.flink.table.functions.FunctionDefinition; -import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.inference.ArgumentCount; -import org.apache.flink.table.types.inference.CallContext; import org.apache.flink.table.types.inference.InputTypeStrategy; -import org.apache.flink.table.types.inference.Signature; -import org.apache.flink.table.types.inference.Signature.Argument; -import org.apache.flink.table.types.logical.LegacyTypeInformationType; -import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeRoot; -import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; -import org.apache.flink.table.types.utils.TypeConversions; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; -import java.util.stream.IntStream; /** An {@link InputTypeStrategy} that expects that all arguments have a common array type. */ @Internal -public final class CommonArrayInputTypeStrategy implements InputTypeStrategy { - - private static final Argument COMMON_ARGUMENT = Argument.ofGroup("COMMON"); - - private final ArgumentCount argumentCount; +public final class CommonArrayInputTypeStrategy extends CommonCollectionInputTypeStrategy { public CommonArrayInputTypeStrategy(ArgumentCount argumentCount) { - this.argumentCount = argumentCount; - } - - @Override - public ArgumentCount getArgumentCount() { - return argumentCount; - } - - @Override - public Optional> inferInputTypes( - CallContext callContext, boolean throwOnFailure) { - List argumentDataTypes = callContext.getArgumentDataTypes(); - List argumentTypes = - argumentDataTypes.stream() - .map(DataType::getLogicalType) - .collect(Collectors.toList()); - - if (!argumentTypes.stream() - .allMatch(logicalType -> logicalType.is(LogicalTypeRoot.ARRAY))) { - return callContext.fail(throwOnFailure, "All arguments requires to be a ARRAY type"); - } - - if (argumentTypes.stream().anyMatch(CommonArrayInputTypeStrategy::isLegacyType)) { - return Optional.of(argumentDataTypes); - } - - Optional commonType = LogicalTypeMerging.findCommonType(argumentTypes); - - if (!commonType.isPresent()) { - return callContext.fail( - throwOnFailure, - "Could not find a common type for arguments: %s", - argumentDataTypes); - } - - return commonType.map( - type -> - Collections.nCopies( - argumentTypes.size(), TypeConversions.fromLogicalToDataType(type))); - } - - @Override - public List getExpectedSignatures(FunctionDefinition definition) { - Optional minCount = argumentCount.getMinCount(); - Optional maxCount = argumentCount.getMaxCount(); - - int numberOfMandatoryArguments = minCount.orElse(0); - - if (maxCount.isPresent()) { - return IntStream.range(numberOfMandatoryArguments, maxCount.get() + 1) - .mapToObj(count -> Signature.of(Collections.nCopies(count, COMMON_ARGUMENT))) - .collect(Collectors.toList()); - } - - List arguments = - new ArrayList<>(Collections.nCopies(numberOfMandatoryArguments, COMMON_ARGUMENT)); - arguments.add(Argument.ofGroupVarying("COMMON")); - return Collections.singletonList(Signature.of(arguments)); - } - - private static boolean isLegacyType(LogicalType type) { - return type instanceof LegacyTypeInformationType; + super(argumentCount, "All arguments requires to be a ARRAY type", LogicalTypeRoot.ARRAY); } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CommonCollectionInputTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CommonCollectionInputTypeStrategy.java new file mode 100644 index 0000000000000..3226db66468aa --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CommonCollectionInputTypeStrategy.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.ArgumentCount; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.inference.Signature; +import org.apache.flink.table.types.inference.Signature.Argument; +import org.apache.flink.table.types.logical.LegacyTypeInformationType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; +import org.apache.flink.table.types.utils.TypeConversions; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** An {@link InputTypeStrategy} that expects that all arguments have a common type. */ +@Internal +public class CommonCollectionInputTypeStrategy implements InputTypeStrategy { + + private static final Argument COMMON_ARGUMENT = Argument.ofGroup("COMMON"); + + private final ArgumentCount argumentCount; + + private final LogicalTypeRoot logicalTypeRoot; + private final String errorMessage; + + public CommonCollectionInputTypeStrategy( + ArgumentCount argumentCount, String errorMessage, LogicalTypeRoot logicalTypeRoot) { + this.argumentCount = argumentCount; + this.errorMessage = errorMessage; + this.logicalTypeRoot = logicalTypeRoot; + } + + @Override + public ArgumentCount getArgumentCount() { + return argumentCount; + } + + @Override + public Optional> inferInputTypes( + CallContext callContext, boolean throwOnFailure) { + List argumentDataTypes = callContext.getArgumentDataTypes(); + List argumentTypes = + argumentDataTypes.stream() + .map(DataType::getLogicalType) + .collect(Collectors.toList()); + + if (!argumentTypes.stream().allMatch(logicalType -> logicalType.is(logicalTypeRoot))) { + return callContext.fail(throwOnFailure, errorMessage); + } + + if (argumentTypes.stream().anyMatch(CommonCollectionInputTypeStrategy::isLegacyType)) { + return Optional.of(argumentDataTypes); + } + + Optional commonType = LogicalTypeMerging.findCommonType(argumentTypes); + + if (!commonType.isPresent()) { + return callContext.fail( + throwOnFailure, + "Could not find a common type for arguments: %s", + argumentDataTypes); + } + + return commonType.map( + type -> + Collections.nCopies( + argumentTypes.size(), TypeConversions.fromLogicalToDataType(type))); + } + + @Override + public List getExpectedSignatures(FunctionDefinition definition) { + Optional minCount = argumentCount.getMinCount(); + Optional maxCount = argumentCount.getMaxCount(); + + int numberOfMandatoryArguments = minCount.orElse(0); + + if (maxCount.isPresent()) { + return IntStream.range(numberOfMandatoryArguments, maxCount.get() + 1) + .mapToObj(count -> Signature.of(Collections.nCopies(count, COMMON_ARGUMENT))) + .collect(Collectors.toList()); + } + + List arguments = + new ArrayList<>(Collections.nCopies(numberOfMandatoryArguments, COMMON_ARGUMENT)); + arguments.add(Argument.ofGroupVarying("COMMON")); + return Collections.singletonList(Signature.of(arguments)); + } + + private static boolean isLegacyType(LogicalType type) { + return type instanceof LegacyTypeInformationType; + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CommonMapInputTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CommonMapInputTypeStrategy.java new file mode 100644 index 0000000000000..ad35b44caa2ce --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/CommonMapInputTypeStrategy.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.inference.ArgumentCount; +import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.logical.LogicalTypeRoot; + +/** An {@link InputTypeStrategy} that expects that all arguments have a common map type. */ +@Internal +public final class CommonMapInputTypeStrategy extends CommonCollectionInputTypeStrategy { + + public CommonMapInputTypeStrategy(ArgumentCount argumentCount) { + super(argumentCount, "All arguments requires to be a MAP type", LogicalTypeRoot.MAP); + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CommonArrayInputTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CommonArrayInputTypeStrategyTest.java deleted file mode 100644 index 1e95c884f0879..0000000000000 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CommonArrayInputTypeStrategyTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.types.inference.strategies; - -import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.types.inference.InputTypeStrategies; -import org.apache.flink.table.types.inference.InputTypeStrategiesTestBase; - -import java.util.stream.Stream; - -/** Tests for {@link CommonArrayInputTypeStrategy}. */ -class CommonArrayInputTypeStrategyTest extends InputTypeStrategiesTestBase { - - @Override - protected Stream testData() { - return Stream.of( - TestSpec.forStrategy(InputTypeStrategies.commonArrayType(2)) - .expectSignature("f(, )") - .calledWithArgumentTypes( - DataTypes.ARRAY(DataTypes.INT()), - DataTypes.ARRAY(DataTypes.DOUBLE().notNull()).notNull()) - .expectArgumentTypes( - DataTypes.ARRAY(DataTypes.DOUBLE()), - DataTypes.ARRAY(DataTypes.DOUBLE())), - TestSpec.forStrategy( - "Strategy fails if not all of the argument types are ARRAY", - InputTypeStrategies.commonArrayType(2)) - .calledWithArgumentTypes(DataTypes.INT(), DataTypes.ARRAY(DataTypes.INT())) - .expectErrorMessage("All arguments requires to be a ARRAY type"), - TestSpec.forStrategy( - "Strategy fails if can not find a common type", - InputTypeStrategies.commonArrayType(2)) - .calledWithArgumentTypes( - DataTypes.ARRAY(DataTypes.INT()), - DataTypes.ARRAY(DataTypes.STRING())) - .expectErrorMessage( - "Could not find a common type for arguments: [ARRAY, ARRAY]")); - } -} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CommonCollectionInputTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CommonCollectionInputTypeStrategyTest.java new file mode 100644 index 0000000000000..8d9656614af50 --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/strategies/CommonCollectionInputTypeStrategyTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.inference.InputTypeStrategies; +import org.apache.flink.table.types.inference.InputTypeStrategiesTestBase; + +import java.util.stream.Stream; + +/** Tests for {@link CommonCollectionInputTypeStrategy}. */ +class CommonCollectionInputTypeStrategyTest extends InputTypeStrategiesTestBase { + + @Override + protected Stream testData() { + return Stream.of( + TestSpec.forStrategy(InputTypeStrategies.commonArrayType(2)) + .expectSignature("f(, )") + .calledWithArgumentTypes( + DataTypes.ARRAY(DataTypes.INT()), + DataTypes.ARRAY(DataTypes.DOUBLE().notNull()).notNull()) + .expectArgumentTypes( + DataTypes.ARRAY(DataTypes.DOUBLE()), + DataTypes.ARRAY(DataTypes.DOUBLE())), + TestSpec.forStrategy( + "Strategy fails if not all of the argument types are ARRAY", + InputTypeStrategies.commonArrayType(2)) + .calledWithArgumentTypes(DataTypes.INT(), DataTypes.ARRAY(DataTypes.INT())) + .expectErrorMessage("All arguments requires to be a ARRAY type"), + TestSpec.forStrategy( + "Strategy fails if can not find a common type", + InputTypeStrategies.commonArrayType(2)) + .calledWithArgumentTypes( + DataTypes.ARRAY(DataTypes.INT()), + DataTypes.ARRAY(DataTypes.STRING())) + .expectErrorMessage( + "Could not find a common type for arguments: [ARRAY, ARRAY]"), + TestSpec.forStrategy(InputTypeStrategies.commonMultipleArrayType(2)) + .expectSignature("f(, , ...)") + .calledWithArgumentTypes( + DataTypes.ARRAY(DataTypes.INT()), + DataTypes.ARRAY(DataTypes.DOUBLE().notNull()).notNull()) + .expectArgumentTypes( + DataTypes.ARRAY(DataTypes.DOUBLE()), + DataTypes.ARRAY(DataTypes.DOUBLE())), + TestSpec.forStrategy( + "Strategy fails if not all of the argument types are ARRAY", + InputTypeStrategies.commonMultipleArrayType(2)) + .calledWithArgumentTypes(DataTypes.INT(), DataTypes.ARRAY(DataTypes.INT())) + .expectErrorMessage("All arguments requires to be a ARRAY type"), + TestSpec.forStrategy( + "Strategy fails if can not find a common type", + InputTypeStrategies.commonMultipleArrayType(2)) + .calledWithArgumentTypes( + DataTypes.ARRAY(DataTypes.INT()), + DataTypes.ARRAY(DataTypes.STRING())) + .expectErrorMessage( + "Could not find a common type for arguments: [ARRAY, ARRAY]"), + TestSpec.forStrategy(InputTypeStrategies.commonMapType(2)) + .expectSignature("f(, , ...)") + .calledWithArgumentTypes( + DataTypes.MAP(DataTypes.INT(), DataTypes.INT()), + DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.DOUBLE())) + .expectArgumentTypes( + DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.DOUBLE()), + DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.DOUBLE())), + TestSpec.forStrategy( + "Strategy fails if can not find a common type", + InputTypeStrategies.commonMapType(2)) + .calledWithArgumentTypes( + DataTypes.MAP(DataTypes.STRING(), DataTypes.INT()), + DataTypes.MAP(DataTypes.INT(), DataTypes.INT())) + .expectErrorMessage( + "Could not find a common type for arguments: [MAP, MAP]"), + TestSpec.forStrategy( + "Strategy fails if not all of the argument types are MAP", + InputTypeStrategies.commonMapType(2)) + .calledWithArgumentTypes( + DataTypes.ARRAY(DataTypes.INT()), + DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())) + .expectErrorMessage("All arguments requires to be a MAP type")); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/MapFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/MapFunctionITCase.java index 991e33ff0ed6d..4684447b0a898 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/MapFunctionITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/MapFunctionITCase.java @@ -47,6 +47,7 @@ import static org.apache.flink.table.api.DataTypes.TIMESTAMP; import static org.apache.flink.table.api.Expressions.$; import static org.apache.flink.table.api.Expressions.call; +import static org.apache.flink.table.api.Expressions.lit; import static org.apache.flink.table.api.Expressions.map; import static org.apache.flink.table.api.Expressions.mapFromArrays; import static org.apache.flink.util.CollectionUtil.entry; @@ -71,7 +72,8 @@ Stream getTestSetSpecs() { mapKeysTestCases(), mapValuesTestCases(), mapEntriesTestCases(), - mapFromArraysTestCases()) + mapFromArraysTestCases(), + mapUnionTestCases()) .flatMap(s -> s); } @@ -396,4 +398,200 @@ private Stream mapFromArraysTestCases() { DataTypes.MAP( DataTypes.STRING(), DataTypes.ARRAY(DataTypes.INT())))); } + + private Stream mapUnionTestCases() { + return Stream.of( + TestSetSpec.forFunction(BuiltInFunctionDefinitions.MAP_UNION) + .onFieldsWithData( + null, + "item", + CollectionUtil.map( + entry("one", new Integer[] {1, 2}), + entry("two", new Integer[] {3, 4})), + CollectionUtil.map( + entry("one", new Integer[] {2, 2}), + entry("two", new Integer[] {8, 4})), + CollectionUtil.map( + entry(2, new Integer[] {1, 2}), + entry(7, new Integer[] {3, 4})), + CollectionUtil.map(entry("one", 2), entry("two", 5)), + new Integer[] {1, 2, 3, 4, 5, null}, + new String[] {"1", "3", "5", "7", "9", null}, + null, + CollectionUtil.map(entry(1, 2)), + CollectionUtil.map( + entry(1, 3), + entry(2, 4), + entry(lit(null, DataTypes.INT()), 3)), + lit(null, DataTypes.MAP(DataTypes.INT(), DataTypes.INT()))) + .andDataTypes( + DataTypes.BOOLEAN().nullable(), + DataTypes.STRING(), + DataTypes.MAP(DataTypes.STRING(), DataTypes.ARRAY(DataTypes.INT())), + DataTypes.MAP(DataTypes.STRING(), DataTypes.ARRAY(DataTypes.INT())), + DataTypes.MAP(DataTypes.INT(), DataTypes.ARRAY(DataTypes.INT())), + DataTypes.MAP(DataTypes.STRING(), DataTypes.INT()), + DataTypes.ARRAY(DataTypes.INT()), + DataTypes.ARRAY(DataTypes.STRING()), + DataTypes.MAP(DataTypes.INT(), DataTypes.INT()), + DataTypes.MAP(DataTypes.INT(), DataTypes.INT()), + DataTypes.MAP(DataTypes.INT(), DataTypes.INT()), + DataTypes.MAP(DataTypes.INT(), DataTypes.INT())) + .testResult( + $("f10").mapUnion( + CollectionUtil.map( + entry(lit(null, DataTypes.INT()), 8))), + "MAP_UNION(f10, MAP[CAST(NULL AS INT), 8])", + CollectionUtil.map(entry(1, 3), entry(2, 4), entry(null, 8)), + DataTypes.MAP(DataTypes.INT(), DataTypes.INT())) + .testResult( + $("f9").mapUnion( + CollectionUtil.map( + entry(lit(null, DataTypes.INT()), 3))), + "MAP_UNION(f9, MAP[CAST(NULL AS INT), 3])", + CollectionUtil.map(entry(null, 3), entry(1, 2)), + DataTypes.MAP(DataTypes.INT(), DataTypes.INT())) + .testResult( + $("f8").mapUnion( + lit( + null, + DataTypes.MAP( + DataTypes.INT(), DataTypes.INT()))), + "MAP_UNION(f8, CAST(NULL AS MAP))", + null, + DataTypes.MAP(DataTypes.INT(), DataTypes.INT())) + .testResult( + $("f9").mapUnion( + lit( + null, + DataTypes.MAP( + DataTypes.INT(), DataTypes.INT()))), + "MAP_UNION(f9, CAST(NULL AS MAP))", + null, + DataTypes.MAP(DataTypes.INT(), DataTypes.INT())) + .testResult( + $("f11").mapUnion(CollectionUtil.map(entry(1, 2))), + "MAP_UNION(f11, MAP[1, 2])", + null, + DataTypes.MAP(DataTypes.INT(), DataTypes.INT())) + .testResult( + $("f2").mapUnion( + CollectionUtil.map( + entry("one", new Integer[] {2, 2}), + entry("two", new Integer[] {8, 4}), + entry("three", new Integer[] {1, 2}))), + "MAP_UNION(f2, MAP['one', ARRAY[2,2], 'two', ARRAY[8, 4], 'three', ARRAY[1, 2]])", + CollectionUtil.map( + entry("one", new Integer[] {2, 2}), + entry("two", new Integer[] {8, 4}), + entry("three", new Integer[] {1, 2})), + DataTypes.MAP(DataTypes.STRING(), DataTypes.ARRAY(DataTypes.INT()))) + .testResult( + $("f2").mapUnion( + CollectionUtil.map( + entry("one", new Integer[] {2, 2}), + entry("two", new Integer[] {8, 4}), + entry("three", new Integer[] {1, 2})), + CollectionUtil.map( + entry("one", new Integer[] {2, 9}), + entry("four", new Integer[] {8, 4}), + entry("five", new Integer[] {1, 2}))), + "MAP_UNION(f2, MAP['one', ARRAY[2,2], 'two', ARRAY[8, 4], 'three', ARRAY[1, 2]], MAP['one', ARRAY[2,9], 'four', ARRAY[8, 4], 'five', ARRAY[1, 2]])", + CollectionUtil.map( + entry("one", new Integer[] {2, 9}), + entry("two", new Integer[] {8, 4}), + entry("three", new Integer[] {1, 2}), + entry("four", new Integer[] {8, 4}), + entry("five", new Integer[] {1, 2})), + DataTypes.MAP(DataTypes.STRING(), DataTypes.ARRAY(DataTypes.INT()))) + .testResult( + $("f4").mapUnion( + CollectionUtil.map( + entry(1, new Integer[] {2, 2}), + entry(2, new Integer[] {8, 4}), + entry(3, new Integer[] {1, 2}))), + "MAP_UNION(f4, MAP[1, ARRAY[2,2], 2, ARRAY[8, 4], 3, ARRAY[1, 2]])", + CollectionUtil.map( + entry(1, new Integer[] {2, 2}), + entry(2, new Integer[] {8, 4}), + entry(3, new Integer[] {1, 2}), + entry(7, new Integer[] {3, 4})), + DataTypes.MAP(DataTypes.INT(), DataTypes.ARRAY(DataTypes.INT()))) + .testTableApiValidationError( + $("f2").mapUnion( + CollectionUtil.map( + entry(1, new Integer[] {2, 2}), + entry(2, new Integer[] {8, 4}), + entry(3, new Integer[] {1, 2}))), + "Invalid function call:\n" + + "MAP_UNION(MAP>, MAP NOT NULL> NOT NULL)") + .testSqlValidationError( + "MAP_UNION(f2, MAP[1, ARRAY[2,2], 2, ARRAY[8, 4], 3, ARRAY[1, 2]])", + "SQL validation failed. Invalid function call:\n" + + "MAP_UNION(MAP>, MAP NOT NULL> NOT NULL)") + .testTableApiValidationError( + $("f0").mapUnion( + CollectionUtil.map( + entry(1, new Integer[] {2, 2}), + entry(2, new Integer[] {8, 4}), + entry(3, new Integer[] {1, 2}))), + "Invalid function call:\n" + + "MAP_UNION(BOOLEAN, MAP NOT NULL> NOT NULL)") + .testSqlValidationError( + "MAP_UNION(f0, MAP[1, ARRAY[2,2], 2, ARRAY[8, 4], 3, ARRAY[1, 2]])", + "SQL validation failed. Invalid function call:\n" + + "MAP_UNION(BOOLEAN, MAP NOT NULL> NOT NULL)") + .testTableApiValidationError( + $("f1").mapUnion( + CollectionUtil.map( + entry(1, new Integer[] {2, 2}), + entry(2, new Integer[] {8, 4}), + entry(3, new Integer[] {1, 2}))), + "Invalid function call:\n" + + "MAP_UNION(STRING, MAP NOT NULL> NOT NULL)") + .testSqlValidationError( + "MAP_UNION(f1, MAP[1, ARRAY[2,2], 2, ARRAY[8, 4], 3, ARRAY[1, 2]])", + "SQL validation failed. Invalid function call:\n" + + "MAP_UNION(STRING, MAP NOT NULL> NOT NULL)") + .testTableApiValidationError( + $("f2").mapUnion( + CollectionUtil.map( + entry("1", 1), + entry("2", 2), + entry("3", 3))), + "Invalid input arguments. Expected signatures are:\n" + + "MAP_UNION(, ...)") + .testSqlValidationError( + "MAP_UNION(f2, MAP['1', 1, '2', 2, '3', 3])", + "Invalid input arguments. Expected signatures are:\n" + + "MAP_UNION(, ...)") + .testTableApiValidationError( + $("f5").mapUnion(new String[] {"123"}), + "Invalid input arguments. Expected signatures are:\n" + + "MAP_UNION(, ...)") + .testSqlValidationError( + "MAP_UNION(f5, ARRAY['123'])", + "Invalid input arguments. Expected signatures are:\n" + + "MAP_UNION(, ...)") + .testTableApiValidationError( + $("f6").mapUnion( + CollectionUtil.map( + entry("1", 1), + entry("2", 2), + entry("3", 3))), + "Invalid input arguments. Expected signatures are:\n" + + "MAP_UNION(, ...)") + .testSqlValidationError( + "MAP_UNION(f6, MAP['1', 1, '2', 2, '3', 3])", + "Invalid input arguments. Expected signatures are:\n" + + "MAP_UNION(, ...)") + .testTableApiValidationError( + $("f7").mapUnion(new Integer[] {1, 2, 3, 4}), + "Invalid input arguments. Expected signatures are:\n" + + "MAP_UNION(, ...)") + .testSqlValidationError( + "MAP_UNION(f7, ARRAY[1, 2, 3, 4])", + "Invalid input arguments. Expected signatures are:\n" + + "MAP_UNION(, ...)")); + } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/scalar/MapUnionFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/scalar/MapUnionFunction.java new file mode 100644 index 0000000000000..fef1f730e0eca --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/scalar/MapUnionFunction.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.scalar; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.SpecializedFunction; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.KeyValueDataType; +import org.apache.flink.util.FlinkRuntimeException; + +import javax.annotation.Nullable; + +import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.table.api.Expressions.$; + +/** Implementation of {@link BuiltInFunctionDefinitions#MAP_UNION}. */ +@Internal +public class MapUnionFunction extends BuiltInScalarFunction { + private final ArrayData.ElementGetter keyElementGetter; + private final ArrayData.ElementGetter valueElementGetter; + + private final SpecializedFunction.ExpressionEvaluator keyEqualityEvaluator; + private transient MethodHandle keyEqualityHandle; + + public MapUnionFunction(SpecializedFunction.SpecializedContext context) { + super(BuiltInFunctionDefinitions.MAP_UNION, context); + KeyValueDataType outputType = + ((KeyValueDataType) context.getCallContext().getOutputDataType().get()); + final DataType keyDataType = outputType.getKeyDataType(); + final DataType valueDataType = outputType.getValueDataType(); + keyElementGetter = + ArrayData.createElementGetter(outputType.getKeyDataType().getLogicalType()); + valueElementGetter = + ArrayData.createElementGetter(outputType.getValueDataType().getLogicalType()); + keyEqualityEvaluator = + context.createEvaluator( + $("element1").isEqual($("element2")), + DataTypes.BOOLEAN(), + DataTypes.FIELD("element1", keyDataType.notNull().toInternal()), + DataTypes.FIELD("element2", keyDataType.notNull().toInternal())); + } + + @Override + public void open(FunctionContext context) throws Exception { + keyEqualityHandle = keyEqualityEvaluator.open(context); + } + + public @Nullable MapData eval(@Nullable MapData... maps) { + try { + if (maps == null || maps.length == 0) { + return null; + } + if (maps.length == 1) { + return maps[0]; + } + MapData result = maps[0]; + if (result == null) { + return null; + } + for (int i = 1; i < maps.length; ++i) { + MapData map = maps[i]; + if (map == null) { + return null; + } + if (map.size() > 0) { + result = new MapDataForMapUnion(result, map); + } + } + return result; + } catch (Throwable t) { + throw new FlinkRuntimeException(t); + } + } + + private class MapDataForMapUnion implements MapData { + private final GenericArrayData keysArray; + private final GenericArrayData valuesArray; + + public MapDataForMapUnion(MapData map1, MapData map2) throws Throwable { + List keysList = new ArrayList<>(); + List valuesList = new ArrayList<>(); + boolean isKeyNullExist = false; + ArrayData keyArray2 = map2.keyArray(); + ArrayData valueArray2 = map2.valueArray(); + for (int i = 0; i < map2.size(); i++) { + Object key = keyElementGetter.getElementOrNull(keyArray2, i); + if (key == null) { + isKeyNullExist = true; + } + keysList.add(key); + valuesList.add(valueElementGetter.getElementOrNull(valueArray2, i)); + } + ArrayData keyArray1 = map1.keyArray(); + ArrayData valueArray1 = map1.valueArray(); + for (int i = 0; i < map1.size(); i++) { + final Object key1 = keyElementGetter.getElementOrNull(keyArray1, i); + + boolean keyExists = false; + if (key1 != null) { + for (int j = 0; j < keysList.size(); j++) { + final Object key2 = keysList.get(j); + if (key2 != null && (boolean) keyEqualityHandle.invoke(key1, key2)) { + // If key exists in map2, skip this key-value pair + keyExists = true; + break; + } + } + } + + if (isKeyNullExist && key1 == null) { + continue; + } + + // If key doesn't exist in map2, add the key-value pair from map1 + if (!keyExists) { + final Object value1 = valueElementGetter.getElementOrNull(valueArray1, i); + keysList.add(key1); + valuesList.add(value1); + } + } + this.keysArray = new GenericArrayData(keysList.toArray()); + this.valuesArray = new GenericArrayData(valuesList.toArray()); + } + + @Override + public int size() { + return keysArray.size(); + } + + @Override + public ArrayData keyArray() { + return keysArray; + } + + @Override + public ArrayData valueArray() { + return valuesArray; + } + } + + @Override + public void close() throws Exception { + keyEqualityEvaluator.close(); + } +}