Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Skip conversion of shape.shapeof with 0-ranked tensor operand (#2107)
Currently `--shape-legalize-to-stablehlo` fails on the following code: ``` func.func @test1() -> tensor<0xindex> { %0 = arith.constant dense<0> : tensor<i32> %1 = shape.shape_of %0 : tensor<i32> -> tensor<0xindex> func.return %1 : tensor<0xindex> } ``` at `stablehlo/dialect/TypeInference.cpp:1700`: ``` // concatenate_c5 auto elementType = inputTypes[0].cast<ShapedType>().getElementType(); ``` as `inputTypes.size()` is zero. I have checked how it works on non-0 ranked tensor type: ``` func.func @test2() -> tensor<2xindex> { %1 = arith.constant dense<0> : tensor<2x128xi32> %3 = shape.shape_of %1 : tensor<2x128xi32> -> tensor<2xindex> func.return %3 : tensor<2xindex> } ``` produces: ``` func.func @test2() -> tensor<2xindex> { %cst = arith.constant dense<0> : tensor<2x128xi32> %0 = stablehlo.get_dimension_size %cst, dim = 0 : (tensor<2x128xi32>) -> tensor<i32> %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32> %2 = stablehlo.get_dimension_size %cst, dim = 1 : (tensor<2x128xi32>) -> tensor<i32> %3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32> %4 = stablehlo.concatenate %1, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %5 = builtin.unrealized_conversion_cast %4 : tensor<2xi32> to tensor<2xindex> return %5 : tensor<2xindex> } ``` I suggest considering an alternative approach instead of simply bailing out; one option could be generating a constant tensor with zero dimension: ``` func.func @test1() -> tensor<0xindex> { %cst = arith.constant dense<0> : tensor<i32> %0 = stablehlo.constant dense<> : tensor<0xi32> %1 = builtin.unrealized_conversion_cast %0 : tensor<0xi32> to tensor<0xindex> return %1 : tensor<0xindex> } ``` but i am not entirely certain in semantic equivalence.
- Loading branch information