4617 lines
236 KiB
Diff
4617 lines
236 KiB
Diff
From fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a Mon Sep 17 00:00:00 2001
|
|
From: Christian Sigg <chsigg@users.noreply.github.com>
|
|
Date: Thu, 16 Feb 2023 15:40:53 +0100
|
|
Subject: [PATCH] Rebase Triton to LLVM-15. (#1070)
|
|
|
|
This PR rebases Triton from LLVM-14 to LLVM-15. Most changes are
|
|
mechanical, except for the analysis framework changes.
|
|
---
|
|
CMakeLists.txt | 6 +-
|
|
bin/CMakeLists.txt | 2 +-
|
|
bin/FileCheck/FileCheck.cpp | 3 +
|
|
bin/triton-opt.cpp | 6 +-
|
|
bin/triton-translate.cpp | 7 +-
|
|
include/triton/Analysis/Alias.h | 21 +-
|
|
include/triton/Analysis/Allocation.h | 2 +
|
|
include/triton/Analysis/AxisInfo.h | 56 ++-
|
|
include/triton/Analysis/Utility.h | 6 +-
|
|
include/triton/Conversion/Passes.td | 4 +-
|
|
include/triton/Dialect/Triton/IR/Dialect.h | 7 +-
|
|
.../triton/Dialect/Triton/IR/TritonDialect.td | 8 +-
|
|
include/triton/Dialect/Triton/IR/TritonOps.td | 12 +-
|
|
.../triton/Dialect/Triton/IR/TritonTypes.td | 2 +
|
|
.../Dialect/Triton/Transforms/Passes.td | 3 +-
|
|
include/triton/Dialect/TritonGPU/IR/Dialect.h | 4 +-
|
|
.../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 7 +
|
|
.../Dialect/TritonGPU/IR/TritonGPUDialect.td | 2 +-
|
|
.../Dialect/TritonGPU/IR/TritonGPUOps.td | 13 +-
|
|
lib/Analysis/Alias.cpp | 14 +-
|
|
lib/Analysis/Allocation.cpp | 30 +-
|
|
lib/Analysis/AxisInfo.cpp | 79 ++--
|
|
lib/Analysis/CMakeLists.txt | 2 +-
|
|
lib/Analysis/Membar.cpp | 2 +-
|
|
lib/Analysis/Utility.cpp | 54 +++
|
|
.../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 -
|
|
lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h | 10 +-
|
|
.../TritonGPUToLLVM/DotOpToLLVM.cpp | 5 -
|
|
.../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 2 -
|
|
.../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 5 +-
|
|
.../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 2 -
|
|
.../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 7 +-
|
|
.../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 26 +-
|
|
.../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 52 +--
|
|
lib/Conversion/TritonGPUToLLVM/Utility.h | 5 +-
|
|
.../TritonToTritonGPUPass.cpp | 69 ++--
|
|
lib/Dialect/Triton/IR/CMakeLists.txt | 10 +-
|
|
lib/Dialect/Triton/IR/Ops.cpp | 34 +-
|
|
lib/Dialect/Triton/Transforms/Combine.cpp | 6 +-
|
|
lib/Dialect/Triton/Transforms/Combine.td | 2 +-
|
|
lib/Dialect/TritonGPU/IR/Dialect.cpp | 27 +-
|
|
lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 20 +-
|
|
lib/Dialect/TritonGPU/Transforms/Combine.cpp | 2 +-
|
|
lib/Dialect/TritonGPU/Transforms/Combine.td | 1 +
|
|
.../Transforms/DecomposeConversions.cpp | 2 +-
|
|
lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 10 +-
|
|
.../Transforms/ReorderInstructions.cpp | 2 +-
|
|
.../Transforms/TritonGPUConversion.cpp | 12 +-
|
|
.../Transforms/UpdateMmaForVolta.cpp | 6 +-
|
|
lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +-
|
|
lib/Target/LLVMIR/CMakeLists.txt | 3 +-
|
|
lib/Target/PTX/PTXTranslation.cpp | 3 +
|
|
python/setup.py | 15 +-
|
|
python/src/triton.cc | 85 +++--
|
|
python/test/unit/language/test_core.py | 2 +-
|
|
python/triton/compiler.py | 4 +-
|
|
test/Analysis/test-alias.mlir | 24 +-
|
|
test/Analysis/test-alignment.mlir | 344 +++++++++---------
|
|
test/Analysis/test-allocation.mlir | 32 +-
|
|
test/Analysis/test-membar.mlir | 38 +-
|
|
test/Conversion/triton_ops.mlir | 10 +-
|
|
test/Conversion/triton_to_tritongpu.mlir | 6 +-
|
|
test/Conversion/tritongpu_to_llvm.mlir | 94 ++---
|
|
test/Target/tritongpu_to_llvmir.mlir | 4 +-
|
|
test/Target/tritongpu_to_ptx.mlir | 2 +-
|
|
test/Triton/combine.mlir | 40 +-
|
|
test/Triton/vecadd.mlir | 4 +-
|
|
test/TritonGPU/coalesce.mlir | 2 +-
|
|
test/TritonGPU/combine.mlir | 38 +-
|
|
test/TritonGPU/loop-pipeline.mlir | 22 +-
|
|
test/TritonGPU/matmul.mlir | 4 +-
|
|
test/TritonGPU/prefetch.mlir | 4 +-
|
|
test/TritonGPU/update-mma-for-volta.mlir | 4 +-
|
|
test/lib/Analysis/TestAlias.cpp | 29 +-
|
|
test/lib/Analysis/TestAllocation.cpp | 5 +-
|
|
test/lib/Analysis/TestAxisInfo.cpp | 51 +--
|
|
test/lib/Analysis/TestMembar.cpp | 7 +-
|
|
78 files changed, 808 insertions(+), 742 deletions(-)
|
|
|
|
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
|
index d0d361fc7c..b281a28400 100644
|
|
--- a/CMakeLists.txt
|
|
+++ b/CMakeLists.txt
|
|
@@ -1,4 +1,7 @@
|
|
cmake_minimum_required(VERSION 3.6)
|
|
+
|
|
+cmake_policy(SET CMP0116 OLD)
|
|
+
|
|
include(ExternalProject)
|
|
|
|
set(CMAKE_CXX_STANDARD 17)
|
|
@@ -155,7 +158,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
|
endif()
|
|
endif()
|
|
|
|
-
|
|
# # Triton
|
|
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
|
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
|
|
@@ -212,7 +214,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
|
# optimizations
|
|
MLIRPass
|
|
MLIRTransforms
|
|
- MLIRLLVMIR
|
|
+ MLIRLLVMDialect
|
|
MLIRSupport
|
|
MLIRTargetLLVMIRExport
|
|
MLIRExecutionEngine
|
|
diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt
|
|
index 906f635f8b..695b3479fd 100644
|
|
--- a/bin/CMakeLists.txt
|
|
+++ b/bin/CMakeLists.txt
|
|
@@ -48,7 +48,7 @@ llvm_update_compile_flags(triton-translate)
|
|
# MLIR core
|
|
MLIROptLib
|
|
MLIRIR
|
|
- MLIRLLVMIR
|
|
+ MLIRLLVMDialect
|
|
MLIRPass
|
|
MLIRSupport
|
|
MLIRTransforms
|
|
diff --git a/bin/FileCheck/FileCheck.cpp b/bin/FileCheck/FileCheck.cpp
|
|
index 819efc3541..9ac6f1b277 100644
|
|
--- a/bin/FileCheck/FileCheck.cpp
|
|
+++ b/bin/FileCheck/FileCheck.cpp
|
|
@@ -19,6 +19,7 @@
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/InitLLVM.h"
|
|
#include "llvm/Support/Process.h"
|
|
+#include "llvm/Support/SourceMgr.h"
|
|
#include "llvm/Support/WithColor.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <cmath>
|
|
@@ -360,6 +361,8 @@ static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) {
|
|
return "bad-not";
|
|
case Check::CheckBadCount:
|
|
return "bad-count";
|
|
+ case Check::CheckMisspelled:
|
|
+ return "misspelled";
|
|
case Check::CheckNone:
|
|
llvm_unreachable("invalid FileCheckType");
|
|
}
|
|
diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp
|
|
index 9f3b53b7ae..f96232e1b0 100644
|
|
--- a/bin/triton-opt.cpp
|
|
+++ b/bin/triton-opt.cpp
|
|
@@ -8,7 +8,7 @@
|
|
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/InitAllPasses.h"
|
|
-#include "mlir/Support/MlirOptMain.h"
|
|
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
@@ -33,8 +33,8 @@ int main(int argc, char **argv) {
|
|
// TODO: register Triton & TritonGPU passes
|
|
mlir::DialectRegistry registry;
|
|
registry.insert<mlir::triton::TritonDialect,
|
|
- mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
|
|
- mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
|
|
+ mlir::triton::gpu::TritonGPUDialect, mlir::func::FuncDialect,
|
|
+ mlir::math::MathDialect, mlir::arith::ArithmeticDialect,
|
|
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
|
|
|
|
return mlir::asMainReturnCode(mlir::MlirOptMain(
|
|
diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp
|
|
index 05ba15e453..56b5d65857 100644
|
|
--- a/bin/triton-translate.cpp
|
|
+++ b/bin/triton-translate.cpp
|
|
@@ -3,7 +3,7 @@
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
-#include "mlir/Parser.h"
|
|
+#include "mlir/Parser/Parser.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
@@ -38,7 +38,7 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
|
mlir::DialectRegistry registry;
|
|
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
|
|
mlir::math::MathDialect, arith::ArithmeticDialect,
|
|
- StandardOpsDialect, scf::SCFDialect>();
|
|
+ scf::SCFDialect>();
|
|
|
|
context.appendDialectRegistry(registry);
|
|
|
|
@@ -50,7 +50,8 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
|
context.loadAllAvailableDialects();
|
|
context.allowUnregisteredDialects();
|
|
|
|
- OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
|
|
+ OwningOpRef<ModuleOp> module =
|
|
+ parseSourceFile<ModuleOp>(sourceMgr, &context);
|
|
if (!module) {
|
|
llvm::errs() << "Parse MLIR file failed.";
|
|
return nullptr;
|
|
diff --git a/include/triton/Analysis/Alias.h b/include/triton/Analysis/Alias.h
|
|
index fa6b906fc9..631df518bc 100644
|
|
--- a/include/triton/Analysis/Alias.h
|
|
+++ b/include/triton/Analysis/Alias.h
|
|
@@ -2,7 +2,7 @@
|
|
#define TRITON_ANALYSIS_ALIAS_H
|
|
|
|
#include "mlir/Analysis/AliasAnalysis.h"
|
|
-#include "mlir/Analysis/DataFlowAnalysis.h"
|
|
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
|
|
#include "llvm/ADT/DenseSet.h"
|
|
|
|
namespace mlir {
|
|
@@ -21,7 +21,7 @@ class AliasInfo {
|
|
}
|
|
|
|
/// The pessimistic value state of a value without alias
|
|
- static AliasInfo getPessimisticValueState(MLIRContext *context) {
|
|
+ static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) {
|
|
return AliasInfo();
|
|
}
|
|
static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); }
|
|
@@ -29,6 +29,10 @@ class AliasInfo {
|
|
/// The union of both arguments
|
|
static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs);
|
|
|
|
+ void print(raw_ostream &os) const {
|
|
+ llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); });
|
|
+ }
|
|
+
|
|
private:
|
|
/// The set of allocated values that are aliased by this lattice.
|
|
/// For now, we only consider aliased value produced by the following
|
|
@@ -58,9 +62,13 @@ class AliasInfo {
|
|
//===----------------------------------------------------------------------===//
|
|
// Shared Memory Alias Analysis
|
|
//===----------------------------------------------------------------------===//
|
|
-class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
|
|
+class SharedMemoryAliasAnalysis
|
|
+ : public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AliasInfo>> {
|
|
public:
|
|
- using ForwardDataFlowAnalysis<AliasInfo>::ForwardDataFlowAnalysis;
|
|
+ using dataflow::SparseDataFlowAnalysis<
|
|
+ dataflow::Lattice<AliasInfo>>::SparseDataFlowAnalysis;
|
|
+ using dataflow::SparseDataFlowAnalysis<
|
|
+ dataflow::Lattice<AliasInfo>>::getLatticeElement;
|
|
|
|
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
|
|
/// Given two values, returns their aliasing behavior.
|
|
@@ -70,9 +78,10 @@ class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
|
|
ModRefResult getModRef(Operation *op, Value location);
|
|
|
|
/// Computes if the alloc set of the results are changed.
|
|
- ChangeResult
|
|
+ void
|
|
visitOperation(Operation *op,
|
|
- ArrayRef<LatticeElement<AliasInfo> *> operands) override;
|
|
+ ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
|
|
+ ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
|
|
};
|
|
|
|
} // namespace mlir
|
|
diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h
|
|
index b7c136d602..89b77034cc 100644
|
|
--- a/include/triton/Analysis/Allocation.h
|
|
+++ b/include/triton/Analysis/Allocation.h
|
|
@@ -188,6 +188,8 @@ class Allocation {
|
|
friend class triton::AllocationAnalysis;
|
|
};
|
|
|
|
+template <typename T> Interval(T, T) -> Interval<T>;
|
|
+
|
|
} // namespace mlir
|
|
|
|
#endif // TRITON_ANALYSIS_ALLOCATION_H
|
|
diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h
|
|
index fdfbd8fbb3..7083b9c43b 100644
|
|
--- a/include/triton/Analysis/AxisInfo.h
|
|
+++ b/include/triton/Analysis/AxisInfo.h
|
|
@@ -1,9 +1,10 @@
|
|
#ifndef TRITON_ANALYSIS_AXISINFO_H
|
|
#define TRITON_ANALYSIS_AXISINFO_H
|
|
|
|
-#include "mlir/Analysis/DataFlowAnalysis.h"
|
|
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
+#include "mlir/Support/LLVM.h"
|
|
#include "triton/Analysis/Utility.h"
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
@@ -62,7 +63,7 @@ class AxisInfo {
|
|
}
|
|
|
|
/// The pessimistic value state of the contiguity is unknown.
|
|
- static AxisInfo getPessimisticValueState(MLIRContext *context) {
|
|
+ static AxisInfo getPessimisticValueState(MLIRContext *context = nullptr) {
|
|
return AxisInfo();
|
|
}
|
|
static AxisInfo getPessimisticValueState(Value value);
|
|
@@ -70,6 +71,22 @@ class AxisInfo {
|
|
/// The gcd of both arguments for each dimension
|
|
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
|
|
|
|
+ void print(raw_ostream &os) const {
|
|
+ auto print = [&](StringRef name, DimVectorT vec) {
|
|
+ os << name << " = [";
|
|
+ llvm::interleaveComma(vec, os);
|
|
+ os << "]";
|
|
+ };
|
|
+ print("contiguity", contiguity);
|
|
+ print(", divisibility", divisibility);
|
|
+ print(", constancy", constancy);
|
|
+ os << ", constant_value = ";
|
|
+ if (constantValue)
|
|
+ os << *constantValue;
|
|
+ else
|
|
+ os << "<none>";
|
|
+ }
|
|
+
|
|
private:
|
|
/// The _contiguity_ information maps the `d`-th
|
|
/// dimension to the length of the shortest
|
|
@@ -147,7 +164,8 @@ class AxisInfoVisitor {
|
|
}
|
|
|
|
virtual AxisInfo
|
|
- getAxisInfo(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) = 0;
|
|
+ getAxisInfo(Operation *op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) = 0;
|
|
|
|
virtual bool match(Operation *op) = 0;
|
|
};
|
|
@@ -157,15 +175,16 @@ template <typename OpTy> class AxisInfoVisitorImpl : public AxisInfoVisitor {
|
|
public:
|
|
using AxisInfoVisitor::AxisInfoVisitor;
|
|
|
|
- AxisInfo getAxisInfo(Operation *op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) final {
|
|
+ AxisInfo
|
|
+ getAxisInfo(Operation *op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) final {
|
|
return getAxisInfo(cast<OpTy>(op), operands);
|
|
}
|
|
|
|
bool match(Operation *op) final { return isa<OpTy>(op); }
|
|
|
|
- virtual AxisInfo getAxisInfo(OpTy op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
|
+ virtual AxisInfo
|
|
+ getAxisInfo(OpTy op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
|
|
llvm_unreachable("Unimplemented getAxisInfo");
|
|
}
|
|
};
|
|
@@ -176,8 +195,9 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
|
|
public:
|
|
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
|
|
|
- AxisInfo getAxisInfo(OpTy op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
|
+ AxisInfo
|
|
+ getAxisInfo(OpTy op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
|
auto lhsInfo = operands[0]->getValue();
|
|
auto rhsInfo = operands[1]->getValue();
|
|
auto rank = lhsInfo.getRank();
|
|
@@ -230,7 +250,8 @@ class AxisInfoVisitorList {
|
|
(visitors.emplace_back(std::make_unique<Ts>()), ...);
|
|
}
|
|
|
|
- AxisInfo apply(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
|
+ AxisInfo apply(Operation *op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
|
|
for (auto &visitor : visitors)
|
|
if (visitor->match(op))
|
|
return visitor->getAxisInfo(op, operands);
|
|
@@ -241,16 +262,19 @@ class AxisInfoVisitorList {
|
|
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
|
|
};
|
|
|
|
-class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
|
|
+class AxisInfoAnalysis
|
|
+ : public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>> {
|
|
private:
|
|
AxisInfoVisitorList visitors;
|
|
|
|
public:
|
|
- AxisInfoAnalysis(MLIRContext *context);
|
|
+ AxisInfoAnalysis(DataFlowSolver &solver);
|
|
+ using dataflow::SparseDataFlowAnalysis<
|
|
+ dataflow::Lattice<AxisInfo>>::getLatticeElement;
|
|
|
|
- ChangeResult
|
|
- visitOperation(Operation *op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
|
+ void visitOperation(Operation *op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
|
|
+ ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
|
|
|
|
unsigned getPtrContiguity(Value ptr);
|
|
|
|
@@ -261,4 +285,4 @@ class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
|
|
|
|
} // namespace mlir
|
|
|
|
-#endif
|
|
\ No newline at end of file
|
|
+#endif
|
|
diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h
|
|
index c5ac137dc1..ee7fadb59d 100644
|
|
--- a/include/triton/Analysis/Utility.h
|
|
+++ b/include/triton/Analysis/Utility.h
|
|
@@ -1,6 +1,7 @@
|
|
#ifndef TRITON_ANALYSIS_UTILITY_H
|
|
#define TRITON_ANALYSIS_UTILITY_H
|
|
|
|
+#include "mlir/Analysis/DataFlowFramework.h"
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include <algorithm>
|
|
@@ -12,7 +13,7 @@ namespace mlir {
|
|
class ReduceOpHelper {
|
|
public:
|
|
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
|
|
- srcTy = op.operand().getType().cast<RankedTensorType>();
|
|
+ srcTy = op.getOperand().getType().cast<RankedTensorType>();
|
|
}
|
|
|
|
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
|
|
@@ -103,6 +104,9 @@ SetVector<Operation *>
|
|
multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr,
|
|
TransitiveFilter forwardFilter = nullptr);
|
|
|
|
+// Create a basic DataFlowSolver with constant and dead code analysis included.
|
|
+std::unique_ptr<DataFlowSolver> createDataFlowSolver();
|
|
+
|
|
} // namespace mlir
|
|
|
|
#endif // TRITON_ANALYSIS_UTILITY_H
|
|
diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td
|
|
index 70bb20b78e..be00eb2dac 100644
|
|
--- a/include/triton/Conversion/Passes.td
|
|
+++ b/include/triton/Conversion/Passes.td
|
|
@@ -12,7 +12,6 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
|
|
|
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
|
"mlir::math::MathDialect",
|
|
- "mlir::StandardOpsDialect",
|
|
// TODO: Does this pass depend on SCF?
|
|
"mlir::scf::SCFDialect",
|
|
"mlir::triton::TritonDialect",
|
|
@@ -41,8 +40,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
|
"mlir::tensor::TensorDialect",
|
|
"mlir::triton::TritonDialect",
|
|
"mlir::triton::gpu::TritonGPUDialect",
|
|
- "mlir::NVVM::NVVMDialect",
|
|
- "mlir::StandardOpsDialect"];
|
|
+ "mlir::NVVM::NVVMDialect"];
|
|
|
|
let options = [
|
|
Option<"computeCapability", "compute-capability",
|
|
diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h
|
|
index e8012a51df..15869e262e 100644
|
|
--- a/include/triton/Dialect/Triton/IR/Dialect.h
|
|
+++ b/include/triton/Dialect/Triton/IR/Dialect.h
|
|
@@ -1,14 +1,15 @@
|
|
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
|
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
|
|
|
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
|
+#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
-#include "mlir/Dialect/SCF/SCF.h"
|
|
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
+#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
|
-
|
|
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
|
|
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
|
|
#include "triton/Dialect/Triton/IR/Traits.h"
|
|
diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td
|
|
index 07b069e14f..d98ce73884 100644
|
|
--- a/include/triton/Dialect/Triton/IR/TritonDialect.td
|
|
+++ b/include/triton/Dialect/Triton/IR/TritonDialect.td
|
|
@@ -25,12 +25,9 @@ def Triton_Dialect : Dialect {
|
|
let dependentDialects = [
|
|
"arith::ArithmeticDialect",
|
|
"math::MathDialect",
|
|
- "StandardOpsDialect",
|
|
"scf::SCFDialect",
|
|
-
|
|
- // Since LLVM 15
|
|
- // "cf::ControlFlowDialect",
|
|
- // "func::FuncDialect"
|
|
+ "cf::ControlFlowDialect",
|
|
+ "func::FuncDialect"
|
|
];
|
|
|
|
let extraClassDeclaration = [{
|
|
@@ -38,6 +35,7 @@ def Triton_Dialect : Dialect {
|
|
}];
|
|
|
|
let hasConstantMaterializer = 1;
|
|
+ let useDefaultTypePrinterParser = 1;
|
|
}
|
|
|
|
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
|
diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td
|
|
index 779e0b648c..0a69211179 100644
|
|
--- a/include/triton/Dialect/Triton/IR/TritonOps.td
|
|
+++ b/include/triton/Dialect/Triton/IR/TritonOps.td
|
|
@@ -141,11 +141,7 @@ def TT_LoadOp : TT_Op<"load",
|
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
|
];
|
|
|
|
- // let assemblyFormat = "operands attr-dict `:` type($result)";
|
|
- let parser = [{ return mlir::triton::parseLoadOp(parser, result); }];
|
|
-
|
|
- let printer = [{ return mlir::triton::printLoadOp(p, *this); }];
|
|
-
|
|
+ let hasCustomAssemblyFormat = 1;
|
|
let hasCanonicalizer = 1;
|
|
}
|
|
|
|
@@ -170,11 +166,7 @@ def TT_StoreOp : TT_Op<"store",
|
|
"triton::EvictionPolicy":$evict)>,
|
|
];
|
|
|
|
- // let assemblyFormat = "operands attr-dict `:` type($value)";
|
|
- let parser = [{ return mlir::triton::parseStoreOp(parser, result); }];
|
|
-
|
|
- let printer = [{ return mlir::triton::printStoreOp(p, *this); }];
|
|
-
|
|
+ let hasCustomAssemblyFormat = 1;
|
|
let hasCanonicalizer = 1;
|
|
}
|
|
|
|
diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td
|
|
index 66d2a7b9a9..2fe2fd077d 100644
|
|
--- a/include/triton/Dialect/Triton/IR/TritonTypes.td
|
|
+++ b/include/triton/Dialect/Triton/IR/TritonTypes.td
|
|
@@ -1,6 +1,7 @@
|
|
#ifndef TRITON_TYPES
|
|
#define TRITON_TYPES
|
|
|
|
+include "mlir/IR/AttrTypeBase.td"
|
|
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
|
|
|
//
|
|
@@ -58,6 +59,7 @@ def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
|
|
}]>
|
|
];
|
|
|
|
+ let hasCustomAssemblyFormat = 1;
|
|
let skipDefaultBuilders = 1;
|
|
}
|
|
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
|
|
diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td
|
|
index 8f77aed774..a25cdc5680 100644
|
|
--- a/include/triton/Dialect/Triton/Transforms/Passes.td
|
|
+++ b/include/triton/Dialect/Triton/Transforms/Passes.td
|
|
@@ -16,8 +16,7 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
|
|
|
|
let constructor = "mlir::triton::createCombineOpsPass()";
|
|
|
|
- let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
|
- /*SelectOp*/"mlir::StandardOpsDialect"];
|
|
+ let dependentDialects = ["mlir::arith::ArithmeticDialect"];
|
|
}
|
|
|
|
#endif
|
|
diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h
|
|
index b4c8daec7b..dfc5f53ab1 100644
|
|
--- a/include/triton/Dialect/TritonGPU/IR/Dialect.h
|
|
+++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h
|
|
@@ -1,19 +1,17 @@
|
|
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
|
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
|
|
|
-#include "mlir/Dialect/GPU/GPUDialect.h"
|
|
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
|
|
// TritonGPU depends on Triton
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
-
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
|
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
|
|
|
#define GET_ATTRDEF_CLASSES
|
|
-#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
|
|
index 0242c3cc17..af2aeb03a8 100644
|
|
--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
|
|
+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
|
|
@@ -1,6 +1,7 @@
|
|
#ifndef TRITONGPU_ATTRDEFS
|
|
#define TRITONGPU_ATTRDEFS
|
|
|
|
+include "mlir/IR/AttrTypeBase.td"
|
|
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
|
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
|
|
|
@@ -136,6 +137,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
|
];
|
|
|
|
let extraClassDeclaration = extraBaseClassDeclaration;
|
|
+ let hasCustomAssemblyFormat = 1;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
@@ -273,6 +275,7 @@ for
|
|
// ArrayRefParameter<"unsigned">:$sizePerCTA
|
|
);
|
|
|
|
+ let hasCustomAssemblyFormat = 1;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
@@ -422,6 +425,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
|
static constexpr int numBitsToHoldMmaV1ID{5};
|
|
}];
|
|
|
|
+ let hasCustomAssemblyFormat = 1;
|
|
}
|
|
|
|
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
|
@@ -456,6 +460,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
|
template<class T>
|
|
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
|
|
}];
|
|
+
|
|
+ let hasCustomAssemblyFormat = 1;
|
|
}
|
|
|
|
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
|
|
@@ -492,6 +498,7 @@ section 9.7.13.4.1 for more details.
|
|
|
|
];
|
|
|
|
+ let hasCustomAssemblyFormat = 1;
|
|
let extraClassDeclaration = extraBaseClassDeclaration;
|
|
}
|
|
|
|
diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
|
|
index 87ec1d36c6..6489a721b4 100644
|
|
--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
|
|
+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
|
|
@@ -30,7 +30,7 @@ def TritonGPU_Dialect : Dialect {
|
|
}
|
|
}];
|
|
|
|
-
|
|
+ let useDefaultAttributePrinterParser = 1;
|
|
}
|
|
|
|
#endif
|
|
diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
|
|
index 510f8d0183..7aba11dc75 100644
|
|
--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
|
|
+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
|
|
@@ -59,7 +59,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
|
|
// This is needed because these ops don't
|
|
// handle encodings
|
|
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
|
|
-def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
|
|
+def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
|
|
SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding]> {
|
|
let summary = "integer comparison operation";
|
|
@@ -73,7 +73,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
|
|
let results = (outs TT_BoolLike:$result);
|
|
}
|
|
|
|
-def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
|
|
+def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
|
|
SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding]> {
|
|
let summary = "floating-point comparison operation";
|
|
@@ -88,8 +88,8 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
|
|
}
|
|
|
|
// TODO: migrate to arith::SelectOp on LLVM16
|
|
-def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
|
|
- SameOperandsAndResultShape,
|
|
+def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
|
|
+ SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding]> {
|
|
let summary = "select operation";
|
|
|
|
@@ -188,10 +188,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
|
}
|
|
}];
|
|
|
|
- // The custom parser could be replaced with oilist in LLVM-16
|
|
- let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
|
|
-
|
|
- let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
|
|
+ let hasCustomAssemblyFormat = 1;
|
|
}
|
|
|
|
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
|
|
diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp
|
|
index a39e4de9aa..208fdd4afc 100644
|
|
--- a/lib/Analysis/Alias.cpp
|
|
+++ b/lib/Analysis/Alias.cpp
|
|
@@ -18,8 +18,9 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) {
|
|
return ret;
|
|
}
|
|
|
|
-ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
|
- Operation *op, ArrayRef<LatticeElement<AliasInfo> *> operands) {
|
|
+void SharedMemoryAliasAnalysis::visitOperation(
|
|
+ Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
|
|
+ ArrayRef<dataflow::Lattice<AliasInfo> *> results) {
|
|
AliasInfo aliasInfo;
|
|
bool pessimistic = true;
|
|
if (maybeSharedAllocationOp(op)) {
|
|
@@ -44,14 +45,11 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
|
}
|
|
|
|
if (pessimistic) {
|
|
- return markAllPessimisticFixpoint(op->getResults());
|
|
+ return markAllPessimisticFixpoint(results);
|
|
}
|
|
// Join all lattice elements
|
|
- ChangeResult result = ChangeResult::NoChange;
|
|
- for (Value value : op->getResults()) {
|
|
- result |= getLatticeElement(value).join(aliasInfo);
|
|
- }
|
|
- return result;
|
|
+ for (auto *result : results)
|
|
+ propagateIfChanged(result, result->join(aliasInfo));
|
|
}
|
|
|
|
AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {
|
|
diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp
|
|
index 712c08c475..b4de8dcd9d 100644
|
|
--- a/lib/Analysis/Allocation.cpp
|
|
+++ b/lib/Analysis/Allocation.cpp
|
|
@@ -1,4 +1,5 @@
|
|
#include "triton/Analysis/Allocation.h"
|
|
+#include "mlir/Analysis/DataFlowFramework.h"
|
|
#include "mlir/Analysis/Liveness.h"
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
@@ -33,10 +34,8 @@ constexpr int kPtrBitWidth = 64;
|
|
|
|
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
|
|
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
|
|
- auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
|
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
|
auto srcDotLayout = srcLayout.dyn_cast<DotOperandEncodingAttr>();
|
|
- auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
|
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
|
auto dstDotLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>();
|
|
assert(!(srcMmaLayout && dstMmaLayout) &&
|
|
@@ -224,14 +223,12 @@ class AllocationAnalysis {
|
|
}
|
|
|
|
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
|
|
- LatticeElement<AliasInfo> *latticeElement =
|
|
- analysis.lookupLatticeElement(value);
|
|
- if (latticeElement) {
|
|
- auto &info = latticeElement->getValue();
|
|
- if (!info.getAllocs().empty()) {
|
|
- for (auto alloc : info.getAllocs()) {
|
|
- allocation->addAlias(value, alloc);
|
|
- }
|
|
+ dataflow::Lattice<AliasInfo> *latticeElement =
|
|
+ analysis.getLatticeElement(value);
|
|
+ if (latticeElement && !latticeElement->isUninitialized()) {
|
|
+ AliasInfo &info = latticeElement->getValue();
|
|
+ for (auto alloc : info.getAllocs()) {
|
|
+ allocation->addAlias(value, alloc);
|
|
}
|
|
}
|
|
}
|
|
@@ -244,14 +241,19 @@ class AllocationAnalysis {
|
|
getScratchValueSize(op);
|
|
});
|
|
// Get the alias values
|
|
- SharedMemoryAliasAnalysis aliasAnalysis(operation->getContext());
|
|
- aliasAnalysis.run(operation);
|
|
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
|
+ SharedMemoryAliasAnalysis *aliasAnalysis =
|
|
+ solver->load<SharedMemoryAliasAnalysis>();
|
|
+ if (failed(solver->initializeAndRun(operation))) {
|
|
+ // TODO: return error instead of bailing out..
|
|
+ llvm_unreachable("failed to run SharedMemoryAliasAnalysis");
|
|
+ }
|
|
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
|
for (auto operand : op->getOperands()) {
|
|
- getValueAlias(operand, aliasAnalysis);
|
|
+ getValueAlias(operand, *aliasAnalysis);
|
|
}
|
|
for (auto value : op->getResults()) {
|
|
- getValueAlias(value, aliasAnalysis);
|
|
+ getValueAlias(value, *aliasAnalysis);
|
|
}
|
|
});
|
|
}
|
|
diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp
|
|
index 0b7142b04d..4af46c3fbb 100644
|
|
--- a/lib/Analysis/AxisInfo.cpp
|
|
+++ b/lib/Analysis/AxisInfo.cpp
|
|
@@ -1,4 +1,4 @@
|
|
-#include "mlir/Analysis/DataFlowAnalysis.h"
|
|
+#include "mlir/Analysis/DataFlowFramework.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
@@ -52,7 +52,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
|
BlockArgument blockArg = value.dyn_cast<BlockArgument>();
|
|
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
|
|
Operation *op = blockArg.getOwner()->getParentOp();
|
|
- if (FuncOp fun = dyn_cast<FuncOp>(op)) {
|
|
+ if (func::FuncOp fun = dyn_cast<func::FuncOp>(op)) {
|
|
Attribute attr =
|
|
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
|
if (attr)
|
|
@@ -136,8 +136,9 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
|
public:
|
|
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
|
|
|
- AxisInfo getAxisInfo(OpTy op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
|
+ AxisInfo
|
|
+ getAxisInfo(OpTy op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
|
return operands[0]->getValue();
|
|
}
|
|
};
|
|
@@ -147,8 +148,9 @@ class MakeRangeOpAxisInfoVisitor final
|
|
public:
|
|
using AxisInfoVisitorImpl<triton::MakeRangeOp>::AxisInfoVisitorImpl;
|
|
|
|
- AxisInfo getAxisInfo(triton::MakeRangeOp op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
|
+ AxisInfo
|
|
+ getAxisInfo(triton::MakeRangeOp op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
|
auto start = op.start();
|
|
auto end = op.end();
|
|
return AxisInfo(/*contiguity=*/{end - start},
|
|
@@ -162,8 +164,9 @@ class ConstantOpAxisInfoVisitor final
|
|
public:
|
|
using AxisInfoVisitorImpl<arith::ConstantOp>::AxisInfoVisitorImpl;
|
|
|
|
- AxisInfo getAxisInfo(arith::ConstantOp op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
|
+ AxisInfo
|
|
+ getAxisInfo(arith::ConstantOp op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
|
auto intAttr = op.getValue().dyn_cast<IntegerAttr>();
|
|
auto boolAttr = op.getValue().dyn_cast<BoolAttr>();
|
|
if (intAttr || boolAttr) {
|
|
@@ -416,8 +419,9 @@ class SplatOpAxisInfoVisitor final
|
|
public:
|
|
using AxisInfoVisitorImpl<triton::SplatOp>::AxisInfoVisitorImpl;
|
|
|
|
- AxisInfo getAxisInfo(triton::SplatOp op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
|
+ AxisInfo
|
|
+ getAxisInfo(triton::SplatOp op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
|
Type _retTy = *op->result_type_begin();
|
|
TensorType retTy = _retTy.cast<TensorType>();
|
|
AxisInfo opInfo = operands[0]->getValue();
|
|
@@ -439,8 +443,9 @@ class ExpandDimsOpAxisInfoVisitor final
|
|
public:
|
|
using AxisInfoVisitorImpl<triton::ExpandDimsOp>::AxisInfoVisitorImpl;
|
|
|
|
- AxisInfo getAxisInfo(triton::ExpandDimsOp op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
|
+ AxisInfo
|
|
+ getAxisInfo(triton::ExpandDimsOp op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
|
AxisInfo opInfo = operands[0]->getValue();
|
|
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
|
|
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
|
|
@@ -458,8 +463,9 @@ class BroadcastOpAxisInfoVisitor final
|
|
public:
|
|
using AxisInfoVisitorImpl<triton::BroadcastOp>::AxisInfoVisitorImpl;
|
|
|
|
- AxisInfo getAxisInfo(triton::BroadcastOp op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
|
+ AxisInfo
|
|
+ getAxisInfo(triton::BroadcastOp op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
|
Type _retTy = *op->result_type_begin();
|
|
Type _opTy = *op->operand_type_begin();
|
|
TensorType retTy = _retTy.cast<TensorType>();
|
|
@@ -486,8 +492,9 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
|
public:
|
|
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
|
|
|
- AxisInfo getAxisInfo(OpTy op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
|
+ AxisInfo
|
|
+ getAxisInfo(OpTy op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
|
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
|
|
if (!resTy)
|
|
return AxisInfo();
|
|
@@ -596,8 +603,9 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
|
public:
|
|
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
|
|
|
- AxisInfo getAxisInfo(OpTy op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
|
+ AxisInfo
|
|
+ getAxisInfo(OpTy op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
|
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
|
|
if (!resTy)
|
|
return AxisInfo();
|
|
@@ -757,8 +765,9 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
|
public:
|
|
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
|
|
|
|
- AxisInfo getAxisInfo(OpTy op,
|
|
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
|
|
+ AxisInfo
|
|
+ getAxisInfo(OpTy op,
|
|
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
|
auto lhsInfo = operands[0]->getValue();
|
|
auto rhsInfo = operands[1]->getValue();
|
|
std::optional<int64_t> constantValue;
|
|
@@ -786,8 +795,8 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
|
// AxisInfoAnalysis
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
-AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
|
|
- : ForwardDataFlowAnalysis<AxisInfo>(context) {
|
|
+AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
|
|
+ : dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) {
|
|
// UnrealizedConversionCast:
|
|
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
|
// in the process of a PartialConversion, where UnrealizedConversionCast
|
|
@@ -819,7 +828,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
|
|
visitors.append<LogicalOpAxisInfoVisitor<arith::AndIOp>,
|
|
LogicalOpAxisInfoVisitor<arith::OrIOp>,
|
|
LogicalOpAxisInfoVisitor<arith::XOrIOp>>();
|
|
- visitors.append<SelectOpAxisInfoVisitor<mlir::SelectOp>,
|
|
+ visitors.append<SelectOpAxisInfoVisitor<mlir::arith::SelectOp>,
|
|
SelectOpAxisInfoVisitor<triton::gpu::SelectOp>>();
|
|
visitors.append<ShLIOpAxisInfoVisitor, ShROpAxisInfoVisitor<arith::ShRUIOp>,
|
|
ShROpAxisInfoVisitor<arith::ShRSIOp>>();
|
|
@@ -829,11 +838,12 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
|
|
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
|
|
}
|
|
|
|
-ChangeResult AxisInfoAnalysis::visitOperation(
|
|
- Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
|
+void AxisInfoAnalysis::visitOperation(
|
|
+ Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
|
|
+ ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
|
|
AxisInfo curr = visitors.apply(op, operands);
|
|
if (curr.getRank() == 0) {
|
|
- return markAllPessimisticFixpoint(op->getResults());
|
|
+ return markAllPessimisticFixpoint(results);
|
|
}
|
|
// override with hint
|
|
auto newContiguity = curr.getContiguity();
|
|
@@ -854,11 +864,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|
curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy,
|
|
curr.getConstantValue());
|
|
// join all lattice elements
|
|
- ChangeResult result = ChangeResult::NoChange;
|
|
- for (Value value : op->getResults()) {
|
|
- result |= getLatticeElement(value).join(curr);
|
|
- }
|
|
- return result;
|
|
+ for (auto *result : results)
|
|
+ propagateIfChanged(result, result->join(curr));
|
|
}
|
|
|
|
unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) {
|
|
@@ -884,7 +891,10 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
|
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
|
if (!tensorTy)
|
|
return 1;
|
|
- auto axisInfo = lookupLatticeElement(ptr)->getValue();
|
|
+ dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(ptr);
|
|
+ if (!latticeElement || latticeElement->isUninitialized())
|
|
+ return 1;
|
|
+ auto axisInfo = latticeElement->getValue();
|
|
auto layout = tensorTy.getEncoding();
|
|
auto order = triton::gpu::getOrder(layout);
|
|
auto maxMultipleBytes = axisInfo.getDivisibility(order[0]);
|
|
@@ -900,8 +910,11 @@ unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
|
|
auto tensorTy = mask.getType().dyn_cast<RankedTensorType>();
|
|
if (!tensorTy)
|
|
return 1;
|
|
+ dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(mask);
|
|
+ if (!latticeElement || latticeElement->isUninitialized())
|
|
+ return 1;
|
|
+ auto maskAxis = latticeElement->getValue();
|
|
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
|
|
- auto maskAxis = lookupLatticeElement(mask)->getValue();
|
|
auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1);
|
|
return alignment;
|
|
}
|
|
diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt
|
|
index afbc692510..1f761f845c 100644
|
|
--- a/lib/Analysis/CMakeLists.txt
|
|
+++ b/lib/Analysis/CMakeLists.txt
|
|
@@ -8,7 +8,7 @@ add_mlir_library(TritonAnalysis
|
|
DEPENDS
|
|
TritonTableGen
|
|
TritonGPUAttrDefsIncGen
|
|
-
|
|
+
|
|
LINK_LIBS PUBLIC
|
|
MLIRAnalysis
|
|
)
|
|
diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp
|
|
index acc885e827..910274b2ac 100644
|
|
--- a/lib/Analysis/Membar.cpp
|
|
+++ b/lib/Analysis/Membar.cpp
|
|
@@ -2,7 +2,7 @@
|
|
#include "triton/Analysis/Alias.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
|
|
-#include "mlir/Dialect/GPU/GPUDialect.h"
|
|
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
|
|
namespace mlir {
|
|
diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp
|
|
index d9e917e731..6ea52df272 100644
|
|
--- a/lib/Analysis/Utility.cpp
|
|
+++ b/lib/Analysis/Utility.cpp
|
|
@@ -1,5 +1,8 @@
|
|
#include "triton/Analysis/Utility.h"
|
|
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
|
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
+#include "mlir/IR/Matchers.h"
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include <deque>
|
|
@@ -325,4 +328,55 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
|
|
return multiRootTopologicalSort(slice);
|
|
}
|
|
|
|
+namespace {
|
|
+// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
|
|
+// interacts with constant propagation, but SparseConstantPropagation
|
|
+// doesn't seem to be sufficient.
|
|
+struct ConstantAnalysis : public DataFlowAnalysis {
|
|
+ using DataFlowAnalysis::DataFlowAnalysis;
|
|
+
|
|
+ LogicalResult initialize(Operation *top) override {
|
|
+ WalkResult result = top->walk([&](Operation *op) {
|
|
+ if (failed(visit(op)))
|
|
+ return WalkResult::interrupt();
|
|
+ return WalkResult::advance();
|
|
+ });
|
|
+ return success(!result.wasInterrupted());
|
|
+ }
|
|
+
|
|
+ LogicalResult visit(ProgramPoint point) override {
|
|
+ Operation *op = point.get<Operation *>();
|
|
+ Attribute value;
|
|
+ if (matchPattern(op, m_Constant(&value))) {
|
|
+ auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
|
|
+ op->getResult(0));
|
|
+ propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
|
|
+ value, op->getDialect())));
|
|
+ return success();
|
|
+ }
|
|
+ setAllToUnknownConstants(op->getResults());
|
|
+ for (Region ®ion : op->getRegions())
|
|
+ setAllToUnknownConstants(region.getArguments());
|
|
+ return success();
|
|
+ }
|
|
+
|
|
+ /// Set all given values as not constants.
|
|
+ void setAllToUnknownConstants(ValueRange values) {
|
|
+ dataflow::ConstantValue unknownConstant(nullptr, nullptr);
|
|
+ for (Value value : values) {
|
|
+ auto *constant =
|
|
+ getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
|
|
+ propagateIfChanged(constant, constant->join(unknownConstant));
|
|
+ }
|
|
+ }
|
|
+};
|
|
+} // namespace
|
|
+
|
|
+std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
|
|
+ auto solver = std::make_unique<DataFlowSolver>();
|
|
+ solver->load<dataflow::DeadCodeAnalysis>();
|
|
+ solver->load<ConstantAnalysis>();
|
|
+ return solver;
|
|
+}
|
|
+
|
|
} // namespace mlir
|
|
diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
|
|
index 6a46265bd7..e352eb3698 100644
|
|
--- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
|
|
+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
|
|
@@ -159,9 +159,6 @@ struct ConvertLayoutOpConversion
|
|
Value smemBase) const {
|
|
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
|
auto layout = type.getEncoding();
|
|
- auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
|
- auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
|
|
- auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
|
|
auto rank = type.getRank();
|
|
auto sizePerThread = getSizePerThread(layout);
|
|
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
|
diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h
|
|
index 4b89965aa9..1d9e00519b 100644
|
|
--- a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h
|
|
+++ b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h
|
|
@@ -7,10 +7,8 @@
|
|
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
|
-#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
|
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
-#include "mlir/Dialect/GPU/GPUDialect.h"
|
|
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
@@ -422,9 +420,9 @@ struct MMA16816ConversionHelper {
|
|
MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout,
|
|
Value thread, ConversionPatternRewriter &rewriter,
|
|
TypeConverter *typeConverter, Location loc)
|
|
- : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
|
|
- rewriter(rewriter), typeConverter(typeConverter), loc(loc),
|
|
- ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) {
|
|
+ : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()), thread(thread),
|
|
+ helper(mmaLayout), rewriter(rewriter), typeConverter(typeConverter),
|
|
+ loc(loc), ctx(mmaLayout.getContext()) {
|
|
helper.deduceMmaType(dotOperand);
|
|
|
|
Value _32 = i32_val(32);
|
|
diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
|
|
index 0f8070ca9f..e4bd47c411 100644
|
|
--- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
|
|
+++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
|
|
@@ -115,8 +115,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|
auto DTensorTy = D.getType().cast<RankedTensorType>();
|
|
auto AShape = ATensorTy.getShape();
|
|
auto BShape = BTensorTy.getShape();
|
|
- auto DShape = DTensorTy.getShape();
|
|
- auto wpt = mmaLayout.getWarpsPerCTA();
|
|
|
|
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
|
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
|
@@ -221,7 +219,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto *ctx = rewriter.getContext();
|
|
auto loc = op.getLoc();
|
|
- auto threadId = getThreadId(rewriter, loc);
|
|
|
|
auto A = op.a();
|
|
auto B = op.b();
|
|
@@ -230,12 +227,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|
|
|
auto aTensorTy = A.getType().cast<RankedTensorType>();
|
|
auto bTensorTy = B.getType().cast<RankedTensorType>();
|
|
- auto cTensorTy = C.getType().cast<RankedTensorType>();
|
|
auto dTensorTy = D.getType().cast<RankedTensorType>();
|
|
|
|
auto aShape = aTensorTy.getShape();
|
|
auto bShape = bTensorTy.getShape();
|
|
- auto cShape = cTensorTy.getShape();
|
|
|
|
BlockedEncodingAttr dLayout =
|
|
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
|
|
index deb71b9597..0b9e67674b 100644
|
|
--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
|
|
+++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
|
|
@@ -61,7 +61,6 @@ struct FpToFpOpConversion
|
|
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
|
const Value &v0, const Value &v1, const Value &v2,
|
|
const Value &v3) {
|
|
- auto ctx = rewriter.getContext();
|
|
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
|
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
|
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
|
@@ -153,7 +152,6 @@ struct FpToFpOpConversion
|
|
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
|
const Value &v0, const Value &v1, const Value &v2,
|
|
const Value &v3) {
|
|
- auto ctx = rewriter.getContext();
|
|
auto bf16x2VecTy = vec_ty(i16_ty, 2);
|
|
Value bf16x2Vec0 = undef(bf16x2VecTy);
|
|
Value bf16x2Vec1 = undef(bf16x2VecTy);
|
|
diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
|
|
index 9a8b4702bc..bae675f0cb 100644
|
|
--- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
|
|
+++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
|
|
@@ -109,7 +109,8 @@ struct LoadOpConversion
|
|
DenseElementsAttr constAttr;
|
|
int64_t splatVal = 0;
|
|
if (other && valueElemTy.isa<IntegerType>() &&
|
|
- matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) {
|
|
+ matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() &&
|
|
+ constAttr.getElementType().isa<IntegerType>()) {
|
|
otherIsSplatConstInt = true;
|
|
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
|
}
|
|
@@ -333,7 +334,6 @@ struct StoreOpConversion
|
|
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
|
elem = bitcast(elem, valueElemTy);
|
|
|
|
- Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
|
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
|
|
}
|
|
llWord = bitcast(llWord, valArgTy);
|
|
@@ -387,7 +387,6 @@ struct AtomicCASOpConversion
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
- Value ptr = op.ptr();
|
|
|
|
Value llPtr = adaptor.ptr();
|
|
Value llCmp = adaptor.cmp();
|
|
diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
|
|
index 69abd889be..1c973dc196 100644
|
|
--- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
|
|
+++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
|
|
@@ -286,7 +286,6 @@ struct ReduceOpConversion
|
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
|
auto srcLayout = srcTy.getEncoding();
|
|
auto srcShape = srcTy.getShape();
|
|
- auto srcRank = srcTy.getRank();
|
|
auto order = getOrder(srcLayout);
|
|
|
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
|
@@ -351,7 +350,6 @@ struct ReduceOpConversion
|
|
|
|
Value zero = i32_val(0);
|
|
Value laneZero = icmp_eq(laneIdAxis, zero);
|
|
- Value warpZero = icmp_eq(warpIdAxis, zero);
|
|
|
|
for (auto it : accs) {
|
|
const SmallVector<unsigned> &key = it.first;
|
|
diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
|
|
index 5b77150b1a..78cfa076bd 100644
|
|
--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
|
|
+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
|
|
@@ -11,11 +11,11 @@ using ::mlir::LLVM::getStructFromElements;
|
|
using ::mlir::triton::gpu::getElemsPerThread;
|
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
|
|
|
-struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
|
- using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
|
|
+struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
|
|
+ using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
- matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
|
|
+ matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
unsigned numArguments = op.getNumOperands();
|
|
|
|
@@ -476,7 +476,6 @@ struct ExtractSliceOpConversion
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
|
- auto resTy = op.getType().dyn_cast<RankedTensorType>();
|
|
smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset),
|
|
strideVals, offsetVals);
|
|
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
|
diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
|
|
index bb10d5b24a..00e399f848 100644
|
|
--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
|
|
+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
|
|
@@ -4,6 +4,7 @@
|
|
// TODO: refactor so that it doesn't fail if Allocation.h
|
|
// is included after utility.h (due to conflict in `store` macro
|
|
// and <atomic>
|
|
+#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "triton/Analysis/Allocation.h"
|
|
|
|
//
|
|
@@ -39,15 +40,15 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
|
// TODO(Superjomn): remove the code when MLIR v15.0 is included.
|
|
// All the rights are reserved by the LLVM community.
|
|
|
|
-struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
|
+struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
|
|
private:
|
|
/// Only retain those attributes that are not constructed by
|
|
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
|
/// attributes.
|
|
- static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
|
- bool filterArgAttrs,
|
|
+ static void filterFuncAttributes(func::FuncOp op, bool filterArgAttrs,
|
|
SmallVectorImpl<NamedAttribute> &result) {
|
|
- for (const auto &attr : attrs) {
|
|
+
|
|
+ for (const auto &attr : op->getAttrs()) {
|
|
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
|
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
|
|
attr.getName() == "std.varargs" ||
|
|
@@ -65,27 +66,27 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
|
}
|
|
|
|
protected:
|
|
- using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
|
+ using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
|
|
|
|
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
|
|
// to this legalization pattern.
|
|
LLVM::LLVMFuncOp
|
|
- convertFuncOpToLLVMFuncOp(FuncOp funcOp,
|
|
+ convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
// Convert the original function arguments. They are converted using the
|
|
// LLVMTypeConverter provided to this legalization pattern.
|
|
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
|
|
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
|
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
|
- funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
|
+ funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
|
|
+ result);
|
|
if (!llvmType)
|
|
return nullptr;
|
|
|
|
// Propagate argument/result attributes to all converted arguments/result
|
|
// obtained after converting a given original argument/result.
|
|
SmallVector<NamedAttribute, 4> attributes;
|
|
- filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
|
|
- attributes);
|
|
+ filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, attributes);
|
|
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
|
|
assert(!resAttrDicts.empty() && "expected array to be non-empty");
|
|
auto newResAttrDicts =
|
|
@@ -131,7 +132,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
|
}
|
|
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
|
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
|
- /*dsoLocal*/ false, attributes);
|
|
+ /*dsoLocal*/ false, LLVM::CConv::C, attributes);
|
|
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
|
newFuncOp.end());
|
|
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
|
@@ -191,8 +192,8 @@ class ConvertTritonGPUOpToLLVMPatternBase {
|
|
const Allocation *allocation,
|
|
Value smem,
|
|
IndexCacheInfo indexCacheInfo)
|
|
- : converter(&typeConverter), indexCacheInfo(indexCacheInfo),
|
|
- allocation(allocation), smem(smem) {}
|
|
+ : converter(&typeConverter), allocation(allocation), smem(smem),
|
|
+ indexCacheInfo(indexCacheInfo) {}
|
|
|
|
LLVMTypeConverter *getTypeConverter() const { return converter; }
|
|
|
|
@@ -861,7 +862,6 @@ class ConvertTritonGPUOpToLLVMPatternBase {
|
|
ArrayRef<int64_t> shape) const {
|
|
auto parent = sliceLayout.getParent();
|
|
unsigned dim = sliceLayout.getDim();
|
|
- size_t rank = shape.size();
|
|
auto parentIndices =
|
|
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
|
|
unsigned numIndices = parentIndices.size();
|
|
diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
|
|
index ff1af09835..6f66af4e34 100644
|
|
--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
|
|
+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
|
|
@@ -1,10 +1,11 @@
|
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
|
|
|
+#include "mlir/Analysis/DataFlowFramework.h"
|
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
|
+#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
|
|
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
|
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
|
-#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
|
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
@@ -40,7 +41,6 @@ class TritonLLVMConversionTarget : public ConversionTarget {
|
|
addIllegalDialect<triton::TritonDialect>();
|
|
addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
|
addIllegalDialect<mlir::gpu::GPUDialect>();
|
|
- addIllegalDialect<mlir::StandardOpsDialect>();
|
|
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
|
}
|
|
};
|
|
@@ -51,7 +51,7 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
|
: ConversionTarget(ctx) {
|
|
addLegalDialect<LLVM::LLVMDialect>();
|
|
addLegalDialect<NVVM::NVVMDialect>();
|
|
- addIllegalOp<mlir::FuncOp>();
|
|
+ addIllegalOp<mlir::func::FuncOp>();
|
|
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
|
}
|
|
};
|
|
@@ -69,7 +69,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
|
: FuncOpConversionBase(converter, benefit), numWarps(numWarps) {}
|
|
|
|
LogicalResult
|
|
- matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
|
|
+ matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
|
if (!newFuncOp)
|
|
@@ -133,7 +133,8 @@ class ConvertTritonGPUToLLVM
|
|
decomposeBlockedToDotOperand(mod);
|
|
|
|
// Step 2
|
|
- decomposeInsertSliceAsyncOp(mod);
|
|
+ if (failed(decomposeInsertSliceAsyncOp(mod)))
|
|
+ return signalPassFailure();
|
|
|
|
// Step 3
|
|
Allocation allocation(mod);
|
|
@@ -142,7 +143,7 @@ class ConvertTritonGPUToLLVM
|
|
|
|
// Step 4
|
|
RewritePatternSet scf_patterns(context);
|
|
- mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
|
+ mlir::populateSCFToControlFlowConversionPatterns(scf_patterns);
|
|
mlir::ConversionTarget scf_target(*context);
|
|
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
|
|
scf::WhileOp, scf::ExecuteRegionOp>();
|
|
@@ -159,8 +160,10 @@ class ConvertTritonGPUToLLVM
|
|
return signalPassFailure();
|
|
|
|
// Step 6 - get axis and shared memory info
|
|
- AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
|
- axisInfoAnalysis.run(mod);
|
|
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
|
+ AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
|
+ if (failed(solver->initializeAndRun(mod)))
|
|
+ return signalPassFailure();
|
|
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
|
|
mod->setAttr("triton_gpu.shared",
|
|
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
|
|
@@ -178,38 +181,39 @@ class ConvertTritonGPUToLLVM
|
|
|
|
// Normal conversions
|
|
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
|
|
- axisInfoAnalysis, &allocation, smem,
|
|
+ *axisInfoAnalysis, &allocation, smem,
|
|
indexCacheInfo, /*benefit=*/10);
|
|
// ConvertLayoutOp
|
|
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
|
- axisInfoAnalysis, &allocation, smem,
|
|
+ *axisInfoAnalysis, &allocation, smem,
|
|
indexCacheInfo, /*benefit=*/10);
|
|
// DotOp
|
|
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
|
- axisInfoAnalysis, &allocation, smem,
|
|
+ *axisInfoAnalysis, &allocation, smem,
|
|
/*benefit=*/10);
|
|
// ElementwiseOp
|
|
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
|
- axisInfoAnalysis, &allocation, smem,
|
|
+ *axisInfoAnalysis, &allocation, smem,
|
|
/*benefit=*/10);
|
|
// LoadStoreOp
|
|
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
|
- axisInfoAnalysis, &allocation, smem,
|
|
+ *axisInfoAnalysis, &allocation, smem,
|
|
indexCacheInfo, /*benefit=*/10);
|
|
// ReduceOp
|
|
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
|
- axisInfoAnalysis, &allocation, smem,
|
|
+ *axisInfoAnalysis, &allocation, smem,
|
|
indexCacheInfo, /*benefit=*/10);
|
|
// ViewOp
|
|
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
|
- axisInfoAnalysis, &allocation, smem,
|
|
+ *axisInfoAnalysis, &allocation, smem,
|
|
/*benefit=*/10);
|
|
|
|
// Add arith/math's patterns to help convert scalar expression to LLVM.
|
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
|
patterns);
|
|
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
|
- mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
|
+ mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
|
+ patterns);
|
|
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
|
|
|
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
|
@@ -306,9 +310,11 @@ class ConvertTritonGPUToLLVM
|
|
});
|
|
}
|
|
|
|
- void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
|
|
- AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
|
- axisInfoAnalysis.run(mod);
|
|
+ LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const {
|
|
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
|
+ AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
|
+ if (failed(solver->initializeAndRun(mod)))
|
|
+ return failure();
|
|
// TODO(Keren): This is a hacky knob that may cause performance regression
|
|
// when decomposition has been performed. We should remove this knob once we
|
|
// have thorough analysis on async wait. Currently, we decompose
|
|
@@ -342,7 +348,7 @@ class ConvertTritonGPUToLLVM
|
|
auto resSharedLayout =
|
|
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
|
auto resElemTy = dstTy.getElementType();
|
|
- unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
|
|
+ unsigned inVec = axisInfoAnalysis->getPtrContiguity(src);
|
|
unsigned outVec = resSharedLayout.getVec();
|
|
unsigned minVec = std::min(outVec, inVec);
|
|
auto maxBitWidth =
|
|
@@ -400,11 +406,11 @@ class ConvertTritonGPUToLLVM
|
|
} else if (decomposed) {
|
|
// Wait for all previous async ops
|
|
OpBuilder builder(asyncWaitOp);
|
|
- auto newAsyncWaitOp =
|
|
- builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
|
+ builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
|
asyncWaitOp.erase();
|
|
}
|
|
});
|
|
+ return success();
|
|
}
|
|
};
|
|
|
|
diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h
|
|
index d35dac28c5..11976908cf 100644
|
|
--- a/lib/Conversion/TritonGPUToLLVM/Utility.h
|
|
+++ b/lib/Conversion/TritonGPUToLLVM/Utility.h
|
|
@@ -220,10 +220,7 @@ struct SharedMemoryObject {
|
|
ConversionPatternRewriter &rewriter)
|
|
: base(base) {
|
|
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
|
|
-
|
|
- for (auto idx : order) {
|
|
- offsets.emplace_back(i32_val(0));
|
|
- }
|
|
+ offsets.append(order.size(), i32_val(0));
|
|
}
|
|
|
|
SmallVector<Value> getElems() const {
|
|
diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
|
|
index fe42202c34..5f230f787f 100644
|
|
--- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
|
|
+++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
|
|
@@ -1,10 +1,10 @@
|
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
-#include "mlir/Dialect/GPU/GPUDialect.h"
|
|
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
|
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
@@ -59,10 +59,13 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
|
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
|
assert(value);
|
|
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
|
- op, retType,
|
|
- value.reshape(retType) // This is a hack. We just want to add encoding
|
|
- );
|
|
+ if (value.getElementType().isInteger(1) && value.isSplat())
|
|
+ // Workaround until https://reviews.llvm.org/D133743 is included.
|
|
+ value = DenseElementsAttr::get(retType, value.getSplatValue<bool>());
|
|
+ else
|
|
+ // This is a hack. We just want to add encoding
|
|
+ value = value.reshape(retType);
|
|
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value);
|
|
return success();
|
|
}
|
|
};
|
|
@@ -127,12 +130,12 @@ void populateArithmeticPatternsAndLegality(
|
|
}
|
|
|
|
// this shouldn't exist if mlir's SelectOp checked encodings properly
|
|
-class StdSelectPattern : public OpConversionPattern<SelectOp> {
|
|
+class StdSelectPattern : public OpConversionPattern<arith::SelectOp> {
|
|
public:
|
|
- using OpConversionPattern<SelectOp>::OpConversionPattern;
|
|
+ using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
- matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
|
|
+ matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
|
@@ -148,8 +151,8 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
|
MLIRContext *context = patterns.getContext();
|
|
// Rewrite rule
|
|
patterns.add<StdSelectPattern>(typeConverter, context);
|
|
- target.addLegalOp<ReturnOp>(); // this is ok because all functions are inlined
|
|
- // by the frontend
|
|
+ target.addLegalOp<func::ReturnOp>(); // this is ok because all functions are
|
|
+ // inlined by the frontend
|
|
}
|
|
|
|
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
|
@@ -455,18 +458,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
|
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
MLIRContext *context = patterns.getContext();
|
|
- patterns.add< // TODO: view should have custom pattern that views the layout
|
|
- TritonGenericPattern<triton::ViewOp>,
|
|
- TritonGenericPattern<triton::BitcastOp>,
|
|
- TritonGenericPattern<triton::FpToFpOp>,
|
|
- TritonGenericPattern<triton::IntToPtrOp>,
|
|
- TritonGenericPattern<triton::PtrToIntOp>,
|
|
- TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
|
- TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
|
- TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
|
- TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
|
- TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
|
- TritonAtomicRMWPattern>(typeConverter, context);
|
|
+ patterns
|
|
+ .insert< // TODO: view should have custom pattern that views the layout
|
|
+ TritonGenericPattern<triton::ViewOp>,
|
|
+ TritonGenericPattern<triton::BitcastOp>,
|
|
+ TritonGenericPattern<triton::FpToFpOp>,
|
|
+ TritonGenericPattern<triton::IntToPtrOp>,
|
|
+ TritonGenericPattern<triton::PtrToIntOp>,
|
|
+ TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
|
+ TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
|
+ TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
|
+ TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
|
+ TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
|
+ TritonAtomicRMWPattern>(typeConverter, context);
|
|
}
|
|
|
|
//
|
|
@@ -623,29 +627,28 @@ void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
|
|
|
// CF
|
|
|
|
-class CFBranchPattern : public OpConversionPattern<BranchOp> {
|
|
+class CFBranchPattern : public OpConversionPattern<cf::BranchOp> {
|
|
public:
|
|
- using OpConversionPattern<BranchOp>::OpConversionPattern;
|
|
+ using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
- matchAndRewrite(BranchOp op, BranchOp::Adaptor adaptor,
|
|
+ matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
- auto converter = getTypeConverter();
|
|
- auto newOp = rewriter.replaceOpWithNewOp<BranchOp>(op, op.getSuccessor(),
|
|
- adaptor.getOperands());
|
|
+ auto newOp = rewriter.replaceOpWithNewOp<cf::BranchOp>(
|
|
+ op, op.getSuccessor(), adaptor.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
-class CFCondBranchPattern : public OpConversionPattern<CondBranchOp> {
|
|
+class CFCondBranchPattern : public OpConversionPattern<cf::CondBranchOp> {
|
|
public:
|
|
- using OpConversionPattern<CondBranchOp>::OpConversionPattern;
|
|
+ using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
- matchAndRewrite(CondBranchOp op, CondBranchOp::Adaptor adaptor,
|
|
+ matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto converter = getTypeConverter();
|
|
- auto newOp = rewriter.replaceOpWithNewOp<CondBranchOp>(
|
|
+ auto newOp = rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
|
|
op, adaptor.getCondition(), op.getTrueDest(),
|
|
adaptor.getTrueDestOperands(), op.getFalseDest(),
|
|
adaptor.getFalseDestOperands());
|
|
diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt
|
|
index 2d679b21fd..705554ba6b 100644
|
|
--- a/lib/Dialect/Triton/IR/CMakeLists.txt
|
|
+++ b/lib/Dialect/Triton/IR/CMakeLists.txt
|
|
@@ -10,11 +10,7 @@ add_mlir_dialect_library(TritonIR
|
|
|
|
LINK_LIBS PUBLIC
|
|
MLIRIR
|
|
- MLIRArithmetic
|
|
- MLIRSCF
|
|
-
|
|
- # Since LLVM 15
|
|
- # MLIRFunc
|
|
- # else
|
|
- MLIRStandard
|
|
+ MLIRArithmeticDialect
|
|
+ MLIRSCFDialect
|
|
+ MLIRFuncDialect
|
|
)
|
|
diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp
|
|
index 3aadbfa0c0..86570359c5 100644
|
|
--- a/lib/Dialect/Triton/IR/Ops.cpp
|
|
+++ b/lib/Dialect/Triton/IR/Ops.cpp
|
|
@@ -1,10 +1,9 @@
|
|
-#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
-#include "triton/Dialect/Triton/IR/Types.h"
|
|
-
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
+#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
+#include "triton/Dialect/Triton/IR/Types.h"
|
|
|
|
namespace mlir {
|
|
namespace triton {
|
|
@@ -38,8 +37,8 @@ static Type getPointerTypeSameShape(Type type) {
|
|
}
|
|
|
|
// Parser & printer for assembly forms
|
|
-ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
|
- SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
|
+ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
|
Type resultTypes[1];
|
|
SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperandList(allOperands) ||
|
|
@@ -73,18 +72,18 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
|
return success();
|
|
}
|
|
|
|
-void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
|
|
+void LoadOp::print(OpAsmPrinter &printer) {
|
|
printer << " ";
|
|
- printer << loadOp.getOperation()->getOperands();
|
|
+ printer << getOperation()->getOperands();
|
|
// "operand_segment_sizes" can be deduced, so we don't print it.
|
|
- printer.printOptionalAttrDict(loadOp->getAttrs(),
|
|
- {loadOp.operand_segment_sizesAttrName()});
|
|
+ printer.printOptionalAttrDict(getOperation()->getAttrs(),
|
|
+ {operand_segment_sizesAttrName()});
|
|
printer << " : ";
|
|
- printer.printStrippedAttrOrType(loadOp.result().getType());
|
|
+ printer.printStrippedAttrOrType(getResult().getType());
|
|
}
|
|
|
|
-ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
|
- SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
|
+ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
|
Type valueType;
|
|
SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperandList(allOperands) ||
|
|
@@ -104,12 +103,12 @@ ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
|
return success();
|
|
}
|
|
|
|
-void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
|
|
+void StoreOp::print(OpAsmPrinter &printer) {
|
|
printer << " ";
|
|
- printer << storeOp.getOperation()->getOperands();
|
|
- printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{});
|
|
+ printer << getOperation()->getOperands();
|
|
+ printer.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{});
|
|
printer << " : ";
|
|
- printer.printStrippedAttrOrType(storeOp.value().getType());
|
|
+ printer.printStrippedAttrOrType(value().getType());
|
|
}
|
|
|
|
} // namespace triton
|
|
@@ -319,7 +318,8 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
|
if (!constOperand)
|
|
return {};
|
|
auto shapedType = getType().cast<ShapedType>();
|
|
- auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
|
+ auto ret = SplatElementsAttr::get(
|
|
+ shapedType, ArrayRef<Attribute>(constOperand.getValue()));
|
|
return ret;
|
|
}
|
|
|
|
diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp
|
|
index 2261472170..11570283d6 100644
|
|
--- a/lib/Dialect/Triton/Transforms/Combine.cpp
|
|
+++ b/lib/Dialect/Triton/Transforms/Combine.cpp
|
|
@@ -57,13 +57,13 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
|
|
class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
|
|
public:
|
|
CombineSelectMaskedLoadPattern(mlir::MLIRContext *context)
|
|
- : mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context,
|
|
- {triton::LoadOp::getOperationName()}) {}
|
|
+ : mlir::RewritePattern(mlir::arith::SelectOp::getOperationName(), 3,
|
|
+ context, {triton::LoadOp::getOperationName()}) {}
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mlir::Operation *op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
- auto selectOp = llvm::dyn_cast<mlir::SelectOp>(op);
|
|
+ auto selectOp = llvm::dyn_cast<mlir::arith::SelectOp>(op);
|
|
if (!selectOp)
|
|
return mlir::failure();
|
|
|
|
diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td
|
|
index 14f286b26e..ded0e346e6 100644
|
|
--- a/lib/Dialect/Triton/Transforms/Combine.td
|
|
+++ b/lib/Dialect/Triton/Transforms/Combine.td
|
|
@@ -1,9 +1,9 @@
|
|
#ifndef TRITON_PATTERNS
|
|
#define TRITON_PATTERNS
|
|
|
|
-include "mlir/Dialect/StandardOps/IR/Ops.td"
|
|
include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
|
|
include "triton/Dialect/Triton/IR/TritonOps.td"
|
|
+include "mlir/IR/PatternBase.td"
|
|
|
|
|
|
// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
|
|
diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp
|
|
index 1fbc609e88..bfc3f3d3da 100644
|
|
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
|
|
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
|
|
@@ -1,14 +1,14 @@
|
|
+#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
+
|
|
#include <numeric>
|
|
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "triton/Analysis/Utility.h"
|
|
-#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
+#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
-#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
|
-
|
|
using namespace mlir;
|
|
using namespace mlir::triton::gpu;
|
|
|
|
@@ -366,7 +366,6 @@ template SmallVector<int64_t>
|
|
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
|
|
|
|
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|
- size_t rank = shape.size();
|
|
auto parent = getParent();
|
|
return ::getElemsPerThread(parent, paddedShape(shape));
|
|
}
|
|
@@ -655,9 +654,9 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
|
// InsertSliceAsyncOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
-ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
|
- OperationState &result) {
|
|
- SmallVector<OpAsmParser::OperandType, 8> allOperands;
|
|
+ParseResult InsertSliceAsyncOp::parse(OpAsmParser &parser,
|
|
+ OperationState &result) {
|
|
+ SmallVector<OpAsmParser::UnresolvedOperand, 8> allOperands;
|
|
Type srcType, dstType;
|
|
SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperandList(allOperands) ||
|
|
@@ -696,18 +695,16 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
|
return success();
|
|
}
|
|
|
|
-void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
|
- InsertSliceAsyncOp insertSliceAsyncOp) {
|
|
+void InsertSliceAsyncOp::print(OpAsmPrinter &printer) {
|
|
printer << " ";
|
|
- printer << insertSliceAsyncOp.getOperation()->getOperands();
|
|
+ printer << getOperation()->getOperands();
|
|
// "operand_segment_sizes" can be deduced, so we don't print it.
|
|
- printer.printOptionalAttrDict(
|
|
- insertSliceAsyncOp->getAttrs(),
|
|
- {insertSliceAsyncOp.operand_segment_sizesAttrName()});
|
|
+ printer.printOptionalAttrDict(getOperation()->getAttrs(),
|
|
+ {operand_segment_sizesAttrName()});
|
|
printer << " : ";
|
|
- printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
|
|
+ printer.printStrippedAttrOrType(src().getType());
|
|
printer << " -> ";
|
|
- printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
|
+ printer.printStrippedAttrOrType(result().getType());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
|
|
index 82407980d3..ee6009f44a 100644
|
|
--- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
|
|
+++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
|
|
@@ -27,7 +27,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|
auto origType = ptr.getType().cast<RankedTensorType>();
|
|
// Get the shape of the tensor.
|
|
size_t rank = origType.getRank();
|
|
- AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
|
+ dataflow::Lattice<AxisInfo> *latticeElement =
|
|
+ axisInfo.getLatticeElement(ptr);
|
|
+ AxisInfo info = latticeElement && !latticeElement->isUninitialized()
|
|
+ ? latticeElement->getValue()
|
|
+ : AxisInfo();
|
|
// Get the contiguity order of `ptr`
|
|
auto order = argSort(info.getContiguity());
|
|
// The desired divisibility is the maximum divisibility
|
|
@@ -40,7 +44,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|
for (Value val : op->getResults()) {
|
|
if (val.getType() != origType)
|
|
continue;
|
|
- auto valInfo = axisInfo.lookupLatticeElement(val);
|
|
+ auto valInfo = axisInfo.getLatticeElement(val);
|
|
auto currOrder = argSort(valInfo->getValue().getContiguity());
|
|
if (order == currOrder)
|
|
withSameOrder.insert(val);
|
|
@@ -55,7 +59,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
|
|
unsigned perThread = 1;
|
|
for (Value val : withSameOrder) {
|
|
- AxisInfo info = axisInfo.lookupLatticeElement(val)->getValue();
|
|
+ AxisInfo info = axisInfo.getLatticeElement(val)->getValue();
|
|
unsigned maxMultipleBytes = info.getDivisibility(order[0]);
|
|
unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u);
|
|
unsigned maxContig = info.getContiguity(order[0]);
|
|
@@ -123,8 +127,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
// Run axis info analysis
|
|
- AxisInfoAnalysis axisInfo(&getContext());
|
|
- axisInfo.run(op);
|
|
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
|
+ AxisInfoAnalysis *axisInfo = solver->load<AxisInfoAnalysis>();
|
|
+ if (failed(solver->initializeAndRun(op)))
|
|
+ return signalPassFailure();
|
|
|
|
// For each i/o operation, we determine what layout
|
|
// the pointers should have for best memory coalescing
|
|
@@ -146,10 +152,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
|
|
if (!ty || !ty.getElementType().isa<PointerType>())
|
|
return;
|
|
- AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
|
+ AxisInfo info = axisInfo->getLatticeElement(ptr)->getValue();
|
|
auto mod = curr->getParentOfType<ModuleOp>();
|
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
|
- auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
|
|
+ auto convertType = getTypeConverter(*axisInfo, ptr, numWarps);
|
|
layoutMap[ptr] = convertType;
|
|
});
|
|
|
|
diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp
|
|
index efa37ff2dc..089ce3996c 100644
|
|
--- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp
|
|
+++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp
|
|
@@ -1,6 +1,6 @@
|
|
#include "Utility.h"
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
-#include "mlir/Dialect/SCF/SCF.h"
|
|
+#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.td b/lib/Dialect/TritonGPU/Transforms/Combine.td
|
|
index 6bf1b14866..6a7b10dbcb 100644
|
|
--- a/lib/Dialect/TritonGPU/Transforms/Combine.td
|
|
+++ b/lib/Dialect/TritonGPU/Transforms/Combine.td
|
|
@@ -3,5 +3,6 @@
|
|
|
|
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
|
|
include "triton/Dialect/Triton/IR/TritonOps.td"
|
|
+include "mlir/IR/PatternBase.td"
|
|
|
|
#endif
|
|
diff --git a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp
|
|
index 4bd3bc76bf..b2f8defd81 100644
|
|
--- a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp
|
|
+++ b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp
|
|
@@ -1,5 +1,5 @@
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
-#include "mlir/Dialect/SCF/SCF.h"
|
|
+#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
|
|
index 9b2f42231e..85f746c1dc 100644
|
|
--- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
|
|
+++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
|
|
@@ -2,6 +2,7 @@
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "triton/Analysis/AxisInfo.h"
|
|
+#include "triton/Analysis/Utility.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
|
|
@@ -160,15 +161,18 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
|
LogicalResult LoopPipeliner::initialize() {
|
|
Block *loop = forOp.getBody();
|
|
|
|
- AxisInfoAnalysis axisInfoAnalysis(forOp.getContext());
|
|
- axisInfoAnalysis.run(forOp->getParentOfType<ModuleOp>());
|
|
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
|
+ AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
|
+ if (failed(solver->initializeAndRun(forOp->getParentOfType<ModuleOp>()))) {
|
|
+ return failure();
|
|
+ }
|
|
|
|
// can we use forOp.walk(...) here?
|
|
SmallVector<triton::LoadOp, 2> allLoads;
|
|
for (Operation &op : *loop)
|
|
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
|
|
auto ptr = loadOp.ptr();
|
|
- unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
|
|
+ unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr);
|
|
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
|
if (!tensorTy)
|
|
continue;
|
|
diff --git a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
|
|
index 0e7dbe5264..b95a4f50a6 100644
|
|
--- a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
|
|
+++ b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
|
|
@@ -1,5 +1,5 @@
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
-#include "mlir/Dialect/SCF/SCF.h"
|
|
+#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
|
|
index 37ac710995..762e887f36 100644
|
|
--- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
|
|
+++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
|
|
@@ -82,12 +82,12 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
|
scf::ReduceReturnOp>();
|
|
|
|
addDynamicallyLegalDialect<arith::ArithmeticDialect, math::MathDialect,
|
|
- triton::TritonDialect, StandardOpsDialect,
|
|
- scf::SCFDialect>([&](Operation *op) {
|
|
- if (typeConverter.isLegal(op))
|
|
- return true;
|
|
- return false;
|
|
- });
|
|
+ triton::TritonDialect, scf::SCFDialect>(
|
|
+ [&](Operation *op) {
|
|
+ if (typeConverter.isLegal(op))
|
|
+ return true;
|
|
+ return false;
|
|
+ });
|
|
|
|
// We have requirements for the data layouts
|
|
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
|
diff --git a/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp b/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp
|
|
index c229104286..c911fd4a5c 100644
|
|
--- a/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp
|
|
+++ b/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp
|
|
@@ -1,5 +1,5 @@
|
|
#include "Utility.h"
|
|
-#include "mlir/Dialect/SCF/SCF.h"
|
|
+#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
@@ -118,8 +118,8 @@ void setOpResultType(Operation *op, ArrayRef<Type> newTypes) {
|
|
.get("value")
|
|
.dyn_cast<mlir::DenseElementsAttr>();
|
|
if (attr) {
|
|
- auto newAttr = mlir::DenseElementsAttr::getFromRawBuffer(
|
|
- newType, attr.getRawData(), true);
|
|
+ auto newAttr =
|
|
+ mlir::DenseElementsAttr::getFromRawBuffer(newType, attr.getRawData());
|
|
op->setAttr("value", newAttr);
|
|
}
|
|
}
|
|
diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
|
|
index ed15f02f67..6400f1633a 100644
|
|
--- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp
|
|
+++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
|
|
@@ -1,5 +1,5 @@
|
|
#include "Utility.h"
|
|
-#include "mlir/Dialect/SCF/SCF.h"
|
|
+#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt
|
|
index f1bbd0bf4e..ac8973ad19 100644
|
|
--- a/lib/Target/LLVMIR/CMakeLists.txt
|
|
+++ b/lib/Target/LLVMIR/CMakeLists.txt
|
|
@@ -6,8 +6,7 @@ add_mlir_translation_library(TritonLLVMIR
|
|
|
|
LINK_LIBS PUBLIC
|
|
MLIRIR
|
|
- MLIRLLVMIR
|
|
- MLIRSCFToStandard
|
|
+ MLIRLLVMDialect
|
|
MLIRSupport
|
|
MLIRTargetLLVMIRExport
|
|
)
|
|
diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp
|
|
index 4cb0d8193c..6a5453a6e7 100644
|
|
--- a/lib/Target/PTX/PTXTranslation.cpp
|
|
+++ b/lib/Target/PTX/PTXTranslation.cpp
|
|
@@ -1,11 +1,14 @@
|
|
#include "triton/Target/PTX/PTXTranslation.h"
|
|
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
|
+#include <optional>
|
|
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/LegacyPassManager.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Verifier.h"
|
|
#include "llvm/MC/TargetRegistry.h"
|
|
+#include "llvm/Pass.h"
|
|
+#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/TargetSelect.h"
|
|
#include "llvm/Target/TargetMachine.h"
|
|
|
|
diff --git a/python/setup.py b/python/setup.py
|
|
index 2ac3accd25..4530b36714 100644
|
|
--- a/python/setup.py
|
|
+++ b/python/setup.py
|
|
@@ -57,19 +57,10 @@ def get_pybind11_package_info():
|
|
def get_llvm_package_info():
|
|
# download if nothing is installed
|
|
system = platform.system()
|
|
- if system == "Darwin":
|
|
- system_suffix = "apple-darwin"
|
|
- elif system == "Linux":
|
|
- vglibc = tuple(map(int, platform.libc_ver()[1].split('.')))
|
|
- vglibc = vglibc[0] * 100 + vglibc[1]
|
|
- linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
|
|
- system_suffix = f"linux-gnu-{linux_suffix}"
|
|
- else:
|
|
- raise RuntimeError(f"unsupported system: {system}")
|
|
+ system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
|
|
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
|
- release_suffix = "assert" if use_assert_enabled_llvm else "release"
|
|
- name = f'llvm+mlir-14.0.6-x86_64-{system_suffix}-{release_suffix}'
|
|
- url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-14.0.6-f28c006a5895/{name}.tar.xz"
|
|
+ name = 'llvm+mlir-15.0.7-x86_64-{}-{}'.format(system_suffix, "assert" if use_assert_enabled_llvm else "release")
|
|
+ url = "https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-15.0.7-8dfdcc7b7bf6/{}.tar.xz".format(name)
|
|
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
|
|
|
|
|
diff --git a/python/src/triton.cc b/python/src/triton.cc
|
|
index c40b117a55..f190eacc34 100644
|
|
--- a/python/src/triton.cc
|
|
+++ b/python/src/triton.cc
|
|
@@ -8,9 +8,10 @@
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
-#include "mlir/Parser.h"
|
|
+#include "mlir/Parser/Parser.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
|
|
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "triton/Analysis/Allocation.h"
|
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
|
@@ -195,7 +196,7 @@ void init_triton_ir(py::module &&m) {
|
|
std::string attrName = name + "_arg" + std::to_string(id);
|
|
mlir::Block *owner = arg.getOwner();
|
|
if (owner->isEntryBlock() &&
|
|
- !mlir::isa<mlir::FuncOp>(owner->getParentOp())) {
|
|
+ !mlir::isa<mlir::func::FuncOp>(owner->getParentOp())) {
|
|
owner->getParentOp()->setAttr(attrName, attr);
|
|
}
|
|
}
|
|
@@ -348,7 +349,7 @@ void init_triton_ir(py::module &&m) {
|
|
return str;
|
|
})
|
|
.def("push_back",
|
|
- [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
|
+ [](mlir::ModuleOp &self, mlir::func::FuncOp &funcOp) -> void {
|
|
self.push_back(funcOp);
|
|
})
|
|
.def("has_function",
|
|
@@ -358,16 +359,18 @@ void init_triton_ir(py::module &&m) {
|
|
return false;
|
|
})
|
|
.def("get_function",
|
|
- [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
|
|
- return self.lookupSymbol<mlir::FuncOp>(funcName);
|
|
- })
|
|
- .def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp {
|
|
- llvm::SmallVector<mlir::FuncOp> funcs;
|
|
- self.walk([&](mlir::FuncOp func) { funcs.push_back(func); });
|
|
- if (funcs.size() != 1)
|
|
- throw std::runtime_error("Expected a single function");
|
|
- return funcs[0];
|
|
- });
|
|
+ [](mlir::ModuleOp &self,
|
|
+ std::string &funcName) -> mlir::func::FuncOp {
|
|
+ return self.lookupSymbol<mlir::func::FuncOp>(funcName);
|
|
+ })
|
|
+ .def("get_single_function",
|
|
+ [](mlir::ModuleOp &self) -> mlir::func::FuncOp {
|
|
+ llvm::SmallVector<mlir::func::FuncOp> funcs;
|
|
+ self.walk([&](mlir::func::FuncOp func) { funcs.push_back(func); });
|
|
+ if (funcs.size() != 1)
|
|
+ throw std::runtime_error("Expected a single function");
|
|
+ return funcs[0];
|
|
+ });
|
|
|
|
m.def("make_attr",
|
|
[](const std::vector<int> &values, mlir::MLIRContext &context) {
|
|
@@ -388,47 +391,48 @@ void init_triton_ir(py::module &&m) {
|
|
registry.insert<mlir::triton::TritonDialect,
|
|
mlir::triton::gpu::TritonGPUDialect,
|
|
mlir::math::MathDialect, mlir::arith::ArithmeticDialect,
|
|
- mlir::StandardOpsDialect, mlir::scf::SCFDialect>();
|
|
+ mlir::func::FuncDialect, mlir::scf::SCFDialect>();
|
|
context.appendDialectRegistry(registry);
|
|
context.loadAllAvailableDialects();
|
|
|
|
// parse module
|
|
- mlir::OwningOpRef<mlir::ModuleOp> module(
|
|
- mlir::parseSourceFile(inputFilename, &context));
|
|
+ mlir::OwningOpRef<mlir::ModuleOp> module =
|
|
+ mlir::parseSourceFile<mlir::ModuleOp>(inputFilename, &context);
|
|
+ if (!module)
|
|
+ throw std::runtime_error("Parse MLIR file failed.");
|
|
// locations are incompatible with ptx < 7.5 !
|
|
module->walk([](mlir::Operation *op) {
|
|
op->setLoc(mlir::UnknownLoc::get(op->getContext()));
|
|
});
|
|
- if (!module)
|
|
- throw std::runtime_error("Parse MLIR file failed.");
|
|
|
|
return module->clone();
|
|
},
|
|
ret::take_ownership);
|
|
|
|
- py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
|
|
+ py::class_<mlir::func::FuncOp, mlir::OpState>(m, "function")
|
|
// .def_property_readonly("attrs", &ir::function::attrs)
|
|
// .def("add_attr", &ir::function::add_attr);
|
|
.def("args",
|
|
- [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
|
|
+ [](mlir::func::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
|
|
return self.getArgument(idx);
|
|
})
|
|
.def(
|
|
"add_entry_block",
|
|
- [](mlir::FuncOp &self) -> mlir::Block * {
|
|
+ [](mlir::func::FuncOp &self) -> mlir::Block * {
|
|
return self.addEntryBlock();
|
|
},
|
|
ret::reference)
|
|
.def(
|
|
"set_arg_attr",
|
|
- [](mlir::FuncOp &self, int arg_no, const std::string &name, int val) {
|
|
+ [](mlir::func::FuncOp &self, int arg_no, const std::string &name,
|
|
+ int val) {
|
|
// set arg attributes "name" to value "val"
|
|
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
|
|
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
|
|
},
|
|
ret::reference)
|
|
- .def_property_readonly("type", &mlir::FuncOp::getType)
|
|
- .def("reset_type", &mlir::FuncOp::setType);
|
|
+ .def_property_readonly("type", &mlir::func::FuncOp::getFunctionType)
|
|
+ .def("reset_type", &mlir::func::FuncOp::setType);
|
|
|
|
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
|
|
|
|
@@ -445,13 +449,13 @@ void init_triton_ir(py::module &&m) {
|
|
.def("ret",
|
|
[](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
|
|
auto loc = self.getUnknownLoc();
|
|
- self.create<mlir::ReturnOp>(loc, vals);
|
|
+ self.create<mlir::func::ReturnOp>(loc, vals);
|
|
})
|
|
.def("call",
|
|
- [](mlir::OpBuilder &self, mlir::FuncOp &func,
|
|
+ [](mlir::OpBuilder &self, mlir::func::FuncOp &func,
|
|
std::vector<mlir::Value> &args) -> mlir::OpState {
|
|
auto loc = self.getUnknownLoc();
|
|
- return self.create<mlir::CallOp>(loc, func, args);
|
|
+ return self.create<mlir::func::CallOp>(loc, func, args);
|
|
})
|
|
// insertion block/point
|
|
.def("set_insertion_point_to_start",
|
|
@@ -618,15 +622,16 @@ void init_triton_ir(py::module &&m) {
|
|
.def("get_or_insert_function",
|
|
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
|
|
std::string &funcName, mlir::Type &funcType,
|
|
- std::string &visibility) -> mlir::FuncOp {
|
|
+ std::string &visibility) -> mlir::func::FuncOp {
|
|
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
|
|
- return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
|
|
+ return llvm::dyn_cast<mlir::func::FuncOp>(funcOperation);
|
|
auto loc = self.getUnknownLoc();
|
|
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
|
llvm::SmallVector<mlir::NamedAttribute> attrs = {
|
|
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
|
|
self.getStringAttr(visibility))};
|
|
- return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);
|
|
+ return self.create<mlir::func::FuncOp>(loc, funcName, funcTy,
|
|
+ attrs);
|
|
}
|
|
throw std::runtime_error("invalid function type");
|
|
})
|
|
@@ -658,15 +663,15 @@ void init_triton_ir(py::module &&m) {
|
|
[](mlir::OpBuilder &self, mlir::Value condition,
|
|
mlir::Block *trueDest, mlir::Block *falseDest) {
|
|
auto loc = self.getUnknownLoc();
|
|
- self.create<mlir::CondBranchOp>(loc, condition, trueDest,
|
|
- falseDest);
|
|
+ self.create<mlir::cf::CondBranchOp>(loc, condition, trueDest,
|
|
+ falseDest);
|
|
return;
|
|
})
|
|
.def("create_branch",
|
|
[](mlir::OpBuilder &self, mlir::Block *dest,
|
|
std::vector<mlir::Value> &args) {
|
|
auto loc = self.getUnknownLoc();
|
|
- self.create<mlir::BranchOp>(loc, dest, args);
|
|
+ self.create<mlir::cf::BranchOp>(loc, dest, args);
|
|
return;
|
|
})
|
|
// Structured control flow
|
|
@@ -792,14 +797,14 @@ void init_triton_ir(py::module &&m) {
|
|
.def("create_to_index",
|
|
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
|
auto loc = self.getUnknownLoc();
|
|
- return self.create<mlir::arith::IndexCastOp>(loc, input,
|
|
- self.getIndexType());
|
|
+ return self.create<mlir::arith::IndexCastOp>(
|
|
+ loc, self.getIndexType(), input);
|
|
})
|
|
.def("create_index_to_si",
|
|
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
|
auto loc = self.getUnknownLoc();
|
|
- return self.create<mlir::arith::IndexCastOp>(loc, input,
|
|
- self.getI32Type());
|
|
+ return self.create<mlir::arith::IndexCastOp>(
|
|
+ loc, self.getI32Type(), input);
|
|
})
|
|
.def("create_fmul",
|
|
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
|
@@ -1316,8 +1321,8 @@ void init_triton_ir(py::module &&m) {
|
|
[](mlir::OpBuilder &self, mlir::Value &condition,
|
|
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
|
|
auto loc = self.getUnknownLoc();
|
|
- return self.create<mlir::SelectOp>(loc, condition, trueValue,
|
|
- falseValue);
|
|
+ return self.create<mlir::arith::SelectOp>(loc, condition,
|
|
+ trueValue, falseValue);
|
|
})
|
|
.def("create_printf",
|
|
[](mlir::OpBuilder &self, const std::string &prefix,
|
|
@@ -1429,7 +1434,7 @@ void init_triton_ir(py::module &&m) {
|
|
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
|
})
|
|
.def("add_scf_to_cfg", [](mlir::PassManager &self) {
|
|
- self.addPass(mlir::createLowerToCFGPass());
|
|
+ self.addPass(mlir::createConvertSCFToCFPass());
|
|
});
|
|
}
|
|
|
|
diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py
|
|
index 432544a8a4..018f544714 100644
|
|
--- a/python/test/unit/language/test_core.py
|
|
+++ b/python/test/unit/language/test_core.py
|
|
@@ -1918,7 +1918,7 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
|
#dst = {dst_layout}
|
|
""" + """
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
- func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
|
+ func.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
|
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
|
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
|
|
diff --git a/python/triton/compiler.py b/python/triton/compiler.py
|
|
index 5d167634df..c36589037c 100644
|
|
--- a/python/triton/compiler.py
|
|
+++ b/python/triton/compiler.py
|
|
@@ -1514,14 +1514,14 @@ def make_hash(fn, **kwargs):
|
|
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
|
|
|
|
|
|
-# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
|
+# - ^\s*func\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
|
# and any following whitespace
|
|
# - (public\s+)? : optionally match the keyword public and any following whitespace
|
|
# - (@\w+) : match an @ symbol followed by one or more word characters
|
|
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
|
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
|
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
|
-mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
|
+mlir_prototype_pattern = r'^\s*func\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
|
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
|
prototype_pattern = {
|
|
"ttir": mlir_prototype_pattern,
|
|
diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir
|
|
index b3d5673f85..bb21615e68 100644
|
|
--- a/test/Analysis/test-alias.mlir
|
|
+++ b/test/Analysis/test-alias.mlir
|
|
@@ -11,7 +11,7 @@
|
|
|
|
// CHECK-LABEL: matmul_loop
|
|
// There shouldn't be any aliasing with the dot op encoding.
|
|
-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
|
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
|
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
|
@@ -36,7 +36,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|
}
|
|
|
|
// CHECK-LABEL: alloc
|
|
-func @alloc(%A : !tt.ptr<f16>) {
|
|
+func.func @alloc(%A : !tt.ptr<f16>) {
|
|
// CHECK: %cst -> %cst
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
|
@@ -46,7 +46,7 @@ func @alloc(%A : !tt.ptr<f16>) {
|
|
}
|
|
|
|
// CHECK-LABEL: convert
|
|
-func @convert(%A : !tt.ptr<f16>) {
|
|
+func.func @convert(%A : !tt.ptr<f16>) {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
|
// CHECK: %0 -> %0
|
|
%cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
|
|
@@ -54,7 +54,7 @@ func @convert(%A : !tt.ptr<f16>) {
|
|
}
|
|
|
|
// CHECK-LABEL: trans
|
|
-func @trans(%A : !tt.ptr<f16>) {
|
|
+func.func @trans(%A : !tt.ptr<f16>) {
|
|
// CHECK: %cst -> %cst
|
|
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
|
// CHECK: %0 -> %cst
|
|
@@ -63,7 +63,7 @@ func @trans(%A : !tt.ptr<f16>) {
|
|
}
|
|
|
|
// CHECK-LABEL: insert_slice_async
|
|
-func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
+func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
|
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
|
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
|
@@ -76,7 +76,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
}
|
|
|
|
// CHECK-LABEL: insert_slice
|
|
-func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
+func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
|
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
|
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
|
@@ -90,7 +90,7 @@ func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
}
|
|
|
|
// CHECK-LABEL: extract_slice
|
|
-func @extract_slice(%A : !tt.ptr<f16>) {
|
|
+func.func @extract_slice(%A : !tt.ptr<f16>) {
|
|
// CHECK: %cst -> %cst
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
|
%index = arith.constant 0 : index
|
|
@@ -100,7 +100,7 @@ func @extract_slice(%A : !tt.ptr<f16>) {
|
|
}
|
|
|
|
// CHECK-LABEL: if_cat
|
|
-func @if_cat(%i1 : i1) {
|
|
+func.func @if_cat(%i1 : i1) {
|
|
// CHECK: %cst -> %cst
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
// CHECK: %cst_0 -> %cst_0
|
|
@@ -119,7 +119,7 @@ func @if_cat(%i1 : i1) {
|
|
}
|
|
|
|
// CHECK-LABEL: if_alias
|
|
-func @if_alias(%i1 : i1) {
|
|
+func.func @if_alias(%i1 : i1) {
|
|
// CHECK: %cst -> %cst
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
// CHECK-NEXT: %cst_0 -> %cst_0
|
|
@@ -134,7 +134,7 @@ func @if_alias(%i1 : i1) {
|
|
}
|
|
|
|
// CHECK-LABEL: for
|
|
-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
+func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
// CHECK: %cst -> %cst
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
// CHECK-NEXT: %cst_0 -> %cst_0
|
|
@@ -154,7 +154,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
|
|
}
|
|
|
|
// CHECK-LABEL: for_if
|
|
-func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
|
+func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
|
// CHECK: %cst -> %cst
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
// CHECK-NEXT: %cst_0 -> %cst_0
|
|
@@ -180,7 +180,7 @@ func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !t
|
|
}
|
|
|
|
// CHECK-LABEL: for_if_for
|
|
-func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
|
+func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
|
// CHECK: %cst -> %cst
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
// CHECK-NEXT: %cst_0 -> %cst_0
|
|
diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir
|
|
index 0ab34c7a78..af8ea6f856 100644
|
|
--- a/test/Analysis/test-alignment.mlir
|
|
+++ b/test/Analysis/test-alignment.mlir
|
|
@@ -1,288 +1,288 @@
|
|
-// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s
|
|
+// RUN: triton-opt %s -test-print-alignment -split-input-file -o %t 2>&1 | FileCheck %s
|
|
|
|
-// CHECK-LABEL: cast
|
|
-func @cast() {
|
|
- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1]
|
|
+// CHECK-LABEL: @cast
|
|
+func.func @cast() {
|
|
+ // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
|
|
%cst = arith.constant 1 : i32
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
|
|
%0 = arith.extsi %cst : i32 to i64
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
|
%cst_tensor = arith.constant dense<1> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
|
%1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: add
|
|
-func @add() {
|
|
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+// CHECK-LABEL: @add
|
|
+func.func @add() {
|
|
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
|
%1 = arith.constant dense<1> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%2 = arith.addi %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [127]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 127
|
|
%3 = arith.constant dense<127> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
|
%4 = arith.addi %1, %3 : tensor<128xi32>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: sub
|
|
-func @sub() {
|
|
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+// CHECK-LABEL: @sub
|
|
+func.func @sub() {
|
|
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
|
%1 = arith.constant dense<1> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%2 = arith.subi %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [129]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129
|
|
%3 = arith.constant dense<129> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
|
%4 = arith.subi %3, %1 : tensor<128xi32>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: mul
|
|
-func @mul() {
|
|
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+// CHECK-LABEL: @mul
|
|
+func.func @mul() {
|
|
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
|
%1 = arith.constant dense<1> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%2 = arith.muli %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
|
%3 = arith.constant dense<128> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
|
%4 = arith.muli %3, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [2]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2
|
|
%5 = arith.constant dense<2> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [256] ; Constancy: [128] ; ConstantValue: [256]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [128], constant_value = 256
|
|
%6 = arith.muli %4, %5 : tensor<128xi32>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: div
|
|
-func @div() {
|
|
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+// CHECK-LABEL: @div
|
|
+func.func @div() {
|
|
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
|
%1 = arith.constant dense<1> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%2 = arith.divsi %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%3 = arith.divui %1, %0 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
|
%4 = arith.constant dense<64> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
|
|
%5 = arith.divsi %0, %4 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%6 = arith.divsi %4, %0 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
|
%7 = arith.divsi %4, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66
|
|
%8 = arith.constant dense<66> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [2] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [2], constant_value = <none>
|
|
%9 = arith.divui %0, %8 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [8192] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = <none>
|
|
%10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [64] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = <none>
|
|
%11 = arith.divsi %10, %4 : tensor<128xi32>
|
|
- return
|
|
+ return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: rem
|
|
-func @rem() {
|
|
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+// CHECK-LABEL: @rem
|
|
+func.func @rem() {
|
|
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
|
|
%1 = arith.constant dense<1> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
|
%2 = arith.remsi %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%3 = arith.remui %1, %0 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
|
%4 = arith.constant dense<64> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [64] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [64], divisibility = [64], constancy = [1], constant_value = <none>
|
|
%5 = arith.remsi %0, %4 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [1], constant_value = <none>
|
|
%6 = arith.remsi %4, %0 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66
|
|
%7 = arith.constant dense<66> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [2] ; Divisibility: [2] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [2], divisibility = [2], constancy = [1], constant_value = <none>
|
|
%8 = arith.remui %0, %7 : tensor<128xi32>
|
|
- return
|
|
+ return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: broadcast
|
|
-func @broadcast() {
|
|
- // CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
|
+// CHECK-LABEL: @broadcast
|
|
+func.func @broadcast() {
|
|
+ // CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
|
%0 = arith.constant dense<64> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 1] ; ConstantValue: [64]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64
|
|
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 128] ; ConstantValue: [64]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64
|
|
%2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: splat
|
|
-func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
|
- // CHECK: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 128] ; ConstantValue: [None]
|
|
+// CHECK-LABEL: @splat
|
|
+func.func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
|
+ // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
|
|
%0 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: cmp
|
|
-func @cmp() {
|
|
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+// CHECK-LABEL: @cmp
|
|
+func.func @cmp() {
|
|
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
|
%1 = arith.constant dense<0> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
|
%2 = arith.cmpi eq, %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
|
%3 = arith.cmpi slt, %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%4 = arith.cmpi sle, %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
|
%5 = arith.cmpi sge, %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
|
%6 = arith.constant dense<8> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
|
%7 = arith.cmpi sgt, %0, %6 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [0]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 0
|
|
%8 = arith.cmpi sgt, %1, %6 : tensor<128xi32>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: logic
|
|
-func @logic() {
|
|
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+// CHECK-LABEL: @logic
|
|
+func.func @logic() {
|
|
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
|
%1 = arith.constant dense<64> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
|
|
%2 = arith.divsi %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
|
%3 = arith.constant dense<8> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [134217728] ; Constancy: [8] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = <none>
|
|
%4 = arith.divsi %0, %3 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%5 = arith.andi %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%6 = arith.ori %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%7 = arith.xori %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
|
%8 = arith.andi %2, %4 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
|
%9 = arith.ori %2, %4 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
|
%10 = arith.xori %2, %4 : tensor<128xi32>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: select
|
|
-func @select() {
|
|
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+// CHECK-LABEL: @select
|
|
+func.func @select() {
|
|
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
|
%1 = arith.constant dense<0> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
|
%2 = arith.cmpi eq, %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
|
%3 = arith.cmpi slt, %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0
|
|
%4 = arith.constant 0 : i1
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
|
%7 = tt.splat %4 : (i1) -> tensor<128xi1>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
|
|
- %5 = select %4, %3, %7 : tensor<128xi1>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
|
|
+ %5 = arith.select %4, %3, %7 : tensor<128xi1>
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
|
|
%8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-func @shift() {
|
|
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+func.func @shift() {
|
|
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
|
%1 = arith.constant dense<8> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4
|
|
%2 = arith.constant dense<4> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [274877906944] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [274877906944], constancy = [1], constant_value = <none>
|
|
%3 = arith.shli %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [67108864] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [67108864], constancy = [1], constant_value = <none>
|
|
%4 = arith.shrsi %0, %2 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
|
%5 = arith.shli %1, %2 : tensor<128xi32>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-func @max_min() {
|
|
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+func.func @max_min() {
|
|
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = <none>
|
|
%1 = tt.make_range {end = 192 : i32, start = 64 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%2 = arith.maxsi %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%3 = arith.minsi %0, %1 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
|
%4 = arith.constant dense<8> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4
|
|
%5 = arith.constant dense<4> : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [8]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8
|
|
%6 = arith.maxsi %4, %5 : tensor<128xi32>
|
|
return
|
|
}
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: for
|
|
-func @for() {
|
|
- // CHECK: Contiguity: [1, 1] ; Divisibility: [4611686018427387904, 4611686018427387904] ; Constancy: [128, 32] ; ConstantValue: [0]
|
|
+// CHECK-LABEL: @for
|
|
+func.func @for() {
|
|
+ // CHECK: contiguity = [1, 1], divisibility = [4611686018427387904, 4611686018427387904], constancy = [128, 32], constant_value = 0
|
|
%a_init = arith.constant dense<0> : tensor<128x32xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [1]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1
|
|
%b_init = arith.constant dense<1> : tensor<128x32xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
|
|
%c_init = arith.constant dense<4> : tensor<128x32xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
|
|
%ub = arith.constant 128 : index
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0
|
|
%lb = arith.constant 0 : index
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [16]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16
|
|
%step = arith.constant 16 : index
|
|
%a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) {
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = <none>
|
|
%t = arith.index_cast %iv : index to i32
|
|
- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None]
|
|
- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None]
|
|
- // CHECK: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4]
|
|
+ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none>
|
|
+ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none>
|
|
+ // CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
|
|
scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>
|
|
}
|
|
return
|
|
@@ -290,53 +290,53 @@ func @for() {
|
|
|
|
// -----
|
|
|
|
-// CHECK-LABEL: permute_2d
|
|
-func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
|
- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 128] ; ConstantValue: [1]
|
|
+// CHECK-LABEL: @permute_2d
|
|
+func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
|
+ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 128], constant_value = 1
|
|
%cst = arith.constant dense<true> : tensor<128x128xi1>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
|
|
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = <none>
|
|
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
|
|
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [17179869184, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [17179869184, 16], constancy = [1, 1], constant_value = <none>
|
|
%4 = arith.muli %2, %3 : tensor<128x1xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
|
|
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = <none>
|
|
%6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none>
|
|
%7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = <none>
|
|
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
|
- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [128, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = <none>
|
|
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [1, 1], constant_value = <none>
|
|
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = <none>
|
|
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
|
|
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
|
- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
|
|
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none>
|
|
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = <none>
|
|
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = <none>
|
|
%16 = arith.muli %14, %15 : tensor<1x128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 128], constant_value = <none>
|
|
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [128, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = <none>
|
|
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
|
|
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
|
|
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
|
tt.store %19, %20, %cst : tensor<128x128xf32>
|
|
return
|
|
@@ -347,29 +347,29 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
|
module {
|
|
|
|
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
|
|
-// CHECK-LABEL: store_constant_align
|
|
-func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
|
- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+// CHECK-LABEL: @store_constant_align
|
|
+func.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
|
+ // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%pid = tt.get_program_id {axis = 0 : i32} : i32
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
|
|
%c128_i32 = arith.constant 128 : i32
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = <none>
|
|
%1 = arith.muli %pid, %c128_i32 : i32
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
|
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = <none>
|
|
%3 = tt.splat %1 : (i32) -> tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [128], constancy = [1], constant_value = <none>
|
|
%4 = arith.addi %3, %2 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = <none>
|
|
%5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
|
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [128], divisibility = [16], constancy = [1], constant_value = <none>
|
|
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = <none>
|
|
%9 = tt.splat %n : (i32) -> tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = <none>
|
|
%mask = arith.cmpi slt, %4, %9 : tensor<128xi32>
|
|
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
|
|
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%cst = arith.constant dense<0.0> : tensor<128xf32>
|
|
tt.store %5, %cst, %mask : tensor<128xf32>
|
|
return
|
|
@@ -381,8 +381,8 @@ func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n:
|
|
|
|
// This IR is dumped from vecadd test.
|
|
// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask.
|
|
-// CHECK-LABEL: vecadd_mask_align_16
|
|
-func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
|
+// CHECK-LABEL: @vecadd_mask_align_16
|
|
+func.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
|
%c64_i32 = arith.constant 64 : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = arith.muli %0, %c64_i32 : i32
|
|
@@ -394,13 +394,13 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
|
|
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
|
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
|
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
|
- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
|
+ // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [16], constant_value = <none>
|
|
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
|
%11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
|
%12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
|
%13 = arith.addf %11, %12 : tensor<64xf32>
|
|
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
|
- // CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>>, tensor<64xi32> )
|
|
+ // CHECK: tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = <none>
|
|
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
|
tt.store %15, %13, %mask : tensor<64xf32>
|
|
return
|
|
@@ -410,8 +410,8 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
|
|
|
|
// This IR is dumped from vecadd test.
|
|
// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default.
|
|
-// CHECK-LABEL: vecadd_mask_align_1
|
|
-func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
|
+// CHECK-LABEL: @vecadd_mask_align_1
|
|
+func.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
|
%c64_i32 = arith.constant 64 : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = arith.muli %0, %c64_i32 : i32
|
|
@@ -423,7 +423,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
|
|
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
|
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
|
|
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
|
- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
|
+ // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
|
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
|
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
|
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
|
diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir
|
|
index efb00c404d..f79222aa7b 100644
|
|
--- a/test/Analysis/test-allocation.mlir
|
|
+++ b/test/Analysis/test-allocation.mlir
|
|
@@ -13,7 +13,7 @@
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
|
|
// CHECK-LABEL: matmul_loop
|
|
-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
|
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
|
|
|
@@ -46,7 +46,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|
|
|
// Shared memory is available after a tensor's liveness range ends
|
|
// CHECK-LABEL: reusable
|
|
-func @reusable(%A : !tt.ptr<f16>) {
|
|
+func.func @reusable(%A : !tt.ptr<f16>) {
|
|
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
|
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
|
%cst3 = arith.constant dense<true> : tensor<32x128xi1, #AL>
|
|
@@ -78,7 +78,7 @@ func @reusable(%A : !tt.ptr<f16>) {
|
|
// %cst1->%cst4
|
|
// %cst3->%g->%h->%i
|
|
// CHECK-LABEL: preallocate
|
|
-func @preallocate(%A : !tt.ptr<f16>) {
|
|
+func.func @preallocate(%A : !tt.ptr<f16>) {
|
|
// CHECK: offset = 0, size = 512
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
// CHECK-NEXT: offset = 1024, size = 512
|
|
@@ -113,7 +113,7 @@ func @preallocate(%A : !tt.ptr<f16>) {
|
|
|
|
// Unused tensors are immediately released
|
|
// CHECK-LABEL: unused
|
|
-func @unused(%A : !tt.ptr<f16>) {
|
|
+func.func @unused(%A : !tt.ptr<f16>) {
|
|
// CHECK: offset = 0, size = 1024
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED>
|
|
// CHECK-NEXT: offset = 0, size = 512
|
|
@@ -128,7 +128,7 @@ func @unused(%A : !tt.ptr<f16>) {
|
|
|
|
// cst0 is alive through the entire function, it cannot be released before the end of the function
|
|
// CHECK-LABEL: longlive
|
|
-func @longlive(%A : !tt.ptr<f16>) {
|
|
+func.func @longlive(%A : !tt.ptr<f16>) {
|
|
// CHECK: offset = 0, size = 512
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
// CHECK-NEXT: offset = 512, size = 512
|
|
@@ -156,7 +156,7 @@ func @longlive(%A : !tt.ptr<f16>) {
|
|
}
|
|
|
|
// CHECK-LABEL: alloc
|
|
-func @alloc(%A : !tt.ptr<f16>) {
|
|
+func.func @alloc(%A : !tt.ptr<f16>) {
|
|
// CHECK: offset = 0, size = 512
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
|
@@ -167,7 +167,7 @@ func @alloc(%A : !tt.ptr<f16>) {
|
|
}
|
|
|
|
// CHECK-LABEL: scratch
|
|
-func @scratch() {
|
|
+func.func @scratch() {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
|
// CHECK: scratch offset = 0, size = 512
|
|
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
|
@@ -176,7 +176,7 @@ func @scratch() {
|
|
}
|
|
|
|
// CHECK-LABEL: trans
|
|
-func @trans(%A : !tt.ptr<f16>) {
|
|
+func.func @trans(%A : !tt.ptr<f16>) {
|
|
// CHECK: offset = 0, size = 1024
|
|
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
|
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
|
|
@@ -184,7 +184,7 @@ func @trans(%A : !tt.ptr<f16>) {
|
|
}
|
|
|
|
// CHECK-LABEL: insert_slice_async
|
|
-func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
+func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
|
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
|
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
|
@@ -197,7 +197,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
}
|
|
|
|
// CHECK-LABEL: extract_slice
|
|
-func @extract_slice(%A : !tt.ptr<f16>) {
|
|
+func.func @extract_slice(%A : !tt.ptr<f16>) {
|
|
// CHECK: offset = 0, size = 512
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
|
%index = arith.constant 0 : index
|
|
@@ -209,7 +209,7 @@ func @extract_slice(%A : !tt.ptr<f16>) {
|
|
// B0 -> (B1) -> B0
|
|
// Memory used by B1 can be reused by B0.
|
|
// CHECK-LABEL: if
|
|
-func @if(%i1 : i1) {
|
|
+func.func @if(%i1 : i1) {
|
|
// CHECK: offset = 0, size = 512
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
// CHECK-NEXT: offset = 512, size = 512
|
|
@@ -233,7 +233,7 @@ func @if(%i1 : i1) {
|
|
// B0 -> (B1) -> (B2) -> B0
|
|
// Memory used by B0 cannot be reused by B1 or B2.
|
|
// CHECK-LABEL: if_else
|
|
-func @if_else(%i1 : i1) {
|
|
+func.func @if_else(%i1 : i1) {
|
|
// CHECK: offset = 0, size = 512
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
// CHECK-NEXT: offset = 512, size = 512
|
|
@@ -260,7 +260,7 @@ func @if_else(%i1 : i1) {
|
|
// Block arguments and yields are memory aliases that do not trigger a new
|
|
// allocation.
|
|
// CHECK-LABEL: for
|
|
-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
+func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
// CHECK: offset = 0, size = 8192
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
// CHECK-NEXT: offset = 8192, size = 8192
|
|
@@ -275,7 +275,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
|
|
}
|
|
|
|
// CHECK-LABEL: for_if_slice
|
|
-func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
|
+func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
|
// CHECK: offset = 0, size = 8192
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
// CHECK-NEXT: offset = 8192, size = 8192
|
|
@@ -296,7 +296,7 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %
|
|
|
|
// c0 cannot be released in the loop
|
|
// CHECK-LABEL: for_use_ancestor
|
|
-func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
|
+func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
|
// CHECK: offset = 0, size = 8192
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
// CHECK-NEXT: offset = 8192, size = 8192
|
|
@@ -316,7 +316,7 @@ func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
|
|
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
|
|
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
|
|
// CHECK-LABEL: for_if_for
|
|
-func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
|
+func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
|
// CHECK: offset = 0, size = 8192
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
// CHECK-NEXT: offset = 8192, size = 8192
|
|
diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir
|
|
index 7199e5f53d..17880b2094 100644
|
|
--- a/test/Analysis/test-membar.mlir
|
|
+++ b/test/Analysis/test-membar.mlir
|
|
@@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
|
|
// CHECK-LABEL: matmul_loop
|
|
// There shouldn't be any membar with the dot op encoding.
|
|
-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
|
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
|
|
|
@@ -42,7 +42,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|
}
|
|
|
|
// CHECK-LABEL: raw_single_block
|
|
-func @raw_single_block(%A : !tt.ptr<f16>) {
|
|
+func.func @raw_single_block(%A : !tt.ptr<f16>) {
|
|
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
|
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
|
@@ -54,7 +54,7 @@ func @raw_single_block(%A : !tt.ptr<f16>) {
|
|
}
|
|
|
|
// CHECK-LABEL: war_single_block
|
|
-func @war_single_block(%A : !tt.ptr<f16>) {
|
|
+func.func @war_single_block(%A : !tt.ptr<f16>) {
|
|
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
|
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
|
@@ -70,7 +70,7 @@ func @war_single_block(%A : !tt.ptr<f16>) {
|
|
}
|
|
|
|
// CHECK-LABEL: scratch
|
|
-func @scratch() {
|
|
+func.func @scratch() {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
// CHECK: Membar 1
|
|
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
|
@@ -81,7 +81,7 @@ func @scratch() {
|
|
}
|
|
|
|
// CHECK-LABEL: async_wait
|
|
-func @async_wait() {
|
|
+func.func @async_wait() {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
// CHECK: Membar 1
|
|
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
|
@@ -92,7 +92,7 @@ func @async_wait() {
|
|
}
|
|
|
|
// CHECK-LABEL: alloc
|
|
-func @alloc() {
|
|
+func.func @alloc() {
|
|
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
|
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
|
// CHECK: Membar 2
|
|
@@ -101,7 +101,7 @@ func @alloc() {
|
|
}
|
|
|
|
// CHECK-LABEL: extract_slice
|
|
-func @extract_slice() {
|
|
+func.func @extract_slice() {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
|
%index = arith.constant 0 : index
|
|
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
|
@@ -113,14 +113,14 @@ func @extract_slice() {
|
|
}
|
|
|
|
// CHECK-LABEL: trans
|
|
-func @trans() {
|
|
+func.func @trans() {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
|
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
|
|
return
|
|
}
|
|
|
|
// CHECK-LABEL: insert_slice_async
|
|
-func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
+func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
|
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
|
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
|
@@ -135,7 +135,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
}
|
|
|
|
// CHECK-LABEL: insert_slice
|
|
-func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
+func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
|
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
|
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
|
@@ -153,7 +153,7 @@ func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
|
|
|
// If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region
|
|
// CHECK-LABEL: multi_blocks
|
|
-func @multi_blocks(%i1 : i1) {
|
|
+func.func @multi_blocks(%i1 : i1) {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
scf.if %i1 {
|
|
@@ -174,7 +174,7 @@ func @multi_blocks(%i1 : i1) {
|
|
|
|
// Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region
|
|
// CHECK-LABEL: multi_blocks_join_barrier
|
|
-func @multi_blocks_join_barrier(%i1 : i1) {
|
|
+func.func @multi_blocks_join_barrier(%i1 : i1) {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
scf.if %i1 {
|
|
@@ -192,7 +192,7 @@ func @multi_blocks_join_barrier(%i1 : i1) {
|
|
|
|
// Read yielded tensor requires a barrier
|
|
// CHECK-LABEL: multi_blocks_yield
|
|
-func @multi_blocks_yield(%i1 : i1) {
|
|
+func.func @multi_blocks_yield(%i1 : i1) {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
|
|
@@ -212,7 +212,7 @@ func @multi_blocks_yield(%i1 : i1) {
|
|
|
|
// Conservatively add a barrier as if the branch (%i1) is never taken
|
|
// CHECK-LABEL: multi_blocks_noelse
|
|
-func @multi_blocks_noelse(%i1 : i1) {
|
|
+func.func @multi_blocks_noelse(%i1 : i1) {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
scf.if %i1 {
|
|
@@ -226,7 +226,7 @@ func @multi_blocks_noelse(%i1 : i1) {
|
|
|
|
// Conservatively add a barrier as if the branch (%i2) is never taken
|
|
// CHECK-LABEL: multi_blocks_nested_scf
|
|
-func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
|
+func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
|
scf.if %i1 {
|
|
@@ -247,7 +247,7 @@ func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
|
}
|
|
|
|
// CHECK-LABEL: for
|
|
-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
+func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
@@ -262,7 +262,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
|
|
// Although a_shared and b_shared are synced before entering the loop,
|
|
// they are reassociated with aliases (c_shared) and thus require a barrier.
|
|
// CHECK-LABEL: for_alias
|
|
-func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
+func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
// CHECK-NEXT: Membar 2
|
|
@@ -282,7 +282,7 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
|
// Although cst2 is not an argument of scf.yield, its memory is reused by cst1.
|
|
// So we need a barrier both before and after cst1
|
|
// CHECK-LABEL: for_reuse
|
|
-func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
+func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
// CHECK-NEXT: Membar 2
|
|
@@ -302,7 +302,7 @@ func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
|
|
|
|
|
// CHECK-LABEL: for_reuse_nested
|
|
-func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
+func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
|
// CHECK-NEXT: Membar 2
|
|
diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir
|
|
index e9ee502435..0e979b148d 100644
|
|
--- a/test/Conversion/triton_ops.mlir
|
|
+++ b/test/Conversion/triton_ops.mlir
|
|
@@ -1,6 +1,6 @@
|
|
// RUN: triton-opt %s | FileCheck %s
|
|
|
|
-func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
|
+func.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
|
// scalar -> scalar
|
|
// CHECK: i64 -> !tt.ptr<f32>
|
|
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
|
|
@@ -35,7 +35,7 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
|
return
|
|
}
|
|
|
|
-func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
|
+func.func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
|
// scalar -> scalar
|
|
// CHECK: !tt.ptr<f32>
|
|
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>, i32
|
|
@@ -54,7 +54,7 @@ func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
|
return
|
|
}
|
|
|
|
-func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
|
|
+func.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
|
|
// Test if Load/Store ops can handle scalar values
|
|
%other = arith.constant 0.0e+0 : f32
|
|
|
|
@@ -76,7 +76,7 @@ func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ma
|
|
return
|
|
}
|
|
|
|
-func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
|
+func.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
|
// Test if reduce ops infer types correctly
|
|
|
|
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
|
|
@@ -101,7 +101,7 @@ func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
|
return
|
|
}
|
|
|
|
-func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
|
+func.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
|
// Test if reduce ops infer types correctly
|
|
%v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32>
|
|
%v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32>
|
|
diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir
|
|
index a160bc8815..b461ca542f 100644
|
|
--- a/test/Conversion/triton_to_tritongpu.mlir
|
|
+++ b/test/Conversion/triton_to_tritongpu.mlir
|
|
@@ -1,6 +1,6 @@
|
|
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
|
|
|
-func @ops() {
|
|
+func.func @ops() {
|
|
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
|
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
|
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
|
@@ -11,7 +11,7 @@ func @ops() {
|
|
|
|
// -----
|
|
|
|
-func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
|
+func.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
|
// Test if LoadOp is lowered properly (see #771)
|
|
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
|
%mask = arith.constant dense<true> : tensor<128xi1>
|
|
@@ -30,7 +30,7 @@ func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
|
|
|
// -----
|
|
|
|
-func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
|
+func.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
|
// Test if the total number of threadsPerWarp is 32
|
|
// Test if the total number of warps is 2
|
|
// CHECK: #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
|
|
diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir
|
|
index e9e7d5a340..507b362c99 100644
|
|
--- a/test/Conversion/tritongpu_to_llvm.mlir
|
|
+++ b/test/Conversion/tritongpu_to_llvm.mlir
|
|
@@ -4,7 +4,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
|
// Here the 128 comes from the 4 in module attribute multiples 32
|
|
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
|
|
- func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
|
+ func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
|
// CHECK: llvm.return
|
|
return
|
|
}
|
|
@@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_load
|
|
- func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
|
+ func.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
|
// CHECK: llvm.inline_asm
|
|
// CHECK: llvm.inline_asm
|
|
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
|
@@ -28,7 +28,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: vectorized_load
|
|
- func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
|
+ func.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
|
// CHECK: llvm.inline_asm
|
|
// CHECK-SAME: ld.global.b32
|
|
// CHECK: llvm.inline_asm
|
|
@@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK-LABEL: vectorized_load_f16
|
|
- func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
|
|
+ func.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
|
|
// CHECK: llvm.inline_asm
|
|
// CHECK-SAME: ld.global.b16
|
|
// CHECK: llvm.inline_asm
|
|
@@ -59,7 +59,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: masked_load_const_other
|
|
- func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
|
+ func.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
|
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
|
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
|
return
|
|
@@ -72,7 +72,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: masked_load_const_other_vec
|
|
- func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
|
+ func.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
|
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
|
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
|
return
|
|
@@ -84,7 +84,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
|
// CHECK-LABEL: global_load_store_no_vec
|
|
- func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
|
+ func.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
|
|
%c256_i32 = arith.constant 256 : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = arith.muli %0, %c256_i32 : i32
|
|
@@ -128,7 +128,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
|
// CHECK-LABEL: global_load_store_vec4
|
|
- func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
|
+ func.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
|
%c256_i32 = arith.constant 256 : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = arith.muli %0, %c256_i32 : i32
|
|
@@ -165,7 +165,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
|
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
|
// Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1.
|
|
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
|
- func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
|
+ func.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
|
%c64_i32 = arith.constant 64 : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = arith.muli %0, %c64_i32 : i32
|
|
@@ -195,7 +195,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK-LABEL: global_load_store_vec2
|
|
- func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
|
|
+ func.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
|
|
%c256_i32 = arith.constant 256 : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = arith.muli %0, %c256_i32 : i32
|
|
@@ -240,7 +240,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK-LABEL: global_load_store_vec8
|
|
- func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
|
+ func.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
|
%c256_i32 = arith.constant 256 : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = arith.muli %0, %c256_i32 : i32
|
|
@@ -283,7 +283,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_view_broadcast
|
|
- func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
|
|
+ func.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
|
|
// CHECK: llvm.mlir.undef
|
|
// CHECK: %[[T0:.*]] = llvm.extractvalue
|
|
// CHECK: %[[T1:.*]] = llvm.extractvalue
|
|
@@ -307,7 +307,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_make_range
|
|
- func @basic_make_range() {
|
|
+ func.func @basic_make_range() {
|
|
// CHECK: nvvm.read.ptx.sreg.tid.x
|
|
// CHECK: llvm.mlir.undef
|
|
// CHECK: llvm.insertvalue
|
|
@@ -322,7 +322,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_addf
|
|
- func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
|
|
+ func.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
|
|
// CHECK: llvm.fadd
|
|
// CHECK: llvm.fadd
|
|
%1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0>
|
|
@@ -335,7 +335,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_addi
|
|
- func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
|
+ func.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
|
// CHECK: llvm.add
|
|
// CHECK: llvm.add
|
|
%1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0>
|
|
@@ -347,7 +347,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_program_id
|
|
- func @basic_program_id() {
|
|
+ func.func @basic_program_id() {
|
|
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
return
|
|
@@ -359,7 +359,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_addptr
|
|
- func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
|
+ func.func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
|
// CHECK: llvm.getelementptr
|
|
// CHECK: llvm.getelementptr
|
|
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
|
@@ -373,7 +373,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK: llvm.mlir.global external @global_smem
|
|
// CHECK-LABEL: basic_alloc_tensor
|
|
- func @basic_alloc_tensor() {
|
|
+ func.func @basic_alloc_tensor() {
|
|
// CHECK: llvm.mlir.addressof @global_smem
|
|
// CHECK-NEXT: llvm.bitcast
|
|
// CHECK-NEXT: llvm.mlir.constant
|
|
@@ -390,7 +390,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK: llvm.mlir.global external @global_smem
|
|
// CHECK-LABEL: basic_extract_slice
|
|
- func @basic_extract_slice() {
|
|
+ func.func @basic_extract_slice() {
|
|
// CHECK: llvm.mlir.addressof @global_smem
|
|
// CHECK: llvm.extractvalue
|
|
// CHECK-NEXT: llvm.extractvalue
|
|
@@ -423,7 +423,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_async_wait
|
|
- func @basic_async_wait() {
|
|
+ func.func @basic_async_wait() {
|
|
// CHECK: cp.async.wait_group 0x4
|
|
triton_gpu.async_wait {num = 4: i32}
|
|
return
|
|
@@ -442,7 +442,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_insert_slice_async_fallback
|
|
- func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
|
|
+ func.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
|
|
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
|
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
|
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
|
@@ -481,7 +481,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_insert_slice_async_v4
|
|
- func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) {
|
|
+ func.func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) {
|
|
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
|
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
|
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
|
@@ -523,7 +523,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_insert_slice_async_v1
|
|
- func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
|
+ func.func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
|
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
|
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
|
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
|
@@ -568,7 +568,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_insert_slice_async_v1_multictas
|
|
- func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
|
+ func.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
|
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1>
|
|
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
|
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2>
|
|
@@ -619,7 +619,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK: basic_splat
|
|
- func @basic_splat(%ptr: !tt.ptr<f32>) {
|
|
+ func.func @basic_splat(%ptr: !tt.ptr<f32>) {
|
|
// CHECK: llvm.mlir.undef
|
|
// CHECK: llvm.insertvalue
|
|
// CHECK: llvm.insertvalue
|
|
@@ -633,7 +633,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: basic_store
|
|
- func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
|
+ func.func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
|
// CHECK: llvm.inline_asm
|
|
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
|
|
// CHECK: llvm.inline_asm
|
|
@@ -650,7 +650,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
|
// CHECK-LABEL: convert_layout_blocked_blocked
|
|
- func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
|
|
+ func.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
|
|
// CHECK: llvm.mlir.addressof @global_smem
|
|
// CHECK: llvm.store
|
|
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
|
|
@@ -697,7 +697,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
|
// CHECK-LABEL: convert_layout_blocked_blocked_vec
|
|
- func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
|
|
+ func.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
|
|
// CHECK: llvm.mlir.addressof @global_smem
|
|
// CHECK: llvm.store
|
|
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
|
@@ -720,7 +720,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
|
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
|
|
- func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
|
|
+ func.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
|
|
// CHECK: llvm.mlir.addressof @global_smem
|
|
// CHECK: llvm.store
|
|
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
|
@@ -751,7 +751,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK-LABEL: convert_dot
|
|
- func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
|
+ func.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
|
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
|
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
|
// CHECK: llvm.inline_asm
|
|
@@ -775,7 +775,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// TODO: problems in MLIR's parser on slice layout
|
|
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
|
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
-// func @make_range_sliced_layout() {
|
|
+// func.func @make_range_sliced_layout() {
|
|
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
|
// return
|
|
// }
|
|
@@ -788,7 +788,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
|
// CHECK-LABEL: convert_layout_mmav2_block
|
|
- func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
|
+ func.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
|
// CHECK: llvm.store
|
|
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
|
// CHECK: llvm.store
|
|
@@ -808,7 +808,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
|
// CHECK-LABEL: convert_layout_mmav1_block
|
|
- func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
|
|
+ func.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
|
|
// CHECK: llvm.store
|
|
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
|
// CHECK: llvm.store
|
|
@@ -831,7 +831,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
|
// CHECK-LABEL: convert_layout_blocked_shared
|
|
- func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
|
|
+ func.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
|
|
// CHECK: llvm.store
|
|
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
|
|
// CHECK: llvm.store
|
|
@@ -847,7 +847,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK-LABEL: convert_blocked1d_to_slice0
|
|
- func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
|
|
+ func.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
|
|
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
|
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
|
return
|
|
@@ -860,7 +860,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK-LABEL: convert_blocked1d_to_slice1
|
|
- func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
|
|
+ func.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
|
|
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
|
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
|
return
|
|
@@ -873,7 +873,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK-LABEL: convert_blocked_to_blocked_ptr
|
|
- func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
|
|
+ func.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
|
|
// CHECK: llvm.ptrtoint
|
|
// CHECK: llvm.store
|
|
// CHECK: nvvm.barrier0
|
|
@@ -892,7 +892,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
|
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
- func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
+ func.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
|
|
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
|
|
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
|
@@ -918,7 +918,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
|
|
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
- func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
+ func.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
%a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) {
|
|
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma>
|
|
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
|
@@ -941,7 +941,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}>
|
|
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
- func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
+ func.func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
|
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
|
|
// CHECK: llvm.intr.fmuladd
|
|
@@ -965,7 +965,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: matmul_tf32dot
|
|
- func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
+ func.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
|
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
|
|
// CHECK: llvm.inline_asm
|
|
@@ -1000,7 +1000,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: atomic_add_f32
|
|
- func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
|
+ func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
|
// CHECK: llvm.inline_asm
|
|
// CHECK-SAME: atom.global.gpu.add.f32
|
|
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
|
|
@@ -1012,7 +1012,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
|
|
-func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
|
+func.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
|
%blockidx = tt.get_program_id {axis=0:i32} : i32
|
|
%blockidy = tt.get_program_id {axis=1:i32} : i32
|
|
%blockidz = tt.get_program_id {axis=2:i32} : i32
|
|
@@ -1032,7 +1032,7 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
|
// -----
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
- func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
|
+ func.func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
|
// CHECK: nvvm.read.ptx.sreg.nctaid.x
|
|
// CHECK: nvvm.read.ptx.sreg.nctaid.y
|
|
// CHECK: nvvm.read.ptx.sreg.nctaid.z
|
|
@@ -1052,7 +1052,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK-LABEL: test_index_cache
|
|
- func @test_index_cache() {
|
|
+ func.func @test_index_cache() {
|
|
// CHECK: nvvm.read.ptx.sreg.tid.x
|
|
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
|
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
|
@@ -1066,7 +1066,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK-LABEL: test_base_index_cache
|
|
- func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
|
|
+ func.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
|
|
// CHECK: nvvm.read.ptx.sreg.tid.x
|
|
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
|
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
|
@@ -1080,7 +1080,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|
// CHECK-LABEL: test_index_cache_different_block
|
|
- func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
|
|
+ func.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
|
|
// CHECK: nvvm.read.ptx.sreg.tid.x
|
|
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
|
scf.if %arg1 {
|
|
diff --git a/test/Target/tritongpu_to_llvmir.mlir b/test/Target/tritongpu_to_llvmir.mlir
|
|
index cafff3ca60..114d3a9eb2 100644
|
|
--- a/test/Target/tritongpu_to_llvmir.mlir
|
|
+++ b/test/Target/tritongpu_to_llvmir.mlir
|
|
@@ -4,11 +4,11 @@
|
|
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
|
|
// CHECK: define void @test_empty_kernel
|
|
// CHECK: !nvvm.annotations
|
|
-// CHECK: !{void (i32, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128}
|
|
+// CHECK: !{ptr @test_empty_kernel, !"maxntidx", i32 128}
|
|
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
|
|
-func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
|
+func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
|
|
|
return
|
|
}
|
|
diff --git a/test/Target/tritongpu_to_ptx.mlir b/test/Target/tritongpu_to_ptx.mlir
|
|
index 404e970a29..12742ad9e2 100644
|
|
--- a/test/Target/tritongpu_to_ptx.mlir
|
|
+++ b/test/Target/tritongpu_to_ptx.mlir
|
|
@@ -6,7 +6,7 @@
|
|
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
|
|
-func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
|
+func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
|
|
|
return
|
|
}
|
|
diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir
|
|
index 050a3f7565..5ef6790e69 100644
|
|
--- a/test/Triton/combine.mlir
|
|
+++ b/test/Triton/combine.mlir
|
|
@@ -2,10 +2,10 @@
|
|
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s
|
|
|
|
// CHECK-LABEL: @test_combine_dot_add_pattern
|
|
-func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
|
|
- // CHECK: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
|
|
- // CHECK: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
|
|
- // CHECK: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
|
|
+func.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
|
|
+ // CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
|
|
+ // CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
|
|
+ // CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
|
|
%a = arith.constant dense<1.0> : tensor<128x128xf32>
|
|
%b = arith.constant dense<2.0> : tensor<128x128xf32>
|
|
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
|
|
@@ -24,7 +24,7 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
|
|
|
|
|
|
// COM: CHECK-LABEL: @test_combine_addptr_pattern
|
|
-func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
|
+func.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
|
%off0 = arith.constant 10 : i32
|
|
%off1 = arith.constant 15 : i32
|
|
|
|
@@ -47,46 +47,46 @@ func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
|
|
|
|
|
// CHECK-LABEL: @test_combine_select_masked_load_pattern
|
|
-func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
|
+func.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
|
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
|
|
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
|
|
|
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
- %0 = select %cond, %x, %false_val : tensor<8xf32>
|
|
+ %0 = arith.select %cond, %x, %false_val : tensor<8xf32>
|
|
|
|
// CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
%y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
- %1 = select %cond, %y, %false_val : tensor<8xf32>
|
|
+ %1 = arith.select %cond, %y, %false_val : tensor<8xf32>
|
|
|
|
// CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
|
|
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
|
|
-func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
|
+func.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
|
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
|
|
|
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
|
|
- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
- %0 = select %cond0, %dummy_load, %false_val : tensor<8xf32>
|
|
+ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
+ %0 = arith.select %cond0, %dummy_load, %false_val : tensor<8xf32>
|
|
|
|
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
|
|
%real_load0 = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
- %1 = select %cond0, %real_load0, %false_val : tensor<8xf32>
|
|
+ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
+ %1 = arith.select %cond0, %real_load0, %false_val : tensor<8xf32>
|
|
|
|
// Case 3: condition of "broadcast" is not the same as the condition of "select". Select should not be canonicalized.
|
|
%cond0_ = tt.broadcast %cond0 : (i1) -> tensor<8xi1>
|
|
%real_load1 = tt.load %ptr, %cond0_, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
- %2 = select %cond1, %real_load1, %false_val : tensor<8xf32>
|
|
+ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
+ %2 = arith.select %cond1, %real_load1, %false_val : tensor<8xf32>
|
|
|
|
return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
|
|
-func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
|
+func.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
|
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
|
|
%const = arith.constant dense<1.0> : tensor<8xf32>
|
|
%bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32>
|
|
@@ -96,7 +96,7 @@ func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
|
}
|
|
|
|
// CHECK-LABEL: @test_canonicalize_masked_load_pattern
|
|
-func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
|
+func.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
|
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
|
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
|
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
|
@@ -117,7 +117,7 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (te
|
|
}
|
|
|
|
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
|
|
-func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
|
|
+func.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
|
|
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
|
|
|
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
|
|
@@ -130,7 +130,7 @@ func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %
|
|
}
|
|
|
|
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
|
|
-func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
|
|
+func.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
|
|
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
|
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
|
|
|
@@ -144,7 +144,7 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val:
|
|
}
|
|
|
|
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
|
|
-func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
|
|
+func.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
|
|
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
|
|
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
tt.store %ptr, %val, %mask : tensor<8xf32>
|
|
diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir
|
|
index 0b69ef3054..f5019b1cdd 100644
|
|
--- a/test/Triton/vecadd.mlir
|
|
+++ b/test/Triton/vecadd.mlir
|
|
@@ -1,7 +1,7 @@
|
|
// RUN: triton-opt %s -verify-diagnostics
|
|
|
|
module {
|
|
- func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
|
+ func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%c256_i32 = arith.constant 256 : i32
|
|
%1 = arith.muli %0, %c256_i32 : i32
|
|
@@ -43,7 +43,7 @@ module {
|
|
}
|
|
}
|
|
// module {
|
|
-// func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
|
+// func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
|
// %c64 = arith.constant 64 : index
|
|
// %c32 = arith.constant 32 : index
|
|
// %c0 = arith.constant 0 : index
|
|
diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir
|
|
index 60e359f527..51cccccfbd 100644
|
|
--- a/test/TritonGPU/coalesce.mlir
|
|
+++ b/test/TritonGPU/coalesce.mlir
|
|
@@ -19,7 +19,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]>
|
|
// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]>
|
|
// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]]
|
|
-func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
+func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
%arg1: i32 {tt.divisibility = 16 : i32},
|
|
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|
%arg3: i32 {tt.divisibility = 16 : i32}) {
|
|
diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir
|
|
index 2c009ffa48..7e9cb9d504 100644
|
|
--- a/test/TritonGPU/combine.mlir
|
|
+++ b/test/TritonGPU/combine.mlir
|
|
@@ -9,7 +9,7 @@
|
|
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
|
// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
|
// CHECK-LABEL: cst
|
|
-func @cst() -> tensor<1024xi32, #layout1> {
|
|
+func.func @cst() -> tensor<1024xi32, #layout1> {
|
|
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
|
|
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
|
// CHECK-NOT: triton_gpu.convert_layout
|
|
@@ -18,7 +18,7 @@ func @cst() -> tensor<1024xi32, #layout1> {
|
|
}
|
|
|
|
// CHECK-LABEL: range
|
|
-func @range() -> tensor<1024xi32, #layout1> {
|
|
+func.func @range() -> tensor<1024xi32, #layout1> {
|
|
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
|
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
|
// CHECK-NOT: triton_gpu.convert_layout
|
|
@@ -27,7 +27,7 @@ func @range() -> tensor<1024xi32, #layout1> {
|
|
}
|
|
|
|
// CHECK-LABEL: splat
|
|
-func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|
+func.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|
%0 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0>
|
|
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
|
// CHECK-NOT: triton_gpu.convert_layout
|
|
@@ -36,7 +36,7 @@ func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|
}
|
|
|
|
// CHECK-LABEL: remat
|
|
-func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|
+func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
|
%1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
|
%2 = arith.muli %0, %1 : tensor<1024xi32, #layout0>
|
|
@@ -56,7 +56,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|
}
|
|
|
|
// CHECK-LABEL: remat_load_store
|
|
-func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
+func.func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
|
|
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
|
|
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
|
|
@@ -70,7 +70,7 @@ func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
|
|
// Don't rematerialize vectorized loads
|
|
// CHECK-LABEL: remat_expensive
|
|
-func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
+func.func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout1>
|
|
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout1>
|
|
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout1>, tensor<64xi32, #layout1>
|
|
@@ -85,7 +85,7 @@ func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
|
|
// Don't rematerialize loads when original and target layouts are different
|
|
// CHECK-LABEL: remat_multi_layout
|
|
-func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
+func.func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
|
|
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
|
|
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
|
|
@@ -100,7 +100,7 @@ func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
|
|
// Always rematerialize single value loads
|
|
// CHECK-LABEL: remat_single_value
|
|
-func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
+func.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<1x!tt.ptr<i32>, #layout1>
|
|
%1 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1xi32, #layout1>
|
|
// CHECK-NOT: triton_gpu.convert_layout
|
|
@@ -111,7 +111,7 @@ func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
}
|
|
|
|
// CHECK-LABEL: if
|
|
-func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
+func.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
// CHECK-NOT: triton_gpu.convert_layout
|
|
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1>
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
@@ -128,7 +128,7 @@ func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
}
|
|
|
|
// CHECK-LABEL: if_convert_else_not
|
|
-func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
+func.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
|
@@ -149,7 +149,7 @@ func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16
|
|
}
|
|
|
|
// CHECK-LABEL: if_not_else_convert
|
|
-func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
+func.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
|
@@ -170,7 +170,7 @@ func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16
|
|
}
|
|
|
|
// CHECK-LABEL: if_else_both_convert
|
|
-func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
+func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
|
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
|
|
@@ -200,7 +200,7 @@ func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16
|
|
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
|
|
|
// CHECK-LABEL: transpose
|
|
-func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
|
+func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
|
// CHECK-NOT: triton_gpu.convert_layout
|
|
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
|
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
|
@@ -241,7 +241,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
|
}
|
|
|
|
// CHECK-LABEL: loop
|
|
-func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
|
|
+func.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
|
|
// CHECK-NOT: triton_gpu.convert_layout
|
|
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>)
|
|
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]>
|
|
@@ -295,7 +295,7 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
|
}
|
|
|
|
// CHECK-LABEL: vecadd
|
|
-func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
|
+func.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
|
// CHECK-NOT: triton_gpu.convert_layout
|
|
%c256_i32 = arith.constant 256 : i32
|
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
|
@@ -327,7 +327,7 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
|
|
|
|
// Select has args with different element types
|
|
// CHECK-LABEL: select
|
|
-func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
|
|
+func.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
|
|
// CHECK-NOT: triton_gpu.convert_layout
|
|
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
|
|
%cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2>
|
|
@@ -378,7 +378,7 @@ func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f6
|
|
|
|
// Make sure the following IR doesn't hang the compiler.
|
|
// CHECK-LABEL: long_func
|
|
-func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
|
|
+func.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
|
|
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0>
|
|
%cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0>
|
|
%cst_1 = arith.constant dense<0.999499976> : tensor<1024xf32, #blocked0>
|
|
@@ -775,7 +775,7 @@ func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1:
|
|
// A mnist model from torch inductor.
|
|
// Check if topological sort is working correct and there's no unnecessary convert
|
|
// CHECK-LABEL: mnist
|
|
-func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
|
|
+func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
|
|
// CHECK-NOT: triton_gpu.convert_layout
|
|
%cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2>
|
|
%cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3>
|
|
@@ -862,7 +862,7 @@ func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.
|
|
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
|
// cmpf and cmpi have different operands and result types
|
|
// CHECK-LABEL: cmp
|
|
-func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
|
+func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
|
%c64 = arith.constant 64 : index
|
|
%c2048 = arith.constant 2048 : index
|
|
%c0 = arith.constant 0 : index
|
|
diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir
|
|
index 6ee3b15fbc..663f2da7b0 100644
|
|
--- a/test/TritonGPU/loop-pipeline.mlir
|
|
+++ b/test/TritonGPU/loop-pipeline.mlir
|
|
@@ -10,7 +10,7 @@
|
|
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
|
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
|
|
|
-// CHECK: func @matmul_loop
|
|
+// CHECK: func.func @matmul_loop
|
|
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
|
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
|
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
|
@@ -46,8 +46,8 @@
|
|
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
|
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
|
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
|
-func @matmul_loop(%lb : index, %ub : index, %step : index,
|
|
- %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
|
+func.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
|
+ %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
|
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
|
// A ptrs
|
|
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
|
@@ -61,7 +61,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index,
|
|
%b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL>
|
|
%b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL>
|
|
%b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
|
-
|
|
+
|
|
|
|
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
|
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
|
|
@@ -88,7 +88,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index,
|
|
}
|
|
|
|
|
|
-// CHECK: func @matmul_loop_nested
|
|
+// CHECK: func.func @matmul_loop_nested
|
|
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
|
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
|
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
|
@@ -118,8 +118,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index,
|
|
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
|
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
|
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
|
-func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
|
- %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
|
+func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
|
+ %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
|
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
|
scf.for %iv0 = %lb to %ub step %step {
|
|
// A ptrs
|
|
@@ -134,7 +134,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
|
%b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL>
|
|
%b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL>
|
|
%b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
|
-
|
|
+
|
|
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
|
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
|
|
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
|
|
@@ -161,7 +161,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
|
}
|
|
|
|
|
|
-// CHECK: func @matmul_loop_single_pipeline
|
|
+// CHECK: func.func @matmul_loop_single_pipeline
|
|
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
|
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
|
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
|
|
@@ -183,8 +183,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
|
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
|
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
|
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
|
-func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
|
- %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
|
+func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
|
+ %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
|
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
|
// A ptrs
|
|
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
|
diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir
|
|
index 9bd5318e1e..01dc3f0ab1 100644
|
|
--- a/test/TritonGPU/matmul.mlir
|
|
+++ b/test/TritonGPU/matmul.mlir
|
|
@@ -4,7 +4,7 @@
|
|
// CHECK: offset = 49152, size = 49152
|
|
// CHECK: size = 98304
|
|
module {
|
|
-func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
|
|
+func.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
|
|
%cst = arith.constant dense<true> : tensor<64x64xi1>
|
|
%c64 = arith.constant 64 : index
|
|
%c0 = arith.constant 0 : index
|
|
@@ -22,7 +22,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
|
%7 = arith.muli %6, %c8_i32 : i32
|
|
%8 = arith.subi %2, %7 : i32
|
|
%9 = arith.cmpi slt, %8, %c8_i32 : i32
|
|
- %10 = select %9, %8, %c8_i32 : i32
|
|
+ %10 = arith.select %9, %8, %c8_i32 : i32
|
|
%11 = arith.remsi %0, %10 : i32
|
|
%12 = arith.addi %7, %11 : i32
|
|
%13 = arith.remsi %0, %5 : i32
|
|
diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir
|
|
index 52b4dddec1..b427547890 100644
|
|
--- a/test/TritonGPU/prefetch.mlir
|
|
+++ b/test/TritonGPU/prefetch.mlir
|
|
@@ -11,7 +11,7 @@
|
|
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
|
|
|
|
|
-// CHECK: func @matmul_loop
|
|
+// CHECK: func.func @matmul_loop
|
|
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[A0:.*]][0, 0] [128, 16]
|
|
// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]]
|
|
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[B0:.*]][0, 0] [16, 128]
|
|
@@ -28,7 +28,7 @@
|
|
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [16, 128]
|
|
// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]]
|
|
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]]
|
|
-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
|
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
|
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
|
|
|
diff --git a/test/TritonGPU/update-mma-for-volta.mlir b/test/TritonGPU/update-mma-for-volta.mlir
|
|
index d587fffcca..7571ec6185 100644
|
|
--- a/test/TritonGPU/update-mma-for-volta.mlir
|
|
+++ b/test/TritonGPU/update-mma-for-volta.mlir
|
|
@@ -15,7 +15,7 @@
|
|
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
|
|
module attributes {"triton_gpu.num-warps" = 16 : i32} {
|
|
// CHECK-LABEL: dot_mmav1
|
|
- func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
|
+ func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
|
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0>
|
|
%AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a>
|
|
%BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b>
|
|
@@ -50,7 +50,7 @@ module attributes {"triton_gpu.num-warps" = 16 : i32} {
|
|
|
|
module attributes {"triton_gpu.num-warps" = 16 : i32} {
|
|
// CHECK-LABEL: dot_mmav1
|
|
- func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
|
+ func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
|
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0>
|
|
%AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a>
|
|
%BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b>
|
|
diff --git a/test/lib/Analysis/TestAlias.cpp b/test/lib/Analysis/TestAlias.cpp
|
|
index 88a4118fe9..3fd0cfd0d3 100644
|
|
--- a/test/lib/Analysis/TestAlias.cpp
|
|
+++ b/test/lib/Analysis/TestAlias.cpp
|
|
@@ -9,10 +9,10 @@ using namespace mlir;
|
|
namespace {
|
|
|
|
struct TestAliasPass
|
|
- : public PassWrapper<TestAliasPass, OperationPass<FuncOp>> {
|
|
+ : public PassWrapper<TestAliasPass, OperationPass<func::FuncOp>> {
|
|
+
|
|
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
|
|
|
|
- // LLVM15+
|
|
- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
|
|
static void print(StringRef name, SmallVector<std::string, 4> &vals,
|
|
raw_ostream &os) {
|
|
if (vals.empty())
|
|
@@ -39,23 +39,24 @@ struct TestAliasPass
|
|
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
|
os << opName << "\n";
|
|
|
|
- SharedMemoryAliasAnalysis analysis(&getContext());
|
|
- analysis.run(operation);
|
|
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
|
+ SharedMemoryAliasAnalysis *analysis =
|
|
+ solver->load<SharedMemoryAliasAnalysis>();
|
|
+ if (failed(solver->initializeAndRun(operation)))
|
|
+ return signalPassFailure();
|
|
|
|
AsmState state(operation->getParentOfType<ModuleOp>());
|
|
// Get operation ids of value's aliases
|
|
auto getAllocOpNames = [&](Value value) {
|
|
- LatticeElement<AliasInfo> *latticeElement =
|
|
- analysis.lookupLatticeElement(value);
|
|
+ dataflow::Lattice<AliasInfo> *latticeElement =
|
|
+ analysis->getLatticeElement(value);
|
|
SmallVector<std::string, 4> opNames;
|
|
- if (latticeElement) {
|
|
+ if (latticeElement && !latticeElement->isUninitialized()) {
|
|
auto &info = latticeElement->getValue();
|
|
- if (!info.getAllocs().empty()) {
|
|
- for (auto &alias : info.getAllocs()) {
|
|
- auto opName =
|
|
- getValueOperandName(alias.getDefiningOp()->getResult(0), state);
|
|
- opNames.push_back(std::move(opName));
|
|
- }
|
|
+ for (auto &alias : info.getAllocs()) {
|
|
+ auto opName =
|
|
+ getValueOperandName(alias.getDefiningOp()->getResult(0), state);
|
|
+ opNames.push_back(std::move(opName));
|
|
}
|
|
}
|
|
// Ensure deterministic output
|
|
diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp
|
|
index 84108c4d36..35e42242bd 100644
|
|
--- a/test/lib/Analysis/TestAllocation.cpp
|
|
+++ b/test/lib/Analysis/TestAllocation.cpp
|
|
@@ -6,10 +6,9 @@ using namespace mlir;
|
|
namespace {
|
|
|
|
struct TestAllocationPass
|
|
- : public PassWrapper<TestAllocationPass, OperationPass<FuncOp>> {
|
|
+ : public PassWrapper<TestAllocationPass, OperationPass<func::FuncOp>> {
|
|
|
|
- // LLVM15+
|
|
- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
|
|
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
|
|
|
|
StringRef getArgument() const final { return "test-print-allocation"; }
|
|
StringRef getDescription() const final {
|
|
diff --git a/test/lib/Analysis/TestAxisInfo.cpp b/test/lib/Analysis/TestAxisInfo.cpp
|
|
index a5205bb0a0..22347c32f0 100644
|
|
--- a/test/lib/Analysis/TestAxisInfo.cpp
|
|
+++ b/test/lib/Analysis/TestAxisInfo.cpp
|
|
@@ -1,25 +1,15 @@
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "triton/Analysis/AxisInfo.h"
|
|
+#include "triton/Analysis/Utility.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
struct TestAxisInfoPass
|
|
- : public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>> {
|
|
+ : public PassWrapper<TestAxisInfoPass, OperationPass<func::FuncOp>> {
|
|
|
|
- // LLVM15+
|
|
- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
|
|
-
|
|
- void print(const std::string &name, raw_ostream &os, ArrayRef<int64_t> vals) {
|
|
- os << name << ": [";
|
|
- for (size_t d = 0; d < vals.size(); d++) {
|
|
- if (d != 0)
|
|
- os << ", ";
|
|
- os << vals[d];
|
|
- }
|
|
- os << "]";
|
|
- }
|
|
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
|
|
|
|
StringRef getArgument() const final { return "test-print-alignment"; }
|
|
StringRef getDescription() const final {
|
|
@@ -30,38 +20,19 @@ struct TestAxisInfoPass
|
|
Operation *operation = getOperation();
|
|
auto &os = llvm::errs();
|
|
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
|
- os << opName << "\n";
|
|
- AxisInfoAnalysis analysis(&getContext());
|
|
- analysis.run(operation);
|
|
+ os << "@" << opName << "\n";
|
|
+
|
|
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
|
+ AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
|
|
+ if (failed(solver->initializeAndRun(operation)))
|
|
+ return signalPassFailure();
|
|
operation->walk([&](Operation *op) {
|
|
if (op->getNumResults() < 1)
|
|
return;
|
|
for (Value result : op->getResults()) {
|
|
- // std::ostringstream oss;
|
|
- // result.print(oss);
|
|
- // os << " => ";
|
|
- LatticeElement<AxisInfo> *latticeElement =
|
|
- analysis.lookupLatticeElement(result);
|
|
- if (!latticeElement) {
|
|
- os << "None\n";
|
|
- return;
|
|
- }
|
|
- AxisInfo &info = latticeElement->getValue();
|
|
- print("Contiguity", os, info.getContiguity());
|
|
- os << " ; ";
|
|
- print("Divisibility", os, info.getDivisibility());
|
|
- os << " ; ";
|
|
- print("Constancy", os, info.getConstancy());
|
|
- os << " ; ";
|
|
- auto constantValue = info.getConstantValue();
|
|
- os << "ConstantValue: [";
|
|
- if (constantValue.has_value())
|
|
- os << constantValue.value();
|
|
- else
|
|
- os << "None";
|
|
- os << "] ( ";
|
|
result.print(os);
|
|
- os << " ) ";
|
|
+ os << " => ";
|
|
+ analysis->getLatticeElement(result)->getValue().print(os);
|
|
os << "\n";
|
|
}
|
|
});
|
|
diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp
|
|
index df4279fe24..ab9b9f3fb7 100644
|
|
--- a/test/lib/Analysis/TestMembar.cpp
|
|
+++ b/test/lib/Analysis/TestMembar.cpp
|
|
@@ -1,4 +1,4 @@
|
|
-#include "mlir/Dialect/GPU/GPUDialect.h"
|
|
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "triton/Analysis/Allocation.h"
|
|
@@ -9,10 +9,9 @@ using namespace mlir;
|
|
namespace {
|
|
|
|
struct TestMembarPass
|
|
- : public PassWrapper<TestMembarPass, OperationPass<FuncOp>> {
|
|
+ : public PassWrapper<TestMembarPass, OperationPass<func::FuncOp>> {
|
|
|
|
- // LLVM15+
|
|
- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass);
|
|
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass);
|
|
|
|
StringRef getArgument() const final { return "test-print-membar"; }
|
|
StringRef getDescription() const final {
|