From c4f32d3b0ffe491f6d004dde788265e7caacb173 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Fri, 19 Apr 2024 19:57:37 +0000 Subject: [PATCH] Remove "eval" prefix from ops --- stablehlo/reference/Ops.cpp | 657 ++++++++++++++++++------------------ stablehlo/reference/Ops.h | 305 ++++++++--------- 2 files changed, 471 insertions(+), 491 deletions(-) diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index df501966ed..f564066ebe 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -55,9 +55,9 @@ Index evalIndex(Tensor tensor) { return result; } -Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs, - const Axes &lhsContractingDimensions, - const Axes &rhsContractingDimensions) { +Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs, + const Axes &lhsContractingDimensions, + const Axes &rhsContractingDimensions) { SmallVector inferredDotGeneralType; if (failed(hlo::inferDotGeneralOp( /*location=*/{}, lhs.getType(), rhs.getType(), @@ -67,16 +67,16 @@ Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs, report_fatal_error( invalidArgument("Could not infer DotGeneralOp's return type")); - return evalDotGeneralOp( - lhs, rhs, /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions*/ {}, - lhsContractingDimensions, rhsContractingDimensions, - RankedTensorType::get(inferredDotGeneralType[0].getDims(), - lhs.getElementType())); + return dotGeneralOp(lhs, rhs, /*lhsBatchingDimensions=*/{}, + /*rhsBatchingDimensions*/ {}, lhsContractingDimensions, + rhsContractingDimensions, + RankedTensorType::get(inferredDotGeneralType[0].getDims(), + lhs.getElementType())); } -Tensor evalPadOp(const Tensor &operand, const Tensor &paddingValue, - const Sizes &edgePaddingLow, const Sizes &edgePaddingHigh, - const Sizes &interiorPadding) { +Tensor padOp(const Tensor &operand, const Tensor &paddingValue, + const Sizes &edgePaddingLow, const Sizes &edgePaddingHigh, + const Sizes &interiorPadding) { SmallVector inferredTypes; Builder builder(operand.getType().getContext()); auto inferStatus = hlo::inferPadOp( @@ -84,14 +84,14 @@ Tensor evalPadOp(const Tensor &operand, const Tensor &paddingValue, edgePaddingHigh, interiorPadding, inferredTypes); if (failed(inferStatus)) report_fatal_error(invalidArgument("Could not infer PadOp's return type")); - return evalPadOp(operand, paddingValue, edgePaddingLow, interiorPadding, - cast(inferredTypes[0])); + return padOp(operand, paddingValue, edgePaddingLow, interiorPadding, + cast(inferredTypes[0])); } -SmallVector evalReduceOp(ArrayRef inputs, - ArrayRef initValues, - const Axes &dimensions, Region &body, - Process *process, Scope &scope) { +SmallVector reduceOp(ArrayRef inputs, + ArrayRef initValues, + const Axes &dimensions, Region &body, + Process *process, Scope &scope) { SmallVector inputTypes; for (const auto &input : inputs) inputTypes.push_back(input.getType()); @@ -114,12 +114,12 @@ SmallVector evalReduceOp(ArrayRef inputs, llvm::report_fatal_error("Could not infer ReduceOp's return type"); resultTypes.push_back(shapedType); } - return evalReduceOp(inputs, initValues, dimensions, body, process, scope, - resultTypes); + return reduceOp(inputs, initValues, dimensions, body, process, scope, + resultTypes); } -Tensor evalSliceOp(const Tensor &operand, const Sizes &startIndices, - const Sizes &limitIndices, const Sizes &strides) { +Tensor sliceOp(const Tensor &operand, const Sizes &startIndices, + const Sizes &limitIndices, const Sizes &strides) { SmallVector inferredTypes; Builder builder(operand.getType().getContext()); auto inferStatus = hlo::inferSliceOp({}, operand.getType(), startIndices, @@ -127,14 +127,14 @@ Tensor evalSliceOp(const Tensor &operand, const Sizes &startIndices, if (failed(inferStatus)) report_fatal_error( invalidArgument("Could not infer SliceOp's return type")); - return evalSliceOp(operand, startIndices, strides, - cast(inferredTypes[0])); + return sliceOp(operand, startIndices, strides, + cast(inferredTypes[0])); } -SmallVector evalCallOp(ArrayRef inputs, - InterpreterFallback *fallback, - Process *process, Operation *op, - StringRef funcName) { +SmallVector callOp(ArrayRef inputs, + InterpreterFallback *fallback, + Process *process, Operation *op, + StringRef funcName) { SymbolTableCollection symbolTableCollection; auto symbolTable = symbolTableCollection.getSymbolTable(op->getParentOfType()); @@ -148,7 +148,7 @@ SmallVector evalCallOp(ArrayRef inputs, // Experimental notation for slices, roughly following the spec notation. // TODO(#1401): Might evolve in the future together with the spec. constexpr int64_t kColon = -1; -Tensor evalSliceOp(const Tensor &operand, const Index &index) { +Tensor sliceOp(const Tensor &operand, const Index &index) { Sizes start, limit; for (auto i = 0; i < operand.getRank(); ++i) { if (index[i] == -1) { @@ -160,7 +160,7 @@ Tensor evalSliceOp(const Tensor &operand, const Index &index) { } } Sizes strides(operand.getRank(), 1); - return evalSliceOp(operand, start, limit, strides); + return sliceOp(operand, start, limit, strides); } Sizes extractElements(ArrayRef arr, ArrayRef indices) { @@ -201,7 +201,7 @@ SmallVector> getReplicaGroups( return replicaGroups; } -Tensor evalConvolutionOp( +Tensor convolutionOp( const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, ArrayRef> padding, ArrayRef lhsDilation, ArrayRef rhsDilation, @@ -235,7 +235,7 @@ Tensor evalConvolutionOp( report_fatal_error( invalidArgument("Could not infer ConvolutionOp's return type")); - return evalConvolutionOp( + return convolutionOp( lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, windowReversal, inputBatchDimension, inputFeatureDimension, inputSpatialDimensions, kernelInputFeatureDimension, @@ -297,9 +297,9 @@ SmallVector split(const Tensor &x, int64_t numResults, Axis axis, inputStartIndices[axis] = constant(i * resultShape[axis], IntegerType::get(context, 64)); - auto result = evalDynamicSliceOp( - x, inputStartIndices, resultShape, - RankedTensorType::get(resultShape, x.getElementType())); + auto result = + dynamicSliceOp(x, inputStartIndices, resultShape, + RankedTensorType::get(resultShape, x.getElementType())); results.push_back(result); } return results; @@ -323,16 +323,16 @@ SmallVector eval(Region ®ion, for (Operation &operation : block) { if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalAbsOp(operand, op.getType()); + auto result = absOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalAddOp(lhs, rhs, op.getType()); + auto result = addOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto inputs = scope.findTokens(op.getInputs()); - auto result = evalAfterAllOp(inputs, op->getContext()); + auto result = afterAllOp(inputs, op->getContext()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); @@ -349,9 +349,9 @@ SmallVector eval(Region ®ion, if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); - auto result = evalAllGatherOp( - operand, op.getAllGatherDim(), replicaGroups, channelId, - op.getUseGlobalDeviceIds(), process, op.getType()); + auto result = + allGatherOp(operand, op.getAllGatherDim(), replicaGroups, channelId, + op.getUseGlobalDeviceIds(), process, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); @@ -361,9 +361,9 @@ SmallVector eval(Region ®ion, if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); - auto result = evalAllReduceOp( - operand, replicaGroups, channelId, op.getUseGlobalDeviceIds(), - op.getComputation(), process, scope, op.getType()); + auto result = allReduceOp(operand, replicaGroups, channelId, + op.getUseGlobalDeviceIds(), op.getComputation(), + process, scope, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); @@ -379,19 +379,19 @@ SmallVector eval(Region ®ion, if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); - auto result = evalAllToAllOp( - operand, op.getSplitDimension(), op.getConcatDimension(), - op.getSplitCount(), replicaGroups, channelId, process, op.getType()); + auto result = allToAllOp(operand, op.getSplitDimension(), + op.getConcatDimension(), op.getSplitCount(), + replicaGroups, channelId, process, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalAndOp(lhs, rhs, op.getType()); + auto result = andOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalAtan2Op(lhs, rhs, op.getType()); + auto result = atan2Op(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (isa(operation)) { failOnDecomposableOp(operation); @@ -401,33 +401,33 @@ SmallVector eval(Region ®ion, failOnDecomposableOp(operation); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalBitcastConvertOp(operand, op.getType()); + auto result = bitcastConvertOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); auto broadcastDimensions = Axes(op.getBroadcastDimensions()); auto result = - evalBroadcastInDimOp(operand, broadcastDimensions, op.getType()); + broadcastInDimOp(operand, broadcastDimensions, op.getType()); scope.add(op.getResult(), result); } else if (isa(operation)) { failOnDecomposableOp(operation); } else if (auto op = dyn_cast(operation)) { auto operands = scope.findTensors(op.getOperands()); auto results = - evalCallOp(operands, fallback, process, &operation, op.getCallee()); + callOp(operands, fallback, process, &operation, op.getCallee()); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto index = scope.findTensor(op.getIndex()); auto branches = op.getBranches(); - auto results = evalCaseOp(index, branches, process, scope); + auto results = caseOp(index, branches, process, scope); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalCbrtOp(operand, op.getType()); + auto result = cbrtOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalCeilOp(operand, op.getType()); + auto result = ceilOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (isa(operation)) { failOnDecomposableOp(operation); @@ -435,11 +435,11 @@ SmallVector eval(Region ®ion, auto min = scope.findTensor(op.getMin()); auto operand = scope.findTensor(op.getOperand()); auto max = scope.findTensor(op.getMax()); - auto result = evalClampOp(min, operand, max, op.getType()); + auto result = clampOp(min, operand, max, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalClzOp(operand, op.getType()); + auto result = clzOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); @@ -457,7 +457,7 @@ SmallVector eval(Region ®ion, channelId = channelHandle->getHandle(); auto result = - evalCollectiveBroadcastOp(operand, replicaGroups, channelId, process); + collectiveBroadcastOp(operand, replicaGroups, channelId, process); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); @@ -476,36 +476,35 @@ SmallVector eval(Region ®ion, if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); - auto result = evalCollectivePermuteOp(operand, sourceTargetPairs, - channelId, process); + auto result = + collectivePermuteOp(operand, sourceTargetPairs, channelId, process); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); auto comparisonDirection = op.getComparisonDirection(); - auto result = evalCompareOp(lhs, rhs, comparisonDirection, op.getType()); + auto result = compareOp(lhs, rhs, comparisonDirection, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalComplexOp(lhs, rhs, op.getType()); + auto result = complexOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operands = scope.findTensors(op.getOperands()); - auto results = evalCallOp(operands, fallback, process, &operation, - op.getDecomposition()); + auto results = callOp(operands, fallback, process, &operation, + op.getDecomposition()); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto operands = scope.findTensors(op.getOperands()); - auto result = - evalConcatenateOp(operands, op.getDimension(), op.getType()); + auto result = concatenateOp(operands, op.getDimension(), op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { - auto result = evalConstantOp(op.getValue()); + auto result = constantOp(op.getValue()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalConvertOp(operand, op.getType()); + auto result = convertOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); @@ -537,7 +536,7 @@ SmallVector eval(Region ®ion, windowReversal = SmallVector(windowReversalAttr.value()); auto dimensionNumbers = op.getDimensionNumbers(); - auto result = evalConvolutionOp( + auto result = convolutionOp( lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, windowReversal, dimensionNumbers.getInputBatchDimension(), dimensionNumbers.getInputFeatureDimension(), @@ -552,7 +551,7 @@ SmallVector eval(Region ®ion, scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalCosineOp(operand, op.getType()); + auto result = cosineOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (isa(operation)) { failOnDecomposableOp(operation); @@ -561,7 +560,7 @@ SmallVector eval(Region ®ion, } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalDivideOp(lhs, rhs, op.getType()); + auto result = divideOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (isa(operation)) { failOnDecomposableOp(operation); @@ -576,7 +575,7 @@ SmallVector eval(Region ®ion, Axes(op.getDotDimensionNumbers().getLhsContractingDimensions()); auto rhsContractingDimensions = Axes(op.getDotDimensionNumbers().getRhsContractingDimensions()); - auto result = evalDotGeneralOp( + auto result = dotGeneralOp( lhs, rhs, lhsBatchingDimensions, rhsBatchingDimensions, lhsContractingDimensions, rhsContractingDimensions, op.getType()); scope.add(op.getResult(), result); @@ -585,33 +584,33 @@ SmallVector eval(Region ®ion, auto startIndices = scope.findTensors(op.getStartIndices()); auto sliceSizes = Sizes(op.getSliceSizes()); auto result = - evalDynamicSliceOp(operand, startIndices, sliceSizes, op.getType()); + dynamicSliceOp(operand, startIndices, sliceSizes, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); auto update = scope.findTensor(op.getUpdate()); auto startIndices = scope.findTensors(op.getStartIndices()); auto result = - evalDynamicUpdateSliceOp(operand, update, startIndices, op.getType()); + dynamicUpdateSliceOp(operand, update, startIndices, op.getType()); scope.add(op.getResult(), result); } else if (isa(operation)) { failOnDecomposableOp(operation); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalExponentialOp(operand, op.getType()); + auto result = exponentialOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalExpm1Op(operand, op.getType()); + auto result = expm1Op(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalFloorOp(operand, op.getType()); + auto result = floorOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); auto startIndices = scope.findTensor(op.getStartIndices()); - auto result = evalGatherOp( + auto result = gatherOp( operand, startIndices, Axes(op.getDimensionNumbers().getOffsetDims()), Axes(op.getDimensionNumbers().getCollapsedSliceDims()), Axes(op.getDimensionNumbers().getStartIndexMap()), @@ -621,119 +620,118 @@ SmallVector eval(Region ®ion, } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); auto dimension = op.getDimension(); - auto result = evalGetDimensionSizeOp(operand, dimension, op.getType()); + auto result = getDimensionSizeOp(operand, dimension, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTuple(op.getOperand()); - auto result = evalGetTupleElementOp(operand, op.getIndex()); + auto result = getTupleElementOp(operand, op.getIndex()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto pred = scope.findTensor(op.getPred()); auto &trueBranch = op.getTrueBranch(); auto &falseBranch = op.getFalseBranch(); - auto results = evalIfOp(pred, trueBranch, falseBranch, process, scope); + auto results = ifOp(pred, trueBranch, falseBranch, process, scope); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalImagOp(operand, op.getType()); + auto result = imagOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto token = scope.findToken(op.getToken()); - auto results = evalInfeedOp(token, process, region, scope); + auto results = infeedOp(token, process, region, scope); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto iotaDimension = op.getIotaDimension(); - auto result = evalIotaOp(iotaDimension, op.getType()); + auto result = iotaOp(iotaDimension, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalIsFiniteOp(operand, op.getType()); + auto result = isFiniteOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalLog1pOp(operand, op.getType()); + auto result = log1pOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalLogOp(operand, op.getType()); + auto result = logOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalLogisticOp(operand, op.getType()); + auto result = logisticOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto inputs = scope.findTensors(op.getInputs()); auto &computation = op.getComputation(); - auto result = - evalMapOp(inputs, computation, process, scope, op.getType()); + auto result = mapOp(inputs, computation, process, scope, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalMaxOp(lhs, rhs, op.getType()); + auto result = maxOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalMinOp(lhs, rhs, op.getType()); + auto result = minOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalMultiplyOp(lhs, rhs, op.getType()); + auto result = multiplyOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalNegOp(operand, op.getType()); + auto result = negOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalNotOp(operand, op.getType()); + auto result = notOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.find(op.getOperand()); - auto results = evalOptimizationBarrierOp(operand); + auto results = optimizationBarrierOp(operand); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalOrOp(lhs, rhs, op.getType()); + auto result = orOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto inputs = scope.findTensors(op.getInputs()); auto token = scope.findToken(op.getToken()); - auto result = evalOutfeedOp(inputs, token, process); + auto result = outfeedOp(inputs, token, process); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); auto paddingValue = scope.findTensor(op.getPaddingValue()); auto edgePaddingLow = Sizes(op.getEdgePaddingLow()); auto interiorPadding = Sizes(op.getInteriorPadding()); - auto result = evalPadOp(operand, paddingValue, edgePaddingLow, - interiorPadding, op.getType()); + auto result = padOp(operand, paddingValue, edgePaddingLow, + interiorPadding, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { - auto result = evalPartitionIdOp(process, op.getContext()); + auto result = partitionIdOp(process, op.getContext()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalPopulationCountOp(operand, op.getType()); + auto result = populationCountOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalPowerOp(lhs, rhs, op.getType()); + auto result = powerOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalRealOp(operand, op.getType()); + auto result = realOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto token = scope.findToken(op.getToken()); ChannelId channelId = 0; if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle.getHandle(); - auto results = evalRecvOp(token, channelId, process); + auto results = recvOp(token, channelId, process); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto inputs = scope.findTensors(op.getInputs()); @@ -741,15 +739,15 @@ SmallVector eval(Region ®ion, SmallVector resultTypes; for (auto resultType : op.getResultTypes()) resultTypes.push_back(cast(resultType)); - auto results = evalReduceOp(inputs, initValues, Axes(op.getDimensions()), - op.getBody(), process, scope, resultTypes); + auto results = reduceOp(inputs, initValues, Axes(op.getDimensions()), + op.getBody(), process, scope, resultTypes); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); int32_t exponentBits = op.getExponentBits(); int32_t mantissaBits = op.getMantissaBits(); - auto result = evalReducePrecisionOp(operand, exponentBits, mantissaBits, - op.getType()); + auto result = + reducePrecisionOp(operand, exponentBits, mantissaBits, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); @@ -760,10 +758,10 @@ SmallVector eval(Region ®ion, if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle->getHandle(); - auto result = evalReduceScatterOp( - operand, scatterDimension, replicaGroups, channelId, - op.getUseGlobalDeviceIds(), op.getComputation(), process, scope, - op.getType()); + auto result = + reduceScatterOp(operand, scatterDimension, replicaGroups, channelId, + op.getUseGlobalDeviceIds(), op.getComputation(), + process, scope, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto inputs = scope.findTensors(op.getInputs()); @@ -797,7 +795,7 @@ SmallVector eval(Region ®ion, for (auto resultType : op.getResultTypes()) resultTypes.push_back(cast(resultType)); - auto results = evalReduceWindowOp( + auto results = reduceWindowOp( inputs, initValues, Sizes(op.getWindowDimensions()), windowStrides, baseDilations, windowDilations, paddingLow, paddingHigh, op.getBody(), process, scope, resultTypes); @@ -805,14 +803,14 @@ SmallVector eval(Region ®ion, } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalRemOp(lhs, rhs, op.getType()); + auto result = remOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { - auto result = evalReplicaIdOp(process, op.getContext()); + auto result = replicaIdOp(process, op.getContext()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalReshapeOp(operand, op.getType()); + auto result = reshapeOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { return scope.find(op.getOperands()); @@ -821,7 +819,7 @@ SmallVector eval(Region ®ion, } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); auto dimensions = Axes(op.getDimensions()); - auto result = evalReverseOp(operand, dimensions, op.getType()); + auto result = reverseOp(operand, dimensions, op.getType()); scope.add(op.getResult(), result); } else if (isa(operation)) { failOnDecomposableOp(operation); @@ -829,15 +827,15 @@ SmallVector eval(Region ®ion, failOnDecomposableOp(operation); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalRoundNearestEvenOp(operand, op.getType()); + auto result = roundNearestEvenOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalRoundOp(operand, op.getType()); + auto result = roundOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalRsqrtOp(operand, op.getType()); + auto result = rsqrtOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto inputs = scope.findTensors(op.getInputs()); @@ -851,10 +849,10 @@ SmallVector eval(Region ®ion, Axis indexVectorDim(scatterDimensionNumbers.getIndexVectorDim()); auto &updateComputation = op.getUpdateComputation(); SmallVector resultTypes(op->getResultTypes()); - auto results = evalScatterOp( - inputs, scatterIndices, updates, updateWindowDims, insertedWindowDims, - scatterDimsToOperandDims, indexVectorDim, updateComputation, process, - scope, resultTypes); + auto results = scatterOp(inputs, scatterIndices, updates, + updateWindowDims, insertedWindowDims, + scatterDimsToOperandDims, indexVectorDim, + updateComputation, process, scope, resultTypes); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); @@ -883,15 +881,15 @@ SmallVector eval(Region ®ion, } auto result = - evalSelectAndScatterOp(operand, source, initValue, windowDimensions, - windowStrides, paddingLow, op.getSelect(), - op.getScatter(), process, scope, op.getType()); + selectAndScatterOp(operand, source, initValue, windowDimensions, + windowStrides, paddingLow, op.getSelect(), + op.getScatter(), process, scope, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto pred = scope.findTensor(op.getPred()); auto onTrue = scope.findTensor(op.getOnTrue()); auto onFalse = scope.findTensor(op.getOnFalse()); - auto result = evalSelectOp(pred, onTrue, onFalse, op.getType()); + auto result = selectOp(pred, onTrue, onFalse, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto inputs = scope.findTensors(op.getInputs()); @@ -899,36 +897,36 @@ SmallVector eval(Region ®ion, ChannelId channelId = 0; if (auto channelHandle = op.getChannelHandle()) channelId = channelHandle.getHandle(); - auto result = evalSendOp(inputs, token, channelId, process); + auto result = sendOp(inputs, token, channelId, process); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalShiftLeftOp(lhs, rhs, op.getType()); + auto result = shiftLeftOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalShiftRightArithmeticOp(lhs, rhs, op.getType()); + auto result = shiftRightArithmeticOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalShiftRightLogicalOp(lhs, rhs, op.getType()); + auto result = shiftRightLogicalOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalSignOp(operand, op.getType()); + auto result = signOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalSineOp(operand, op.getType()); + auto result = sineOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); auto startIndices = Sizes(op.getStartIndices()); auto strides = Sizes(op.getStrides()); - auto result = evalSliceOp(operand, startIndices, strides, op.getType()); + auto result = sliceOp(operand, startIndices, strides, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operands = scope.findTensors(op.getInputs()); @@ -936,20 +934,20 @@ SmallVector eval(Region ®ion, auto isStable = op.getIsStable(); auto &comparator = op.getComparator(); auto results = - evalSortOp(operands, dimension, isStable, comparator, process, scope); + sortOp(operands, dimension, isStable, comparator, process, scope); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalSqrtOp(operand, op.getType()); + auto result = sqrtOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalSubtractOp(lhs, rhs, op.getType()); + auto result = subtractOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); - auto result = evalTanhOp(operand, op.getType()); + auto result = tanhOp(operand, op.getType()); scope.add(op.getResult(), result); } else if (isa(operation)) { failOnDecomposableOp(operation); @@ -958,13 +956,13 @@ SmallVector eval(Region ®ion, } else if (auto op = dyn_cast(operation)) { auto operand = scope.findTensor(op.getOperand()); auto permutation = Axes(op.getPermutation()); - auto result = evalTransposeOp(operand, permutation, op.getType()); + auto result = transposeOp(operand, permutation, op.getType()); scope.add(op.getResult(), result); } else if (isa(operation)) { failOnDecomposableOp(operation); } else if (auto op = dyn_cast(operation)) { auto val = scope.find(op.getVal()); - auto result = evalTupleOp(val, cast(op.getType())); + auto result = tupleOp(val, cast(op.getType())); scope.add(op.getResult(), result); } else if (isa(operation)) { failOnDecomposableOp(operation); @@ -972,12 +970,12 @@ SmallVector eval(Region ®ion, auto operand = scope.find(op.getOperand()); auto &cond = op.getCond(); auto &body = op.getBody(); - auto results = evalWhileOp(operand, cond, body, fallback, process, scope); + auto results = whileOp(operand, cond, body, fallback, process, scope); scope.add(op.getResults(), results); } else if (auto op = dyn_cast(operation)) { auto lhs = scope.findTensor(op.getLhs()); auto rhs = scope.findTensor(op.getRhs()); - auto result = evalXorOp(lhs, rhs, op.getType()); + auto result = xorOp(lhs, rhs, op.getType()); scope.add(op.getResult(), result); } else { if (!fallback) @@ -991,28 +989,28 @@ SmallVector eval(Region ®ion, llvm::report_fatal_error("Expected a terminator when evaluating a region"); } -Tensor evalAbsOp(const Tensor &operand, ShapedType resultType) { +Tensor absOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, abs(operand.get(*it))); return result; } -Tensor evalAddOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { +Tensor addOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, lhs.get(*it) + rhs.get(*it)); return result; } -Token evalAfterAllOp(ArrayRef inputs, MLIRContext *context) { +Token afterAllOp(ArrayRef inputs, MLIRContext *context) { return Token(context); } -Tensor evalAllGatherOp(const Tensor &operand, int64_t allGatherDim, - SmallVector> replicaGroups, - ChannelId channelId, bool useGlobalDeviceIds, - Process *process, ShapedType resultType) { +Tensor allGatherOp(const Tensor &operand, int64_t allGatherDim, + SmallVector> replicaGroups, + ChannelId channelId, bool useGlobalDeviceIds, + Process *process, ShapedType resultType) { if (!process) llvm::report_fatal_error( "all_gather is only supported when run via interpreter.run_parallel"); @@ -1037,14 +1035,14 @@ Tensor evalAllGatherOp(const Tensor &operand, int64_t allGatherDim, *processGroup, [&](const ProcessId &id) { return rendezvousResult.lookup(id); })); - return evalConcatenateOp(groupOperands, allGatherDim, resultType); + return concatenateOp(groupOperands, allGatherDim, resultType); } -Tensor evalAllReduceOp(const Tensor &operand, - SmallVector> replicaGroups, - ChannelId channelId, bool useGlobalDeviceIds, - Region &computation, Process *process, Scope &scope, - ShapedType resultType) { +Tensor allReduceOp(const Tensor &operand, + SmallVector> replicaGroups, + ChannelId channelId, bool useGlobalDeviceIds, + Region &computation, Process *process, Scope &scope, + ShapedType resultType) { if (!process) llvm::report_fatal_error( "all_reduce is only supported when run via interpreter.run_parallel"); @@ -1085,11 +1083,11 @@ Tensor evalAllReduceOp(const Tensor &operand, return result; } -Tensor evalAllToAllOp(const Tensor &operand, Axis splitDimension, - Axis concatDimension, int64_t splitCount, - SmallVector> replicaGroups, - ChannelId channelId, Process *process, - ShapedType resultType) { +Tensor allToAllOp(const Tensor &operand, Axis splitDimension, + Axis concatDimension, int64_t splitCount, + SmallVector> replicaGroups, + ChannelId channelId, Process *process, + ShapedType resultType) { if (!process) llvm::report_fatal_error( "all_to_all is only supported when run via interpreter.run_parallel"); @@ -1115,25 +1113,24 @@ Tensor evalAllToAllOp(const Tensor &operand, Axis splitDimension, if (processId == process->getId()) scatteredParts.push_back(splitParts[i]); } - return evalConcatenateOp(scatteredParts, concatDimension, resultType); + return concatenateOp(scatteredParts, concatDimension, resultType); } -Tensor evalAndOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { +Tensor andOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, lhs.get(*it) & rhs.get(*it)); return result; } -Tensor evalAtan2Op(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType) { +Tensor atan2Op(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, atan2(lhs.get(*it), rhs.get(*it))); return result; } -Tensor evalBitcastConvertOp(const Tensor &operand, ShapedType resultType) { +Tensor bitcastConvertOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); auto resultElementType = result.getElementType(); @@ -1171,9 +1168,8 @@ Tensor evalBitcastConvertOp(const Tensor &operand, ShapedType resultType) { return result; } -Tensor evalBroadcastInDimOp(const Tensor &operand, - const Axes &broadcastDimensions, - ShapedType resultType) { +Tensor broadcastInDimOp(const Tensor &operand, const Axes &broadcastDimensions, + ShapedType resultType) { Tensor result(resultType); for (auto resultIt = result.index_begin(); resultIt != result.index_end(); ++resultIt) { @@ -1188,9 +1184,8 @@ Tensor evalBroadcastInDimOp(const Tensor &operand, return result; } -SmallVector evalCaseOp(const Tensor &index, - RegionRange branches, Process *process, - Scope &scope) { +SmallVector caseOp(const Tensor &index, RegionRange branches, + Process *process, Scope &scope) { int64_t indexValue = index.get({}).getIntegerValue().getSExtValue(); if (indexValue < 0 || indexValue >= static_cast(branches.size())) indexValue = branches.size() - 1; @@ -1198,22 +1193,22 @@ SmallVector evalCaseOp(const Tensor &index, return eval(*branches[indexValue], {}, /*fallback=*/nullptr, process, &scope); } -Tensor evalCbrtOp(const Tensor &operand, ShapedType resultType) { +Tensor cbrtOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, cbrt(operand.get(*it))); return result; } -Tensor evalCeilOp(const Tensor &operand, ShapedType resultType) { +Tensor ceilOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, ceil(operand.get(*it))); return result; } -Tensor evalClampOp(const Tensor &min, const Tensor &operand, const Tensor &max, - ShapedType resultType) { +Tensor clampOp(const Tensor &min, const Tensor &operand, const Tensor &max, + ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) { Element minElement = min.getRank() != 0 ? min.get(*it) : min.get({}); @@ -1224,7 +1219,7 @@ Tensor evalClampOp(const Tensor &min, const Tensor &operand, const Tensor &max, return result; } -Tensor evalClzOp(const Tensor &operand, ShapedType resultType) { +Tensor clzOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) { auto element = @@ -1236,9 +1231,9 @@ Tensor evalClzOp(const Tensor &operand, ShapedType resultType) { return result; } -Tensor evalCollectiveBroadcastOp( - const Tensor &operand, SmallVector> replicaGroups, - ChannelId channelId, Process *process) { +Tensor collectiveBroadcastOp(const Tensor &operand, + SmallVector> replicaGroups, + ChannelId channelId, Process *process) { if (!process) llvm::report_fatal_error( "collective_broadcast is only supported when run via " @@ -1253,13 +1248,13 @@ Tensor evalCollectiveBroadcastOp( return process->rendezvous(*processGroup, channelId, operand) .lookup((*processGroup)[0]); - return evalBroadcastInDimOp(constant(0.0, operand.getElementType()), {}, - operand.getType()); + return broadcastInDimOp(constant(0.0, operand.getElementType()), {}, + operand.getType()); } -Tensor evalCollectivePermuteOp( - const Tensor &operand, SmallVector> sourceTargetPairs, - ChannelId channelId, Process *process) { +Tensor collectivePermuteOp(const Tensor &operand, + SmallVector> sourceTargetPairs, + ChannelId channelId, Process *process) { if (!process) llvm::report_fatal_error( "collective_permute is only supported when run via " @@ -1282,13 +1277,13 @@ Tensor evalCollectivePermuteOp( } if (result) return result; - return evalBroadcastInDimOp(constant(0.0, operand.getElementType()), {}, - operand.getType()); + return broadcastInDimOp(constant(0.0, operand.getElementType()), {}, + operand.getType()); } -Tensor evalCompareOp(const Tensor &lhs, const Tensor &rhs, - ComparisonDirection comparisonDirection, - ShapedType resultType) { +Tensor compareOp(const Tensor &lhs, const Tensor &rhs, + ComparisonDirection comparisonDirection, + ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) { switch (comparisonDirection) { @@ -1315,16 +1310,15 @@ Tensor evalCompareOp(const Tensor &lhs, const Tensor &rhs, return result; } -Tensor evalComplexOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType) { +Tensor complexOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, complex(lhs.get(*it), rhs.get(*it))); return result; } -Tensor evalConcatenateOp(ArrayRef inputs, Axis dimension, - ShapedType resultType) { +Tensor concatenateOp(ArrayRef inputs, Axis dimension, + ShapedType resultType) { Tensor result(resultType); int64_t dimensionOffset = 0; for (const auto &input : inputs) { @@ -1340,18 +1334,18 @@ Tensor evalConcatenateOp(ArrayRef inputs, Axis dimension, return result; } -Tensor evalConstantOp(ElementsAttr value) { +Tensor constantOp(ElementsAttr value) { return makeTensor(cast(value)); } -Tensor evalConvertOp(const Tensor &operand, ShapedType resultType) { +Tensor convertOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, convert(result.getElementType(), operand.get(*it))); return result; } -Tensor evalConvolutionOp( +Tensor convolutionOp( const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, ArrayRef> padding, ArrayRef lhsDilation, ArrayRef rhsDilation, @@ -1370,7 +1364,7 @@ Tensor evalConvolutionOp( resultType.getContext()); SmallVector results; for (auto [left, right] : llvm::zip(lhses, rhses)) - results.push_back(evalConvolutionOp( + results.push_back(convolutionOp( left, right, windowStrides, padding, lhsDilation, rhsDilation, windowReversal, inputBatchDimension, inputFeatureDimension, inputSpatialDimensions, kernelInputFeatureDimension, @@ -1379,7 +1373,7 @@ Tensor evalConvolutionOp( /*featureGroupCount=*/1, batchGroupCount, /*precisionConfig=*/{}, resultType)); - return evalConcatenateOp(results, outputFeatureDimension, result.getType()); + return concatenateOp(results, outputFeatureDimension, result.getType()); } if (batchGroupCount > 1) { @@ -1389,7 +1383,7 @@ Tensor evalConvolutionOp( resultType.getContext()); SmallVector results; for (auto [left, right] : llvm::zip(lhses, rhses)) - results.push_back(evalConvolutionOp( + results.push_back(convolutionOp( left, right, windowStrides, padding, lhsDilation, rhsDilation, windowReversal, inputBatchDimension, inputFeatureDimension, inputSpatialDimensions, kernelInputFeatureDimension, @@ -1398,7 +1392,7 @@ Tensor evalConvolutionOp( featureGroupCount, /*batchGroupCount=*/1, /*precisionConfig=*/{}, resultType)); - return evalConcatenateOp(results, outputFeatureDimension, result.getType()); + return concatenateOp(results, outputFeatureDimension, result.getType()); } Axes lhsPermutation; @@ -1429,8 +1423,8 @@ Tensor evalConvolutionOp( } auto paddingValue = constant(0.0, result.getElementType()); - auto paddedLhs = evalPadOp(lhs, paddingValue, lhsPaddingLow, lhsPaddingHigh, - Sizes(lhsBaseDilations)); + auto paddedLhs = padOp(lhs, paddingValue, lhsPaddingLow, lhsPaddingHigh, + Sizes(lhsBaseDilations)); IndexSpaceIterator outputSpatialIndexIt( extractElements(result.getShape(), outputSpatialDimensions), @@ -1450,14 +1444,14 @@ Tensor evalConvolutionOp( lhsWindowStart[i] + lhsWindowDimensions[i] * lhsWindowDilations[i], paddedLhs.getShape()[i])); - auto lhsWindow = evalSliceOp(paddedLhs, lhsWindowStart, limitIndices, - Sizes(lhsWindowDilations)); + auto lhsWindow = sliceOp(paddedLhs, lhsWindowStart, limitIndices, + Sizes(lhsWindowDilations)); Axes reverseDims; for (auto [i, isReverse] : llvm::enumerate(windowReversal)) if (isReverse) reverseDims.push_back(inputSpatialDimensions[i]); auto reversedLhsWindow = - evalReverseOp(lhsWindow, reverseDims, lhsWindow.getType()); + reverseOp(lhsWindow, reverseDims, lhsWindow.getType()); Axes lhsContractingDimensions(inputSpatialDimensions); lhsContractingDimensions.push_back(inputFeatureDimension); @@ -1466,8 +1460,8 @@ Tensor evalConvolutionOp( rhsContractingDimensions.push_back(kernelInputFeatureDimension); auto dotProduct = - evalDotGeneralOp(reversedLhsWindow, rhs, lhsContractingDimensions, - rhsContractingDimensions); + dotGeneralOp(reversedLhsWindow, rhs, lhsContractingDimensions, + rhsContractingDimensions); Sizes resultNonSpatialDims; for (auto i = 0; i < result.getRank(); ++i) @@ -1495,27 +1489,26 @@ Tensor evalConvolutionOp( return result; } -Tensor evalCosineOp(const Tensor &operand, ShapedType resultType) { +Tensor cosineOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, cosine(operand.get(*it))); return result; } -Tensor evalDivideOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType) { +Tensor divideOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, lhs.get(*it) / rhs.get(*it)); return result; } -Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs, - const Axes &lhsBatchingDimensions, - const Axes &rhsBatchingDimensions, - const Axes &lhsContractingDimensions, - const Axes &rhsContractingDimensions, - ShapedType resultType) { +Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs, + const Axes &lhsBatchingDimensions, + const Axes &rhsBatchingDimensions, + const Axes &lhsContractingDimensions, + const Axes &rhsContractingDimensions, + ShapedType resultType) { Tensor result(resultType); Axes lhsResultDims; for (auto i = 0; i < lhs.getType().getRank(); ++i) @@ -1583,8 +1576,8 @@ Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs, return result; } -Tensor evalDynamicSliceOp(const Tensor &operand, ArrayRef startIndices, - const Sizes &sliceSizes, ShapedType resultType) { +Tensor dynamicSliceOp(const Tensor &operand, ArrayRef startIndices, + const Sizes &sliceSizes, ShapedType resultType) { Tensor result(resultType); auto adjustedStartIndices = clamp(0, evalIndex(startIndices), operand.getShape() - sliceSizes); @@ -1597,9 +1590,9 @@ Tensor evalDynamicSliceOp(const Tensor &operand, ArrayRef startIndices, return result; } -Tensor evalDynamicUpdateSliceOp(const Tensor &operand, const Tensor &update, - ArrayRef startIndices, - ShapedType resultType) { +Tensor dynamicUpdateSliceOp(const Tensor &operand, const Tensor &update, + ArrayRef startIndices, + ShapedType resultType) { Tensor result(resultType); auto adjustedStartIndices = clamp(0, evalIndex(startIndices), operand.getShape() - update.getShape()); @@ -1615,32 +1608,32 @@ Tensor evalDynamicUpdateSliceOp(const Tensor &operand, const Tensor &update, return result; } -Tensor evalExpm1Op(const Tensor &operand, ShapedType resultType) { +Tensor expm1Op(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, exponentialMinusOne(operand.get(*it))); return result; } -Tensor evalExponentialOp(const Tensor &operand, ShapedType resultType) { +Tensor exponentialOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, exponential(operand.get(*it))); return result; } -Tensor evalFloorOp(const Tensor &operand, ShapedType resultType) { +Tensor floorOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, floor(operand.get(*it))); return result; } -Tensor evalGatherOp(const Tensor &operand, const Tensor &startIndices, - const Axes &offsetDims, const Axes &collapsedSliceDims, - const Axes &startIndexMap, Axis indexVectorDim, - const Sizes &sliceSizes, bool indicesAreSorted, - ShapedType resultType) { +Tensor gatherOp(const Tensor &operand, const Tensor &startIndices, + const Axes &offsetDims, const Axes &collapsedSliceDims, + const Axes &startIndexMap, Axis indexVectorDim, + const Sizes &sliceSizes, bool indicesAreSorted, + ShapedType resultType) { Tensor result(resultType); Axes batchDims; for (auto d : result.getAxes()) @@ -1657,7 +1650,7 @@ Tensor evalGatherOp(const Tensor &operand, const Tensor &startIndices, if (indexVectorDim < startIndices.getRank()) startIndicesIndex.insert(startIndicesIndex.begin() + indexVectorDim, kColon); - auto startIndex = evalIndex(evalSliceOp(startIndices, startIndicesIndex)); + auto startIndex = evalIndex(sliceOp(startIndices, startIndicesIndex)); Index fullStartIndex(operand.getRank(), 0); for (auto dOperand : operand.getAxes()) { @@ -1684,35 +1677,35 @@ Tensor evalGatherOp(const Tensor &operand, const Tensor &startIndices, return result; } -Tensor evalGetDimensionSizeOp(const Tensor &operand, Axis dimension, - ShapedType resultType) { +Tensor getDimensionSizeOp(const Tensor &operand, Axis dimension, + ShapedType resultType) { Tensor result(resultType); result.set( {}, convert(resultType.getElementType(), operand.getShape()[dimension])); return result; } -InterpreterValue evalGetTupleElementOp(const Tuple &operand, int32_t index) { +InterpreterValue getTupleElementOp(const Tuple &operand, int32_t index) { return operand.get(index); } -SmallVector evalIfOp(const Tensor &pred, Region &trueBranch, - Region &falseBranch, Process *process, - Scope &scope) { +SmallVector ifOp(const Tensor &pred, Region &trueBranch, + Region &falseBranch, Process *process, + Scope &scope) { return pred.get({}).getBooleanValue() ? eval(trueBranch, {}, /*fallback=*/nullptr, process, &scope) : eval(falseBranch, {}, /*fallback=*/nullptr, process, &scope); } -Tensor evalImagOp(const Tensor &operand, ShapedType resultType) { +Tensor imagOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, imag(operand.get(*it))); return result; } -SmallVector evalInfeedOp(Token token, Process *process, - Region ®ion, Scope &scope) { +SmallVector infeedOp(Token token, Process *process, + Region ®ion, Scope &scope) { if (!process) llvm::report_fatal_error( "infeed is only supported when run via interpreter.run_parallel"); @@ -1726,7 +1719,7 @@ SmallVector evalInfeedOp(Token token, Process *process, return results; } -Tensor evalIotaOp(Axis iotaDimension, ShapedType resultType) { +Tensor iotaOp(Axis iotaDimension, ShapedType resultType) { Tensor result(resultType); auto elementType = result.getElementType(); for (auto it = result.index_begin(); it != result.index_end(); ++it) @@ -1734,36 +1727,36 @@ Tensor evalIotaOp(Axis iotaDimension, ShapedType resultType) { return result; } -Tensor evalIsFiniteOp(const Tensor &operand, ShapedType resultType) { +Tensor isFiniteOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, isFinite(operand.get(*it))); return result; } -Tensor evalLog1pOp(const Tensor &operand, ShapedType resultType) { +Tensor log1pOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, logPlusOne(operand.get(*it))); return result; } -Tensor evalLogOp(const Tensor &operand, ShapedType resultType) { +Tensor logOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, log(operand.get(*it))); return result; } -Tensor evalLogisticOp(const Tensor &operand, ShapedType resultType) { +Tensor logisticOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, logistic(operand.get(*it))); return result; } -Tensor evalMapOp(ArrayRef inputs, Region &computation, Process *process, - Scope &scope, ShapedType resultType) { +Tensor mapOp(ArrayRef inputs, Region &computation, Process *process, + Scope &scope, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) { SmallVector args; @@ -1780,55 +1773,54 @@ Tensor evalMapOp(ArrayRef inputs, Region &computation, Process *process, return result; } -Tensor evalMaxOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { +Tensor maxOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, max(lhs.get(*it), rhs.get(*it))); return result; } -Tensor evalMinOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { +Tensor minOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, min(lhs.get(*it), rhs.get(*it))); return result; } -Tensor evalMultiplyOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType) { +Tensor multiplyOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, lhs.get(*it) * rhs.get(*it)); return result; } -Tensor evalNegOp(const Tensor &operand, ShapedType resultType) { +Tensor negOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, -operand.get(*it)); return result; } -Tensor evalNotOp(const Tensor &operand, ShapedType resultType) { +Tensor notOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, ~operand.get(*it)); return result; } -SmallVector evalOptimizationBarrierOp( +SmallVector optimizationBarrierOp( ArrayRef operand) { return SmallVector(operand); } -Tensor evalOrOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { +Tensor orOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, lhs.get(*it) | rhs.get(*it)); return result; } -Token evalOutfeedOp(ArrayRef inputs, Token token, Process *process) { +Token outfeedOp(ArrayRef inputs, Token token, Process *process) { if (!process) llvm::report_fatal_error( "outfeed is only supported when run via interpreter.run_parallel"); @@ -1837,9 +1829,9 @@ Token evalOutfeedOp(ArrayRef inputs, Token token, Process *process) { return token; } -Tensor evalPadOp(const Tensor &operand, const Tensor &paddingValue, - const Sizes &edgePaddingLow, const Sizes &interiorPadding, - ShapedType resultType) { +Tensor padOp(const Tensor &operand, const Tensor &paddingValue, + const Sizes &edgePaddingLow, const Sizes &interiorPadding, + ShapedType resultType) { auto result = makeSplat(resultType, paddingValue.get({})); for (auto operandIt = operand.index_begin(); operandIt != operand.index_end(); ++operandIt) { @@ -1853,7 +1845,7 @@ Tensor evalPadOp(const Tensor &operand, const Tensor &paddingValue, return result; } -Tensor evalPartitionIdOp(Process *process, MLIRContext *context) { +Tensor partitionIdOp(Process *process, MLIRContext *context) { if (!process) llvm::report_fatal_error( "partition_id is only supported when run via interpreter.run_parallel"); @@ -1862,41 +1854,40 @@ Tensor evalPartitionIdOp(Process *process, MLIRContext *context) { return constant(APInt(32, partitionId), elementType); } -Tensor evalPopulationCountOp(const Tensor &operand, ShapedType resultType) { +Tensor populationCountOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, popcnt(operand.get(*it))); return result; } -Tensor evalPowerOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType) { +Tensor powerOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, power(lhs.get(*it), rhs.get(*it))); return result; } -Tensor evalRealOp(const Tensor &operand, ShapedType resultType) { +Tensor realOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, real(operand.get(*it))); return result; } -SmallVector evalRecvOp(Token token, ChannelId channelId, - Process *process) { +SmallVector recvOp(Token token, ChannelId channelId, + Process *process) { SmallVector results; for (const auto &tensor : process->recv(channelId)) results.push_back(tensor); results.push_back(token); return results; } -SmallVector evalReduceOp(ArrayRef inputs, - ArrayRef initValues, - const Axes &dimensions, Region &body, - Process *process, Scope &scope, - ArrayRef resultTypes) { +SmallVector reduceOp(ArrayRef inputs, + ArrayRef initValues, + const Axes &dimensions, Region &body, + Process *process, Scope &scope, + ArrayRef resultTypes) { SmallVector results; for (auto [resultType, initValue] : llvm::zip(resultTypes, initValues)) results.push_back(makeSplat(resultType, initValue.get({}))); @@ -1925,8 +1916,8 @@ SmallVector evalReduceOp(ArrayRef inputs, return results; } -Tensor evalReducePrecisionOp(const Tensor &operand, int32_t exponentBits, - int32_t mantissaBits, ShapedType resultType) { +Tensor reducePrecisionOp(const Tensor &operand, int32_t exponentBits, + int32_t mantissaBits, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, @@ -1934,11 +1925,11 @@ Tensor evalReducePrecisionOp(const Tensor &operand, int32_t exponentBits, return result; } -Tensor evalReduceScatterOp(const Tensor &operand, int64_t scatterDimension, - SmallVector> replicaGroups, - ChannelId channelId, bool useGlobalDeviceIds, - Region ®ion, Process *process, Scope &scope, - ShapedType returnType) { +Tensor reduceScatterOp(const Tensor &operand, int64_t scatterDimension, + SmallVector> replicaGroups, + ChannelId channelId, bool useGlobalDeviceIds, + Region ®ion, Process *process, Scope &scope, + ShapedType returnType) { if (!process) llvm::report_fatal_error( "reduce_scatter is only supported when run via " @@ -1959,8 +1950,8 @@ Tensor evalReduceScatterOp(const Tensor &operand, int64_t scatterDimension, process->getId().replicaId, process->getId().partitionId)); auto reducedValue = - evalAllReduceOp(operand, replicaGroups, channelId, useGlobalDeviceIds, - region, process, scope, operand.getType()); + allReduceOp(operand, replicaGroups, channelId, useGlobalDeviceIds, region, + process, scope, operand.getType()); auto parts = split(reducedValue, processGroups[0].size(), scatterDimension, operand.getType().getContext()); @@ -1976,7 +1967,7 @@ Tensor evalReduceScatterOp(const Tensor &operand, int64_t scatterDimension, return result; } -SmallVector evalReduceWindowOp( +SmallVector reduceWindowOp( ArrayRef inputs, ArrayRef initValues, const Sizes &windowDimensions, const Sizes &windowStrides, const Sizes &baseDilations, const Sizes &windowDilations, @@ -1988,8 +1979,8 @@ SmallVector evalReduceWindowOp( SmallVector paddedInputs; for (auto [input, initValue] : llvm::zip(inputs, initValues)) - paddedInputs.push_back(evalPadOp(input, initValue, paddingLow, paddingHigh, - baseDilations - 1)); + paddedInputs.push_back( + padOp(input, initValue, paddingLow, paddingHigh, baseDilations - 1)); for (auto resultIt = results[0].index_begin(); resultIt != results[0].index_end(); ++resultIt) { SmallVector windows; @@ -1997,24 +1988,24 @@ SmallVector evalReduceWindowOp( auto windowEnd = windowStart + (windowDimensions - 1) * windowDilations + 1; for (const auto &paddedInput : paddedInputs) windows.push_back( - evalSliceOp(paddedInput, windowStart, windowEnd, windowDilations)); + sliceOp(paddedInput, windowStart, windowEnd, windowDilations)); - auto reducedValues = evalReduceOp(windows, initValues, inputs[0].getAxes(), - body, process, scope); + auto reducedValues = reduceOp(windows, initValues, inputs[0].getAxes(), + body, process, scope); for (auto [result, value] : llvm::zip(results, reducedValues)) result.set(*resultIt, value.get({})); } return results; } -Tensor evalRemOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { +Tensor remOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, rem(lhs.get(*it), rhs.get(*it))); return result; } -Tensor evalReplicaIdOp(Process *process, MLIRContext *context) { +Tensor replicaIdOp(Process *process, MLIRContext *context) { if (!process) llvm::report_fatal_error( "replica_id is only supported when run via interpreter.run_parallel"); @@ -2023,7 +2014,7 @@ Tensor evalReplicaIdOp(Process *process, MLIRContext *context) { return constant(APInt(32, replicaId), elementType); } -Tensor evalReshapeOp(const Tensor &operand, ShapedType resultType) { +Tensor reshapeOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto resultIt = result.index_begin(), operandIt = operand.index_begin(); resultIt != result.index_end(); ++resultIt, ++operandIt) { @@ -2034,8 +2025,8 @@ Tensor evalReshapeOp(const Tensor &operand, ShapedType resultType) { return result; } -Tensor evalReverseOp(const Tensor &operand, const Axes &dimensions, - ShapedType resultType) { +Tensor reverseOp(const Tensor &operand, const Axes &dimensions, + ShapedType resultType) { Tensor result(resultType); for (auto resultIt = result.index_begin(); resultIt != result.index_end(); ++resultIt) { @@ -2048,28 +2039,28 @@ Tensor evalReverseOp(const Tensor &operand, const Axes &dimensions, return result; } -Tensor evalRoundNearestEvenOp(const Tensor &operand, ShapedType resultType) { +Tensor roundNearestEvenOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, roundNearestEven(operand.get(*it))); return result; } -Tensor evalRoundOp(const Tensor &operand, ShapedType resultType) { +Tensor roundOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, roundNearestAfz(operand.get(*it))); return result; } -Tensor evalRsqrtOp(const Tensor &operand, ShapedType resultType) { +Tensor rsqrtOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, rsqrt(operand.get(*it))); return result; } -SmallVector evalScatterOp( +SmallVector scatterOp( ArrayRef inputs, const Tensor &scatterIndices, ArrayRef updates, const Axes &updateWindowDims, const Axes &insertedWindowDims, const Axes &scatterDimsToOperandDims, @@ -2094,7 +2085,7 @@ SmallVector evalScatterOp( if (indexVectorDim < scatterIndices.getRank()) startIndicesIndex.insert(startIndicesIndex.begin() + indexVectorDim, kColon); - auto startIndex = evalIndex(evalSliceOp(scatterIndices, startIndicesIndex)); + auto startIndex = evalIndex(sliceOp(scatterIndices, startIndicesIndex)); Index fullStartIndex(inputs[0].getRank(), 0); for (auto dInput : inputs[0].getAxes()) { @@ -2132,13 +2123,12 @@ SmallVector evalScatterOp( return results; } -Tensor evalSelectAndScatterOp(const Tensor &operand, const Tensor &source, - const Tensor &initValue, - const Sizes &windowDimensions, - const Sizes &windowStrides, - const Sizes &paddingLow, Region &select, - Region &scatter, Process *process, Scope &scope, - ShapedType resultType) { +Tensor selectAndScatterOp(const Tensor &operand, const Tensor &source, + const Tensor &initValue, + const Sizes &windowDimensions, + const Sizes &windowStrides, const Sizes &paddingLow, + Region &select, Region &scatter, Process *process, + Scope &scope, ShapedType resultType) { auto result = makeSplat(resultType, initValue.get({})); for (auto sourceIt = source.index_begin(); sourceIt != source.index_end(); @@ -2178,8 +2168,8 @@ Tensor evalSelectAndScatterOp(const Tensor &operand, const Tensor &source, RankedTensorType::get({2}, initValue.getElementType())); sourceValues.set({0}, source.get(*sourceIt)); sourceValues.set({1}, result.get(operandIndex)); - auto reducedResult = evalReduceOp({sourceValues}, {initValue}, {0}, - scatter, process, scope); + auto reducedResult = + reduceOp({sourceValues}, {initValue}, {0}, scatter, process, scope); result.set(operandIndex, reducedResult[0].get({})); } }); @@ -2187,8 +2177,8 @@ Tensor evalSelectAndScatterOp(const Tensor &operand, const Tensor &source, return result; } -Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue, - const Tensor &onFalse, ShapedType resultType) { +Tensor selectOp(const Tensor &pred, const Tensor &onTrue, const Tensor &onFalse, + ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) { Element predValue = pred.getRank() != 0 ? pred.get(*it) : pred.get({}); @@ -2198,52 +2188,52 @@ Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue, return result; } -Token evalSendOp(ArrayRef inputs, Token token, ChannelId channelId, - Process *process) { +Token sendOp(ArrayRef inputs, Token token, ChannelId channelId, + Process *process) { process->send(inputs, channelId); return token; } -Tensor evalShiftLeftOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType) { +Tensor shiftLeftOp(const Tensor &lhs, const Tensor &rhs, + ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, shiftLeft(lhs.get(*it), rhs.get(*it))); return result; } -Tensor evalShiftRightArithmeticOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType) { +Tensor shiftRightArithmeticOp(const Tensor &lhs, const Tensor &rhs, + ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, shiftRightArithmetic(lhs.get(*it), rhs.get(*it))); return result; } -Tensor evalShiftRightLogicalOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType) { +Tensor shiftRightLogicalOp(const Tensor &lhs, const Tensor &rhs, + ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, shiftRightLogical(lhs.get(*it), rhs.get(*it))); return result; } -Tensor evalSignOp(const Tensor &operand, ShapedType resultType) { +Tensor signOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, sign(operand.get(*it))); return result; } -Tensor evalSineOp(const Tensor &operand, ShapedType resultType) { +Tensor sineOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, sine(operand.get(*it))); return result; } -Tensor evalSliceOp(const Tensor &operand, const Sizes &startIndices, - const Sizes &strides, ShapedType resultType) { +Tensor sliceOp(const Tensor &operand, const Sizes &startIndices, + const Sizes &strides, ShapedType resultType) { Tensor result(resultType); for (auto resultIt = result.index_begin(); resultIt != result.index_end(); ++resultIt) { @@ -2254,9 +2244,9 @@ Tensor evalSliceOp(const Tensor &operand, const Sizes &startIndices, return result; } -SmallVector evalSortOp(ArrayRef inputs, Axis dimension, - bool isStable, Region &comparator, - Process *process, Scope &scope) { +SmallVector sortOp(ArrayRef inputs, Axis dimension, + bool isStable, Region &comparator, Process *process, + Scope &scope) { SmallVector results; for (const auto &input : inputs) results.emplace_back(input.getType()); auto adjustedDimension = @@ -2316,30 +2306,29 @@ SmallVector evalSortOp(ArrayRef inputs, Axis dimension, return results; } -Tensor evalSqrtOp(const Tensor &operand, ShapedType resultType) { +Tensor sqrtOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, sqrt(operand.get(*it))); return result; } -Tensor evalSubtractOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType) { +Tensor subtractOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, lhs.get(*it) - rhs.get(*it)); return result; } -Tensor evalTanhOp(const Tensor &operand, ShapedType resultType) { +Tensor tanhOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, tanh(operand.get(*it))); return result; } -Tensor evalTransposeOp(const Tensor &operand, const Axes &permutation, - ShapedType resultType) { +Tensor transposeOp(const Tensor &operand, const Axes &permutation, + ShapedType resultType) { Tensor result(resultType); for (auto operandIt = operand.index_begin(); operandIt != operand.index_end(); ++operandIt) { @@ -2352,14 +2341,14 @@ Tensor evalTransposeOp(const Tensor &operand, const Axes &permutation, return result; } -Tuple evalTupleOp(ArrayRef val, TupleType resultType) { +Tuple tupleOp(ArrayRef val, TupleType resultType) { return Tuple(val, resultType); } -SmallVector evalWhileOp(SmallVector operand, - Region &cond, Region &body, - InterpreterFallback *fallback, - Process *process, Scope &scope) { +SmallVector whileOp(SmallVector operand, + Region &cond, Region &body, + InterpreterFallback *fallback, + Process *process, Scope &scope) { SmallVector results(operand); auto condResults = eval(cond, operand, fallback, process, &scope); @@ -2372,7 +2361,7 @@ SmallVector evalWhileOp(SmallVector operand, return results; } -Tensor evalXorOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { +Tensor xorOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) result.set(*it, lhs.get(*it) ^ rhs.get(*it)); diff --git a/stablehlo/reference/Ops.h b/stablehlo/reference/Ops.h index 8db66de3af..af546e45ac 100644 --- a/stablehlo/reference/Ops.h +++ b/stablehlo/reference/Ops.h @@ -31,53 +31,49 @@ namespace mlir { namespace stablehlo { // Evaluators for StableHLO ops. -Tensor evalAbsOp(const Tensor &operand, ShapedType resultType); -Tensor evalAddOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); -Token evalAfterAllOp(ArrayRef inputs, MLIRContext *context); -Tensor evalAllGatherOp(const Tensor &operand, int64_t allGatherDim, - SmallVector> replicaGroups, - ChannelId channelId, bool useGlobalDeviceIds, - Process *process, ShapedType resultType); -Tensor evalAllReduceOp(const Tensor &operand, - SmallVector> replicaGroups, - ChannelId channelId, bool useGlobalDeviceIds, - Region &computation, Process *process, Scope &scope, - ShapedType resultType); -Tensor evalAllToAllOp(const Tensor &operand, Axis splitDimension, - Axis concatDimension, int64_t splitCount, - SmallVector> replicaGroups, - ChannelId channelId, Process *process, - ShapedType resultType); -Tensor evalAndOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); -Tensor evalAtan2Op(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); -Tensor evalBitcastConvertOp(const Tensor &operand, ShapedType resultType); -Tensor evalBroadcastInDimOp(const Tensor &operand, - const Axes &broadcastDimensions, - ShapedType resultType); -SmallVector evalCaseOp(const Tensor &index, - RegionRange branches, Process *process, - Scope &scope); -Tensor evalCbrtOp(const Tensor &operand, ShapedType resultType); -Tensor evalCeilOp(const Tensor &operand, ShapedType resultType); -Tensor evalClampOp(const Tensor &min, const Tensor &operand, const Tensor &max, +Tensor absOp(const Tensor &operand, ShapedType resultType); +Tensor addOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Token afterAllOp(ArrayRef inputs, MLIRContext *context); +Tensor allGatherOp(const Tensor &operand, int64_t allGatherDim, + SmallVector> replicaGroups, + ChannelId channelId, bool useGlobalDeviceIds, + Process *process, ShapedType resultType); +Tensor allReduceOp(const Tensor &operand, + SmallVector> replicaGroups, + ChannelId channelId, bool useGlobalDeviceIds, + Region &computation, Process *process, Scope &scope, ShapedType resultType); -Tensor evalClzOp(const Tensor &operand, ShapedType resultType); -Tensor evalCollectiveBroadcastOp( - const Tensor &operand, SmallVector> replicaGroups, - ChannelId channelId, Process *process); -Tensor evalCollectivePermuteOp( - const Tensor &operand, SmallVector> sourceTargetPairs, - ChannelId channelId, Process *process); -Tensor evalCompareOp(const Tensor &lhs, const Tensor &rhs, - ComparisonDirection comparisonDirection, - ShapedType resultType); -Tensor evalComplexOp(const Tensor &lhs, const Tensor &rhs, +Tensor allToAllOp(const Tensor &operand, Axis splitDimension, + Axis concatDimension, int64_t splitCount, + SmallVector> replicaGroups, + ChannelId channelId, Process *process, ShapedType resultType); +Tensor andOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor atan2Op(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor bitcastConvertOp(const Tensor &operand, ShapedType resultType); +Tensor broadcastInDimOp(const Tensor &operand, const Axes &broadcastDimensions, + ShapedType resultType); +SmallVector caseOp(const Tensor &index, RegionRange branches, + Process *process, Scope &scope); +Tensor cbrtOp(const Tensor &operand, ShapedType resultType); +Tensor ceilOp(const Tensor &operand, ShapedType resultType); +Tensor clampOp(const Tensor &min, const Tensor &operand, const Tensor &max, + ShapedType resultType); +Tensor clzOp(const Tensor &operand, ShapedType resultType); +Tensor collectiveBroadcastOp(const Tensor &operand, + SmallVector> replicaGroups, + ChannelId channelId, Process *process); +Tensor collectivePermuteOp(const Tensor &operand, + SmallVector> sourceTargetPairs, + ChannelId channelId, Process *process); +Tensor compareOp(const Tensor &lhs, const Tensor &rhs, + ComparisonDirection comparisonDirection, + ShapedType resultType); +Tensor complexOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor concatenateOp(ArrayRef inputs, Axis dimension, ShapedType resultType); -Tensor evalConcatenateOp(ArrayRef inputs, Axis dimension, - ShapedType resultType); -Tensor evalConstantOp(ElementsAttr value); -Tensor evalConvertOp(const Tensor &operand, ShapedType resultType); -Tensor evalConvolutionOp( +Tensor constantOp(ElementsAttr value); +Tensor convertOp(const Tensor &operand, ShapedType resultType); +Tensor convolutionOp( const Tensor &lhs, const Tensor &rhs, ArrayRef windowStrides, ArrayRef> padding, ArrayRef lhsDilation, ArrayRef rhsDilation, @@ -87,131 +83,126 @@ Tensor evalConvolutionOp( const Axes &kernelSpatialDimensions, Axis outputBatchDimension, Axis outputFeatureDimension, const Axes &outputSpatialDimensions, int64_t featureGroupCount, int64_t batchGroupCount, ShapedType resultType); -Tensor evalCosineOp(const Tensor &operand, ShapedType resultType); -Tensor evalDivideOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType); -Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs, - const Axes &lhsBatchingDimensions, - const Axes &rhsBatchingDimensions, - const Axes &lhsContractingDimensions, - const Axes &rhsContractingDimensions, - ShapedType resultType); -Tensor evalDynamicSliceOp(const Tensor &operand, ArrayRef startIndices, - const Sizes &sliceSizes, ShapedType resultType); -Tensor evalDynamicUpdateSliceOp(const Tensor &operand, const Tensor &update, - ArrayRef startIndices, - ShapedType resultType); -Tensor evalExpm1Op(const Tensor &operand, ShapedType resultType); -Tensor evalExponentialOp(const Tensor &operand, ShapedType resultType); -Tensor evalFloorOp(const Tensor &operand, ShapedType resultType); -Tensor evalGatherOp(const Tensor &operand, const Tensor &startIndices, - const Axes &offsetDims, const Axes &collapsedSliceDims, - const Axes &startIndexMap, Axis indexVectorDim, - const Sizes &sliceSizes, bool indicesAreSorted, +Tensor cosineOp(const Tensor &operand, ShapedType resultType); +Tensor divideOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor dotGeneralOp(const Tensor &lhs, const Tensor &rhs, + const Axes &lhsBatchingDimensions, + const Axes &rhsBatchingDimensions, + const Axes &lhsContractingDimensions, + const Axes &rhsContractingDimensions, ShapedType resultType); -Tensor evalGetDimensionSizeOp(const Tensor &operand, Axis dimension, - ShapedType resultType); -InterpreterValue evalGetTupleElementOp(const Tuple &operand, int32_t index); -SmallVector evalIfOp(const Tensor &pred, Region &trueBranch, - Region &falseBranch, Process *process, - Scope &scope); -Tensor evalImagOp(const Tensor &operand, ShapedType resultType); -SmallVector evalInfeedOp(Token token, Process *process, - Region ®ion, Scope &scope); -Tensor evalIotaOp(Axis iotaDimension, ShapedType resultType); -Tensor evalIsFiniteOp(const Tensor &operand, ShapedType resultType); -Tensor evalLog1pOp(const Tensor &operand, ShapedType resultType); -Tensor evalLogOp(const Tensor &operand, ShapedType resultType); -Tensor evalLogisticOp(const Tensor &operand, ShapedType resultType); -Tensor evalMapOp(ArrayRef inputs, Region &computation, Process *process, - Scope &scope, ShapedType resultType); -Tensor evalMaxOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); -Tensor evalMinOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); -Tensor evalMultiplyOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType); -Tensor evalNegOp(const Tensor &operand, ShapedType resultType); -Tensor evalNotOp(const Tensor &operand, ShapedType resultType); -SmallVector evalOptimizationBarrierOp( +Tensor dynamicSliceOp(const Tensor &operand, ArrayRef startIndices, + const Sizes &sliceSizes, ShapedType resultType); +Tensor dynamicUpdateSliceOp(const Tensor &operand, const Tensor &update, + ArrayRef startIndices, + ShapedType resultType); +Tensor expm1Op(const Tensor &operand, ShapedType resultType); +Tensor exponentialOp(const Tensor &operand, ShapedType resultType); +Tensor floorOp(const Tensor &operand, ShapedType resultType); +Tensor gatherOp(const Tensor &operand, const Tensor &startIndices, + const Axes &offsetDims, const Axes &collapsedSliceDims, + const Axes &startIndexMap, Axis indexVectorDim, + const Sizes &sliceSizes, bool indicesAreSorted, + ShapedType resultType); +Tensor getDimensionSizeOp(const Tensor &operand, Axis dimension, + ShapedType resultType); +InterpreterValue getTupleElementOp(const Tuple &operand, int32_t index); +SmallVector ifOp(const Tensor &pred, Region &trueBranch, + Region &falseBranch, Process *process, + Scope &scope); +Tensor imagOp(const Tensor &operand, ShapedType resultType); +SmallVector infeedOp(Token token, Process *process, + Region ®ion, Scope &scope); +Tensor iotaOp(Axis iotaDimension, ShapedType resultType); +Tensor isFiniteOp(const Tensor &operand, ShapedType resultType); +Tensor log1pOp(const Tensor &operand, ShapedType resultType); +Tensor logOp(const Tensor &operand, ShapedType resultType); +Tensor logisticOp(const Tensor &operand, ShapedType resultType); +Tensor mapOp(ArrayRef inputs, Region &computation, Process *process, + Scope &scope, ShapedType resultType); +Tensor maxOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor minOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor multiplyOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor negOp(const Tensor &operand, ShapedType resultType); +Tensor notOp(const Tensor &operand, ShapedType resultType); +SmallVector optimizationBarrierOp( ArrayRef operand); -Tensor evalOrOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); -Token evalOutfeedOp(ArrayRef inputs, Token token, Process *process); -Tensor evalPadOp(const Tensor &operand, const Tensor &paddingValue, - const Sizes &edgePaddingLow, const Sizes &interiorPadding, - ShapedType resultType); -Tensor evalPartitionIdOp(Process *process, MLIRContext *context); -Tensor evalPopulationCountOp(const Tensor &operand, ShapedType resultType); -Tensor evalPowerOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); -Tensor evalRealOp(const Tensor &operand, ShapedType resultType); -SmallVector evalRecvOp(Token token, ChannelId channelId, - Process *process); -SmallVector evalReduceOp(ArrayRef inputs, - ArrayRef initValues, - const Axes &dimensions, Region &body, - Process *process, Scope &scope, - ArrayRef resultTypes); -Tensor evalReducePrecisionOp(const Tensor &operand, int32_t exponentBits, - int32_t mantissaBits, ShapedType resultType); -Tensor evalReduceScatterOp(const Tensor &operand, int64_t scatterDimension, - SmallVector> replicaGroups, - ChannelId channelId, bool useGlobalDeviceIds, - Region ®ion, Process *process, Scope &scope, - ShapedType returnType); -SmallVector evalReduceWindowOp( +Tensor orOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Token outfeedOp(ArrayRef inputs, Token token, Process *process); +Tensor padOp(const Tensor &operand, const Tensor &paddingValue, + const Sizes &edgePaddingLow, const Sizes &interiorPadding, + ShapedType resultType); +Tensor partitionIdOp(Process *process, MLIRContext *context); +Tensor populationCountOp(const Tensor &operand, ShapedType resultType); +Tensor powerOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor realOp(const Tensor &operand, ShapedType resultType); +SmallVector recvOp(Token token, ChannelId channelId, + Process *process); +SmallVector reduceOp(ArrayRef inputs, + ArrayRef initValues, + const Axes &dimensions, Region &body, + Process *process, Scope &scope, + ArrayRef resultTypes); +Tensor reducePrecisionOp(const Tensor &operand, int32_t exponentBits, + int32_t mantissaBits, ShapedType resultType); +Tensor reduceScatterOp(const Tensor &operand, int64_t scatterDimension, + SmallVector> replicaGroups, + ChannelId channelId, bool useGlobalDeviceIds, + Region ®ion, Process *process, Scope &scope, + ShapedType returnType); +SmallVector reduceWindowOp( ArrayRef inputs, ArrayRef initValues, const Sizes &windowDimensions, const Sizes &windowStrides, const Sizes &baseDilations, const Sizes &windowDilations, const Sizes &paddingLow, const Sizes &paddingHigh, Region &body, Process *process, Scope &scope, ArrayRef resultTypes); -Tensor evalRemOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); -Tensor evalReplicaIdOp(Process *process, MLIRContext *context); -Tensor evalReshapeOp(const Tensor &operand, ShapedType resultType); -Tensor evalReverseOp(const Tensor &operand, const Axes &dimensions, - ShapedType resultType); -Tensor evalRoundOp(const Tensor &operand, ShapedType resultType); -Tensor evalRoundNearestEvenOp(const Tensor &operand, ShapedType resultType); -Tensor evalRsqrtOp(const Tensor &operand, ShapedType resultType); -SmallVector evalScatterOp( +Tensor remOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor replicaIdOp(Process *process, MLIRContext *context); +Tensor reshapeOp(const Tensor &operand, ShapedType resultType); +Tensor reverseOp(const Tensor &operand, const Axes &dimensions, + ShapedType resultType); +Tensor roundOp(const Tensor &operand, ShapedType resultType); +Tensor roundNearestEvenOp(const Tensor &operand, ShapedType resultType); +Tensor rsqrtOp(const Tensor &operand, ShapedType resultType); +SmallVector scatterOp( ArrayRef inputs, const Tensor &scatterIndices, ArrayRef updates, const Axes &updateWindowDims, const Axes &insertedWindowDims, const Axes &scatterDimsToOperandDims, Axis indexVectorDim, Region &updateComputation, Process *process, Scope &scope, ArrayRef resultTypes); -Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue, - const Tensor &onFalse, ShapedType resultType); -Tensor evalSelectAndScatterOp(const Tensor &operand, const Tensor &source, - const Tensor &initValue, - const Sizes &windowDimensions, - const Sizes &windowStrides, - const Sizes &paddingLow, Region &select, - Region &scatter, Process *process, Scope &scope, +Tensor selectOp(const Tensor &pred, const Tensor &onTrue, const Tensor &onFalse, + ShapedType resultType); +Tensor selectAndScatterOp(const Tensor &operand, const Tensor &source, + const Tensor &initValue, + const Sizes &windowDimensions, + const Sizes &windowStrides, const Sizes &paddingLow, + Region &select, Region &scatter, Process *process, + Scope &scope, ShapedType resultType); +Token sendOp(ArrayRef inputs, Token token, ChannelId channelId, + Process *process); +Tensor shiftLeftOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor shiftRightArithmeticOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); -Token evalSendOp(ArrayRef inputs, Token token, ChannelId channelId, - Process *process); -Tensor evalShiftLeftOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType); -Tensor evalShiftRightArithmeticOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType); -Tensor evalShiftRightLogicalOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType); -Tensor evalSignOp(const Tensor &operand, ShapedType resultType); -Tensor evalSineOp(const Tensor &operand, ShapedType resultType); -Tensor evalSliceOp(const Tensor &operand, const Sizes &startIndices, - const Sizes &strides, ShapedType resultType); -SmallVector evalSortOp(ArrayRef inputs, Axis dimension, - bool isStable, Region &comparator, - Process *process, Scope &scope); -Tensor evalSqrtOp(const Tensor &operand, ShapedType resultType); -Tensor evalSubtractOp(const Tensor &lhs, const Tensor &rhs, - ShapedType resultType); -Tensor evalTanhOp(const Tensor &operand, ShapedType resultType); -Tensor evalTransposeOp(const Tensor &operand, const Axes &permutation, - ShapedType resultType); -Tuple evalTupleOp(ArrayRef val, TupleType resultType); -SmallVector evalWhileOp(SmallVector operand, - Region &cond, Region &body, - InterpreterFallback *fallback, - Process *process, Scope &scope); -Tensor evalXorOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor shiftRightLogicalOp(const Tensor &lhs, const Tensor &rhs, + ShapedType resultType); +Tensor signOp(const Tensor &operand, ShapedType resultType); +Tensor sineOp(const Tensor &operand, ShapedType resultType); +Tensor sliceOp(const Tensor &operand, const Sizes &startIndices, + const Sizes &strides, ShapedType resultType); +SmallVector sortOp(ArrayRef inputs, Axis dimension, + bool isStable, Region &comparator, Process *process, + Scope &scope); +Tensor sqrtOp(const Tensor &operand, ShapedType resultType); +Tensor subtractOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor tanhOp(const Tensor &operand, ShapedType resultType); +Tensor transposeOp(const Tensor &operand, const Axes &permutation, + ShapedType resultType); +Tuple tupleOp(ArrayRef val, TupleType resultType); +SmallVector whileOp(SmallVector operand, + Region &cond, Region &body, + InterpreterFallback *fallback, + Process *process, Scope &scope); +Tensor xorOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); /// Evaluates an mlir::Region `region` using the runtime values `args` /// corresponding to the arguments of the entry block of the region.