Skip to content

Commit

Permalink
Augment example test with interpreter usage (#2098)
Browse files Browse the repository at this point in the history
Augmented the example test with interpreter usage demo.

nit: Fixed the dependency of `MLIRParser` on `StablehloReferenceApi` as
API.cc uses `parseSourceFile`.
  • Loading branch information
sdasgup3 authored Mar 25, 2024
1 parent 881154e commit c123a48
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 27 deletions.
10 changes: 6 additions & 4 deletions examples/c++/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package(
)

cc_binary(
name = "stablehlo-add",
name = "example-add",
srcs = [
"stablehlo_add.cpp",
"ExampleAdd.cpp",
],
deps = [
"//:reference_api",
"//:stablehlo_ops",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand All @@ -32,11 +33,12 @@ cc_binary(
)

cc_test(
name = "stablehlo-add_test",
name = "example-add-test",
srcs = [
"stablehlo_add.cpp",
"ExampleAdd.cpp",
],
deps = [
"//:reference_api",
"//:stablehlo_ops",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
18 changes: 9 additions & 9 deletions examples/c++/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@

# TODO(fzakaria): As we think about adding more tests we should consider
# making this a function to be DRY.
add_executable(stablehlo-add stablehlo_add.cpp)
llvm_update_compile_flags(stablehlo-add)
target_link_libraries(stablehlo-add PRIVATE StablehloOps)
add_executable(example-add ExampleAdd.cpp)
llvm_update_compile_flags(example-add)
target_link_libraries(example-add PRIVATE StablehloOps StablehloReferenceApi)

mlir_check_all_link_libraries(stablehlo-add)
add_custom_target(stablehlo-add-test
COMMAND stablehlo-add
mlir_check_all_link_libraries(example-add)
add_custom_target(example-add-test
COMMAND example-add
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "Executing stablehlo-add to validate it works."
COMMENT "Executing example-add to validate it works."
DEPENDS
stablehlo-add
example-add
)
add_dependencies(check-stablehlo-quick stablehlo-add-test)
add_dependencies(check-stablehlo-quick example-add-test)
35 changes: 23 additions & 12 deletions examples/c++/stablehlo_add.cpp → examples/c++/ExampleAdd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ limitations under the License.
==============================================================================*/
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Verifier.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/reference/Api.h"

int main() {
mlir::MLIRContext context;
Expand All @@ -31,10 +33,11 @@ int main() {
module->setName("test_module");

/** create function **/
// create function argument and result types
mlir::Type arg0 =
// create function argument and result types.
auto tensorType =
mlir::RankedTensorType::get({3, 4}, mlir::FloatType::getF32(&context));
auto func_type = mlir::FunctionType::get(&context, {arg0, arg0}, {arg0});
auto func_type =
mlir::FunctionType::get(&context, {tensorType, tensorType}, {tensorType});

// create the function and map arguments.
llvm::ArrayRef<mlir::NamedAttribute> attrs;
Expand All @@ -43,26 +46,34 @@ int main() {
function.setVisibility(mlir::func::FuncOp::Visibility::Public);
module->push_back(function);

// create function block with add operations
// create function block with add operations.
mlir::Block* block = function.addEntryBlock();
llvm::SmallVector<mlir::Value, 4> arguments(block->args_begin(),
block->args_end());
mlir::OpBuilder block_builder = mlir::OpBuilder::atBlockEnd(block);
mlir::Location loc = block_builder.getUnknownLoc();

llvm::SmallVector<mlir::NamedAttribute, 10> attributes;
block_builder.create<mlir::stablehlo::AddOp>(loc, arguments, attributes)
.getOperation();

mlir::Operation* op =
block_builder
.create<mlir::stablehlo::AddOp>(loc, arg0, arguments, attributes)
block_builder.create<mlir::stablehlo::AddOp>(loc, arguments, attributes)
.getOperation();
block_builder.create<mlir::func::ReturnOp>(loc, op->getResult(0));

// verify the module and dump
/** verify and dump the module **/
assert(mlir::succeeded(mlir::verify(module.get())));
module->dump();

return 0;
/* interpret the function "main" with concrete inputs **/
auto getConstValue = [&](double val) {
return mlir::DenseElementsAttr::get(
tensorType,
block_builder.getFloatAttr(tensorType.getElementType(), val));
};

auto inputValue1 = getConstValue(10.0);
auto inputValue2 = getConstValue(20.0);
auto expectedValue = getConstValue(30.0);

mlir::stablehlo::InterpreterConfiguration config;
auto results = evalModule(*module, {inputValue1, inputValue2}, config);
return failed(results) || (*results)[0] != expectedValue;
}
40 changes: 38 additions & 2 deletions examples/c++/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,41 @@ for how to use StableHLO.

Note: If you have a great example to highlight, we welcome contributions!

* [stablehlo_add](./stablehlo_add.cpp): A simple example that demonstrates
how to use the StableHLO library to add two numbers.
* [example-add](./ExampleAdd.cpp): A simple example that demonstrates how to
* Use the StableHLO library to add two numbers.
* Interpret a StableHLO program using concrete inputs.

```c++
// Assume 'module' is an MLIR module with function "main" containing StableHLO
// operations.
llvm::outs() << "Program:\n " << module << "\n";

// Create concrete inputs to be used for interpreting "main".
auto inputValue1 = mlir::DenseElementsAttr::get(
tensorType, block_builder.getFloatAttr(tensorType.getElementType(),
static_cast<double>(10)));
auto inputValue2 = mlir::DenseElementsAttr::get(
tensorType, block_builder.getFloatAttr(tensorType.getElementType(),
static_cast<double>(20)));
llvm::outs() << "Inputs: " << inputValue1 << ", " << inputValue2 << "\n";


mlir::stablehlo::InterpreterConfiguration config;
auto results = evalModule(module, {inputValue1, inputValue2}, config);
llvm::outs() << "Output: " << (*results)[0];
```
Output:
```mlir
Program:
module @test_module {
func.func @main(%arg0: tensor<3x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3x4xf32>
%1 = stablehlo.add %arg0, %arg1 : tensor<3x4xf32>
return %1 : tensor<3x4xf32>
}
}
Inputs: dense<1.000000e+01> : tensor<3x4xf32>, dense<2.000000e+01> : tensor<3x4xf32>
Output: dense<3.000000e+01> : tensor<3x4xf32>
```
1 change: 1 addition & 0 deletions stablehlo/reference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_mlir_library(StablehloReferenceApi

LINK_LIBS PUBLIC
MLIRIR
MLIRParser
MLIRSupport
InterpreterOps
StablehloReferenceConfiguration
Expand Down

0 comments on commit c123a48

Please sign in to comment.