Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Feb 10, 2025
1 parent ed0bdd0 commit f15f050
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 6 deletions.
22 changes: 19 additions & 3 deletions tools/pnnx/src/pass_level2/torch_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class torch_max_tnn : public GraphRewriterPass
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.ReduceMax op_0 1 1 input out arg0=%keepdims arg1=%dim
tnn.ReduceMax op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -202,8 +202,24 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["dim"] = captured_params.at("dim");
op->params["keepdim"] = captured_params.at("keepdims").i ? true : false;
std::vector<int> dim;
for (int i = 1; ; i++)
{
if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end())
break;

dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i);
}

if (dim.size() == 1)
{
op->params["dim"] = dim[0];
}
else
{
fprintf(stderr, "fallback to reduce max all\n");
}
op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false;
}
};

Expand Down
15 changes: 12 additions & 3 deletions tools/pnnx/src/pass_level2/torch_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class torch_mean_tnn : public GraphRewriterPass
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.ReduceMean op_0 1 1 input out arg0=%keepdims arg1=%dim
tnn.ReduceMean op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -168,8 +168,17 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["dim"] = captured_params.at("dim");
op->params["keepdim"] = captured_params.at("keepdims").i ? true : false;
std::vector<int> dim;
for (int i = 1; ; i++)
{
if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end())
break;

dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i);
}

op->params["dim"] = dim;
op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false;
}
};

Expand Down
43 changes: 43 additions & 0 deletions tools/pnnx/src/pass_level2/torch_min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,47 @@ pnnx.Output output 2 0 out indices

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx_1, 50)

class torch_min_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.ReduceMin op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.min";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
std::vector<int> dim;
for (int i = 1; ; i++)
{
if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end())
break;

dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i);
}

if (dim.size() == 1)
{
op->params["dim"] = dim[0];
}
else
{
fprintf(stderr, "fallback to reduce min all\n");
}
op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_tnn, 50)

} // namespace pnnx

0 comments on commit f15f050

Please sign in to comment.