Skip to content

Commit

Permalink
Check casts and cast TMA coords
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobhinkle committed Feb 7, 2025
1 parent 6ebbc35 commit efa8507
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 14 deletions.
9 changes: 9 additions & 0 deletions csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <expr_simplifier.h>
#include <instrumentation.h>
#include <ir/all_nodes.h>
#include <ir/builder.h>
#include <ir/iostream.h>
#include <ir/utils.h>
#include <logical_domain_map.h>
Expand Down Expand Up @@ -2717,6 +2718,14 @@ std::pair<Val*, Val*> Index::getCpAsyncBulkGmemIndex(
auto indices_inner_to_outer =
indexer.getIndexFor(ldst, !is_load, ids_to_index, loops);

// These are the box coordinates of the TMA box, which must be of type
// int32_t. Possible overflow in each of these dims should be checked
// elsewhere.
for (size_t i : c10::irange(indices_inner_to_outer.size())) {
indices_inner_to_outer[i] =
IrBuilder::maybeCastExpr(DataType::Int32, indices_inner_to_outer[i]);
}

auto coordinate = IrBuilder::arrayExpr(indices_inner_to_outer);
auto descriptor = tma_info.tensorMap();
if (is_load) {
Expand Down
72 changes: 58 additions & 14 deletions csrc/runtime/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,6 @@ class ScalarBoundsCalculator : kir::IrVisitor {
if (val->dtype() != dtype) {
continue;
}
std::cout << " [" << b.min << ", " << b.max << "] "
<< val->toInlineString() << std::endl;
if (!initialized) {
ret = b;
initialized = true;
Expand All @@ -391,14 +389,49 @@ class ScalarBoundsCalculator : kir::IrVisitor {
}
if (b.min < std::numeric_limits<int32_t>::min() ||
b.max > std::numeric_limits<int32_t>::max()) {
std::cout << "Found value " << val->toInlineString()
<< " which has bounds [" << b.min << ", " << b.max << "]"
<< std::endl;
}
}
return ret;
}

//! Look at all casts (T)x where x is of type nvfuser_index_t, to ensure that
//! these casts are safe i.e. that the bounds of x do not overflow those
//! representable by T.
bool castsFromIndexAreSafe() const {
return std::all_of(
casts_from_index_.begin(), casts_from_index_.end(), [&](UnaryOp* cast) {
const BoundedInt& bounds = bounds_.at(cast->in());
DataType out_type = cast->out()->dtype();
if (out_type == DataType::Int) {
return true;
} else if (out_type == DataType::Int32) {
return bounds.min >= std::numeric_limits<int32_t>::min() &&
bounds.max <= std::numeric_limits<int32_t>::max();
} else if (out_type == DataType::Short) {
return bounds.min >= std::numeric_limits<int16_t>::min() &&
bounds.max <= std::numeric_limits<int16_t>::max();
} else if (out_type == DataType::Char) {
return bounds.min >= std::numeric_limits<int8_t>::min() &&
bounds.max <= std::numeric_limits<int8_t>::max();
} else if (out_type == DataType::UInt64) {
// upper limit is above that of int64_t, which is the type of
// bounds.max
return bounds.min >= 0L;
} else if (out_type == DataType::UInt32) {
return bounds.min >= std::numeric_limits<uint32_t>::min() &&
bounds.max <= std::numeric_limits<uint32_t>::max();
} else if (out_type == DataType::UInt16) {
return bounds.min >= std::numeric_limits<uint16_t>::min() &&
bounds.max <= std::numeric_limits<uint16_t>::max();
} else if (out_type == DataType::Byte) {
return bounds.min >= std::numeric_limits<uint8_t>::min() &&
bounds.max <= std::numeric_limits<uint8_t>::max();
} else {
NVF_THROW("Unhandled DataType ", cast->out()->dtype());
}
});
}

private:
void setBounds(Val* val, const BoundedInt& bounds) {
bounds_[val] = bounds;
Expand Down Expand Up @@ -467,6 +500,15 @@ class ScalarBoundsCalculator : kir::IrVisitor {
return;
}

if (auto* uop = dynamic_cast<UnaryOp*>(expr); uop &&
uop->getUnaryOpType() == UnaryOpType::Cast &&
uop->in()->dtype() == DataType::Index &&
uop->out()->isIntegralScalar()) {
// Collect casts _from_ Index scalars, so that we can check that these are
// safe.
casts_from_index_.push_back(uop);
}

if (!expr->isA<ForLoop>() &&
std::all_of(
expr->outputs().begin(), expr->outputs().end(), [](Val* outp) {
Expand Down Expand Up @@ -543,6 +585,12 @@ class ScalarBoundsCalculator : kir::IrVisitor {
case UnaryOpType::BitwiseNot:
result = ~a;
break;
case UnaryOpType::Cast:
// This assumes there is no loss or overflow, since those should not
// occur in our kernels. We can check that later for index types using
// castsFromIndexAreSafe().
result = a;
break;
case UnaryOpType::Neg:
result = {-a.max, -a.min};
break;
Expand Down Expand Up @@ -617,6 +665,7 @@ class ScalarBoundsCalculator : kir::IrVisitor {
const LaunchParams& launch_params_;
std::unordered_map<const Val*, BoundedInt> bounds_;
std::unordered_map<const Val*, PolymorphicValue> known_scalars_;
std::vector<UnaryOp*> casts_from_index_;
};

PrimDataType getSmallestIndexTypeByBoundingExpressions(
Expand All @@ -627,11 +676,10 @@ PrimDataType getSmallestIndexTypeByBoundingExpressions(
ScalarBoundsCalculator calc(kernel, expr_eval, launch_params);
// Compute the range of all nvfuser_index_t values in the fusion
BoundedInt index_bounds = calc.boundByDataType();
// TODO: while we still have the ScalarBoundsCalculator computed, we should
// check the inputs to any TMA expressions to ensure that it is safe to cast
// them to Int32. Doing so would allow us to no longer throw an error when
// the line below returns Int instead of Int32, e.g. in cases where TMA is
// used for loads but not for stores with large problems.
// while we still have the ScalarBoundsCalculator computed, check the inputs
// to any TMA expressions to ensure that it is safe to cast them to Int32.
NVF_ERROR(
calc.castsFromIndexAreSafe(), "Found unsafe casts from DataType::Index");
return (index_bounds.min < (int64_t)std::numeric_limits<int32_t>::min() ||
index_bounds.max > (int64_t)std::numeric_limits<int32_t>::max())
? PrimDataType::Int
Expand Down Expand Up @@ -777,10 +825,6 @@ void KernelExecutor::compile(
expr_eval.precomputedValues()->bindValues(kernel->inputs(), args);
compile_params.index_type = getSmallestIndexTypeByBoundingExpressions(
kernel, expr_eval, launch_params);
NVF_ERROR(
compile_params.index_type.value() == PrimDataType::Int32,
"Compilation with int64 is requested but int32 is required because ",
"of TMA operations.");
}

// Now that we have launch parameters we can compile the kernel. It's a bit
Expand Down

0 comments on commit efa8507

Please sign in to comment.