From 6b350d2465ed5b1fd44e0ef4df5bb27c8ab0e7b0 Mon Sep 17 00:00:00 2001 From: Zhang Shuo <52872288+fuyou4546@users.noreply.github.com> Date: Mon, 8 Jun 2026 05:20:34 +0000 Subject: [PATCH 1/2] feat(triton): add AOT backend with Add operator --- CMakeLists.txt | 10 ++ scripts/generate_triton_ops.py | 83 +++++++++++++ scripts/generate_wrappers.py | 11 +- src/CMakeLists.txt | 45 +++++++ src/triton/ops/add/add.h | 208 +++++++++++++++++++++++++++++++++ src/triton/ops/add/add.py | 52 +++++++++ src/triton/ops/add/build.py | 114 ++++++++++++++++++ 7 files changed, 522 insertions(+), 1 deletion(-) create mode 100644 scripts/generate_triton_ops.py create mode 100644 src/triton/ops/add/add.h create mode 100644 src/triton/ops/add/add.py create mode 100644 src/triton/ops/add/build.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 9acb2d033..3b726b917 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,8 @@ option(WITH_TORCH "Enable PyTorch C++ backend" OFF) option(WITH_NINETOOTHED "Enable NineToothed-generated kernels" OFF) +option(WITH_TRITON "Enable Triton-generated kernels" OFF) + # Default OFF until CANN's `extract_host_stub.py` path handling is fixed for # `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed # object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the @@ -316,6 +318,14 @@ if(WITH_NINETOOTHED) set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run NineToothed code generation") endif() +if(WITH_TRITON AND NOT WITH_NVIDIA) + message(FATAL_ERROR "`WITH_TRITON` temporarily requires `WITH_NVIDIA=ON` because Triton AOT temporarily targets CUDA.") +endif() + +if(WITH_TRITON) + set(TRITON_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run Triton AOT code generation") +endif() + if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) enable_language(CUDA) diff --git a/scripts/generate_triton_ops.py b/scripts/generate_triton_ops.py new file mode 100644 index 000000000..917981b67 --- /dev/null +++ b/scripts/generate_triton_ops.py @@ -0,0 +1,83 @@ +import argparse +import importlib.util +import pathlib +import shutil +import sys + +_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[1] +_OPS_DIR = _PROJECT_DIR / "src" / "triton" / "ops" + + +def _find_op_modules(): + return { + path.parent.name: path + for path in sorted(_OPS_DIR.glob("*/build.py")) + if path.is_file() + } + + +def _build_manifest(output_dir): + return sorted(str(path) for path in pathlib.Path(output_dir).rglob("*.c")) + + +def _write_cmake_manifest(output_dir, sources): + manifest_path = pathlib.Path(output_dir) / "manifest.cmake" + lines = ["set(INFINIOPS_TRITON_SOURCES"] + lines.extend(f' "{source}"' for source in sources) + lines.append(")") + lines.append("") + lines.append(f'set(INFINIOPS_TRITON_INCLUDE_DIRS "{output_dir}")') + lines.append("") + manifest_path.write_text("\n".join(lines) + "\n") + + +def _load_op_module(op): + path = _find_op_modules()[op] + sys.path.insert(0, str(path.parent)) + spec = importlib.util.spec_from_file_location(path.stem, path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + return module + + +def generate(ops, *, output_dir): + op_modules = _find_op_modules() + unknown_ops = tuple(op for op in ops if op not in op_modules) + + if unknown_ops: + raise ValueError(f"unsupported Triton ops: {', '.join(unknown_ops)}") + + output_dir = pathlib.Path(output_dir) + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.mkdir(parents=True, exist_ok=True) + + for op in ops: + module = _load_op_module(op) + module.build(output_dir) + + sources = _build_manifest(output_dir) + _write_cmake_manifest(output_dir, sources) + + return sources + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate Triton operator sources for InfiniOps." + ) + parser.add_argument("--output-dir", required=True) + parser.add_argument("--ops", nargs="+", default=tuple(_find_op_modules())) + + return parser.parse_args() + + +def main(): + args = _parse_args() + generate(args.ops, output_dir=args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 22be3a48f..8492b845a 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1183,13 +1183,15 @@ def _filter_ops(ops, op_allowlist, *, strict=False): return {op_name: ops[op_name] for op_name in op_allowlist if op_name in ops} -def _get_all_ops(devices, with_torch=False, with_ninetoothed=False): +def _get_all_ops(devices, with_torch=False, with_ninetoothed=False, with_triton=False): scan_dirs = set(devices) if with_torch: scan_dirs.add("torch") if with_ninetoothed: scan_dirs.add("ninetoothed") + if with_triton: + scan_dirs.add("triton") ops = {} @@ -1328,6 +1330,12 @@ def _dispatch_gen_batch_size(): help="Fail if `--ops` contains operators unavailable for the active devices.", ) + parser.add_argument( + "--with-triton", + action="store_true", + help="Include Triton backend implementations.", + ) + args = parser.parse_args() # Wipe previous outputs so files for ops that have since been removed @@ -1348,6 +1356,7 @@ def _dispatch_gen_batch_size(): args.devices, with_torch=args.with_torch, with_ninetoothed=args.with_ninetoothed, + with_triton=args.with_triton, ) ops = _filter_ops( diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7cd3c0043..4a465b356 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -84,6 +84,41 @@ if(WITH_NINETOOTHED) target_sources(infiniops PRIVATE ${INFINI_OPS_NINETOOTHED_SOURCES}) endif() +if(WITH_TRITON) + find_package(Python COMPONENTS Interpreter REQUIRED) + + if(TRITON_PYTHON_EXECUTABLE) + set(_triton_python "${TRITON_PYTHON_EXECUTABLE}") + elseif(_TORCH_PYTHON) + set(_triton_python "${_TORCH_PYTHON}") + else() + set(_triton_python "${Python_EXECUTABLE}") + endif() + message(STATUS "Triton codegen Python: ${_triton_python}") + + set(_triton_output_dir "${CMAKE_CURRENT_BINARY_DIR}/triton") + set(_triton_generator_args + "${PROJECT_SOURCE_DIR}/scripts/generate_triton_ops.py" + --output-dir "${_triton_output_dir}") + + execute_process( + COMMAND "${_triton_python}" ${_triton_generator_args} + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + RESULT_VARIABLE _triton_generation_result + ) + + if(NOT _triton_generation_result EQUAL 0) + message(FATAL_ERROR "Generating Triton AOT operator sources failed with `${_triton_python}`. Set `TRITON_PYTHON_EXECUTABLE` to a Python with `triton` and CUDA dependencies installed.") + endif() + + enable_language(C) + + include("${_triton_output_dir}/manifest.cmake") + target_compile_definitions(infiniops PUBLIC WITH_TRITON=1) + target_include_directories(infiniops PRIVATE ${INFINIOPS_TRITON_INCLUDE_DIRS}) + target_sources(infiniops PRIVATE ${INFINIOPS_TRITON_SOURCES}) +endif() + if(WITH_ILUVATAR) set(ILUVATAR_PATTERNS "native/cuda/*.cc" @@ -552,6 +587,10 @@ if(GENERATE_OPERATOR_CALL_INSTANTIATIONS OR GENERATE_PYTHON_BINDINGS) list(APPEND GENERATOR_ARGS --with-ninetoothed) endif() + if(WITH_TRITON) + list(APPEND GENERATOR_ARGS --with-triton) + endif() + execute_process( COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS} WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} @@ -791,6 +830,12 @@ if(GENERATE_PYTHON_BINDINGS) target_include_directories(ops PRIVATE ${INFINI_OPS_NINETOOTHED_INCLUDE_DIRS}) endif() + + if(WITH_TRITON) + target_include_directories(ops PRIVATE + ${INFINIOPS_TRITON_INCLUDE_DIRS}) + endif() + target_link_libraries(ops PRIVATE infiniops) # Cambricon generated dispatch is compiled into the Python extension and diff --git a/src/triton/ops/add/add.h b/src/triton/ops/add/add.h new file mode 100644 index 000000000..70b0525c3 --- /dev/null +++ b/src/triton/ops/add/add.h @@ -0,0 +1,208 @@ +#ifndef INFINI_OPS_TRITON_ADD_H_ +#define INFINI_OPS_TRITON_ADD_H_ + +#include + +#include +#include +#include +#include +#include + +#include "base/add.h" +#include "data_type.h" + +extern "C" { +#include "add/infini_ops_triton_add_bf16.h" +#include "add/infini_ops_triton_add_fp16.h" +#include "add/infini_ops_triton_add_fp32.h" +#include "add/infini_ops_triton_add_fp64.h" +#include "add/infini_ops_triton_add_i16.h" +#include "add/infini_ops_triton_add_i32.h" +#include "add/infini_ops_triton_add_i64.h" +#include "add/infini_ops_triton_add_i8.h" +#include "add/infini_ops_triton_add_u16.h" +#include "add/infini_ops_triton_add_u32.h" +#include "add/infini_ops_triton_add_u64.h" +#include "add/infini_ops_triton_add_u8.h" +} + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + using Add::operator(); + + Operator(const Tensor input, const Tensor other, Tensor out) + : Add{input, other, out} { + const int ndim = static_cast(ndim_); + std::vector h_metadata(4 * std::max(ndim, 1), 0); + for (int i = 0; i < ndim; ++i) { + h_metadata[0 * ndim + i] = static_cast(out_shape_[i]); + h_metadata[1 * ndim + i] = static_cast(input_strides_[i]); + h_metadata[2 * ndim + i] = static_cast(other_strides_[i]); + h_metadata[3 * ndim + i] = static_cast(out_strides_[i]); + } + + const size_t bytes = h_metadata.size() * sizeof(int64_t); + cuMemAlloc(&d_metadata_, bytes); + cuMemcpyHtoD(d_metadata_, h_metadata.data(), bytes); + + const size_t stride_bytes = ndim * sizeof(int64_t); + d_out_shape_ = d_metadata_ + 0 * stride_bytes; + d_input_stride_ = d_metadata_ + 1 * stride_bytes; + d_other_stride_ = d_metadata_ + 2 * stride_bytes; + d_out_stride_ = d_metadata_ + 3 * stride_bytes; + } + + ~Operator() { + if (d_metadata_) { + cuMemFree(d_metadata_); + } + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + EnsureLoaded(out.dtype()); + + CUstream stream = static_cast(stream_); + auto x = reinterpret_cast(const_cast(input.data())); + auto y = reinterpret_cast(const_cast(other.data())); + auto o = reinterpret_cast(out.data()); + + int32_t n = static_cast(out.numel()); + int32_t ndim_val = static_cast(ndim_); + int32_t x_contig = static_cast(is_input_contiguous_); + int32_t y_contig = static_cast(is_other_contiguous_); + int32_t out_contig = static_cast(is_out_contiguous_); + + CUresult rc = CUDA_ERROR_INVALID_VALUE; + switch (out.dtype()) { + case DataType::kFloat16: + rc = infini_ops_triton_add_fp16_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kBFloat16: + rc = infini_ops_triton_add_bf16_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kFloat32: + rc = infini_ops_triton_add_fp32_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kFloat64: + rc = infini_ops_triton_add_fp64_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kInt8: + rc = infini_ops_triton_add_i8_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kInt16: + rc = infini_ops_triton_add_i16_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kInt32: + rc = infini_ops_triton_add_i32_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kInt64: + rc = infini_ops_triton_add_i64_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kUInt8: + rc = infini_ops_triton_add_u8_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kUInt16: + rc = infini_ops_triton_add_u16_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kUInt32: + rc = infini_ops_triton_add_u32_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + case DataType::kUInt64: + rc = infini_ops_triton_add_u64_default( + stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, + d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); + break; + default: + assert(false && "triton `Add` does not support this dtype"); + } + + assert(rc == CUDA_SUCCESS && "Triton `Add` launch failed"); + } + + private: + CUdeviceptr d_metadata_{0}; + + CUdeviceptr d_out_shape_{0}; + + CUdeviceptr d_input_stride_{0}; + + CUdeviceptr d_other_stride_{0}; + + CUdeviceptr d_out_stride_{0}; + + static void EnsureLoaded(DataType dtype) { + static std::once_flag fp16, bf16, fp32, fp64, i8, i16, i32, i64, u8, u16, + u32, u64; + switch (dtype) { + case DataType::kFloat16: + std::call_once(fp16, &load_infini_ops_triton_add_fp16); + break; + case DataType::kBFloat16: + std::call_once(bf16, &load_infini_ops_triton_add_bf16); + break; + case DataType::kFloat32: + std::call_once(fp32, &load_infini_ops_triton_add_fp32); + break; + case DataType::kFloat64: + std::call_once(fp64, &load_infini_ops_triton_add_fp64); + break; + case DataType::kInt8: + std::call_once(i8, &load_infini_ops_triton_add_i8); + break; + case DataType::kInt16: + std::call_once(i16, &load_infini_ops_triton_add_i16); + break; + case DataType::kInt32: + std::call_once(i32, &load_infini_ops_triton_add_i32); + break; + case DataType::kInt64: + std::call_once(i64, &load_infini_ops_triton_add_i64); + break; + case DataType::kUInt8: + std::call_once(u8, &load_infini_ops_triton_add_u8); + break; + case DataType::kUInt16: + std::call_once(u16, &load_infini_ops_triton_add_u16); + break; + case DataType::kUInt32: + std::call_once(u32, &load_infini_ops_triton_add_u32); + break; + case DataType::kUInt64: + std::call_once(u64, &load_infini_ops_triton_add_u64); + break; + default: + break; + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/triton/ops/add/add.py b/src/triton/ops/add/add.py new file mode 100644 index 000000000..eb3e35724 --- /dev/null +++ b/src/triton/ops/add/add.py @@ -0,0 +1,52 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + x_ptr, + y_ptr, + out_ptr, + out_shape_ptr, + x_stride_ptr, + y_stride_ptr, + out_stride_ptr, + x_contig, + y_contig, + out_contig, + ndim, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = (pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)).to(tl.int64) + mask = offsets < n_elements + + if (x_contig != 0) and (y_contig != 0) and (out_contig != 0): + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x + y, mask=mask) + else: + x_offs = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + y_offs = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + out_offs = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + tmp = offsets + + for i in range(ndim): + s = tl.load(out_shape_ptr + (ndim - 1 - i)) + d = tmp % s + tmp = tmp // s + x_offs += d * tl.load(x_stride_ptr + (ndim - 1 - i)) + y_offs += d * tl.load(y_stride_ptr + (ndim - 1 - i)) + out_offs += d * tl.load(out_stride_ptr + (ndim - 1 - i)) + + if x_contig != 0: + x_offs = offsets + if y_contig != 0: + y_offs = offsets + if out_contig != 0: + out_offs = offsets + + x = tl.load(x_ptr + x_offs, mask=mask) + y = tl.load(y_ptr + y_offs, mask=mask) + tl.store(out_ptr + out_offs, x + y, mask=mask) diff --git a/src/triton/ops/add/build.py b/src/triton/ops/add/build.py new file mode 100644 index 000000000..28f77872a --- /dev/null +++ b/src/triton/ops/add/build.py @@ -0,0 +1,114 @@ +import pathlib + +from triton.tools.compile import CompileArgs, compile_kernel +from triton.tools import link + +_KERNEL_PATH = pathlib.Path(__file__).parent / "add.py" +_KERNEL_NAME = "kernel" +_DTYPES = ( + "fp16", + "bf16", + "fp32", + "fp64", + "i8", + "i16", + "i32", + "i64", + "u8", + "u16", + "u32", + "u64", +) +_BLOCK_SIZES = (512, 1024) +_NUM_WARPS = 4 +_NUM_STAGES = 3 + + +def _compile_variants(variant_dir, dtype): + out_name = f"infini_ops_triton_add_{dtype}" + headers = [] + for block_size in _BLOCK_SIZES: + aligned_sig = ( + f"*{dtype}:16, *{dtype}:16, *{dtype}:16, " + f"*i64, *i64, *i64, *i64, " + f"i32, i32, i32, i32, i32, {block_size}" + ) + _, files = compile_kernel( + CompileArgs( + path=str(_KERNEL_PATH), + kernel_name=_KERNEL_NAME, + signature=aligned_sig, + grid=f"(n_elements + {block_size} - 1) / {block_size}, 1, 1", + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + out_name=out_name, + out_path=variant_dir / out_name, + target=None, + ) + ) + headers.extend(f for f in files if f.suffix == ".h") + + generic_sig = ( + f"*{dtype}, *{dtype}, *{dtype}, " + f"*i64, *i64, *i64, *i64, " + f"i32, i32, i32, i32, i32, {block_size}" + ) + _, files = compile_kernel( + CompileArgs( + path=str(_KERNEL_PATH), + kernel_name=_KERNEL_NAME, + signature=generic_sig, + grid=f"(n_elements + {block_size} - 1) / {block_size}, 1, 1", + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + out_name=out_name, + out_path=variant_dir / out_name, + target=None, + ) + ) + headers.extend(f for f in files if f.suffix == ".h") + return headers, out_name + + +def _link_one_dtype(variant_dir, headers, out_name): + parser = link.HeaderParser() + for h in headers: + parser.extract_linker_meta(h.read_text()) + + out_base = variant_dir / out_name + first_meta = next(iter(parser.kernels.values()))[0] + backend_prelude = ( + pathlib.Path(link.__file__).parent / "extra" / parser.backend_name / "link.h" + ).read_text() + + algo_decls = [link.make_algo_decls(name, m) for name, m in parser.kernels.items()] + out_base.with_suffix(".h").write_text( + backend_prelude + + "\n".join(algo_decls) + + "\n" + + link.make_get_num_algos_decl(first_meta) + + "\n" + + link.make_global_decl(first_meta) + ) + defs = [ + link.make_kernel_hints_dispatcher(name, m) for name, m in parser.kernels.items() + ] + names = list(parser.kernels.keys()) + src = backend_prelude + src += "#include \n#include \n\n" + src += "\n".join(defs) + "\n" + src += link.make_func_pointers(names, first_meta) + "\n" + src += link.make_get_num_algos_def(first_meta) + "\n" + src += link.make_kernel_meta_const_dispatcher(first_meta) + "\n" + src += link.make_kernel_load_def(names, first_meta) + "\n" + src += link.make_default_algo_kernel(first_meta) + out_base.with_suffix(".c").write_text(src) + + +def build(output_dir: pathlib.Path): + variant_dir = output_dir / "add" + variant_dir.mkdir(parents=True, exist_ok=True) + + for dtype in _DTYPES: + headers, out_name = _compile_variants(variant_dir, dtype) + _link_one_dtype(variant_dir, headers, out_name) From 8c15446b631487d16a7d64a3002b78c3d4daee8a Mon Sep 17 00:00:00 2001 From: Zhang Shuo <52872288+fuyou4546@users.noreply.github.com> Date: Tue, 16 Jun 2026 13:31:19 +0000 Subject: [PATCH 2/2] refactor(triton): encapsulate aot build script --- scripts/triton/aot.py | 253 ++++++++++++++++++ .../generate_ops.py} | 49 +++- src/CMakeLists.txt | 2 +- src/triton/ops/add/add.h | 135 +--------- src/triton/ops/add/build.py | 116 ++------ 5 files changed, 330 insertions(+), 225 deletions(-) create mode 100644 scripts/triton/aot.py rename scripts/{generate_triton_ops.py => triton/generate_ops.py} (59%) diff --git a/scripts/triton/aot.py b/scripts/triton/aot.py new file mode 100644 index 000000000..90729a9d5 --- /dev/null +++ b/scripts/triton/aot.py @@ -0,0 +1,253 @@ +import ast +from dataclasses import dataclass +import pathlib +from typing import Any, Sequence + +from triton.tools import link +from triton.tools.compile import CompileArgs, compile_kernel + + +@dataclass(frozen=True) +class Signature: + pointer_dtypes: dict[str, str] + pointer_alignments: dict[str, int | None] | None = None + scalar_dtypes: dict[str, str] | None = None + constexprs: dict[str, Any] | None = None + + +@dataclass(frozen=True) +class CompileConfig: + signature: Signature + grid: str + out_name: str + num_warps: int = 4 + num_stages: int = 3 + target: Any = None + + +def compile( + config: CompileConfig, + *, + path: pathlib.Path, + kernel_name: str, + out_dir: pathlib.Path, + kernel_args: Sequence[str], +) -> list[pathlib.Path]: + _, files = compile_kernel( + CompileArgs( + path=str(path), + kernel_name=kernel_name, + signature=_render_signature(config.signature, kernel_args), + grid=config.grid, + num_warps=config.num_warps, + num_stages=config.num_stages, + out_name=config.out_name, + out_path=out_dir / config.out_name, + target=config.target, + ) + ) + + return [path for path in files if path.suffix == ".h"] + + +def link_headers(headers: Sequence[pathlib.Path], out_base: pathlib.Path): + parser = link.HeaderParser() + for header in headers: + parser.extract_linker_meta(header.read_text()) + + first_meta = next(iter(parser.kernels.values()))[0] + backend_prelude = ( + pathlib.Path(link.__file__).parent / "extra" / parser.backend_name / "link.h" + ).read_text() + + out_base.with_suffix(".h").write_text( + backend_prelude + + "\n".join( + link.make_algo_decls(name, meta) for name, meta in parser.kernels.items() + ) + + "\n" + + link.make_get_num_algos_decl(first_meta) + + "\n" + + link.make_global_decl(first_meta) + ) + + names = list(parser.kernels) + defs = [ + link.make_kernel_hints_dispatcher(name, meta) + for name, meta in parser.kernels.items() + ] + + out_base.with_suffix(".c").write_text( + backend_prelude + + "#include \n#include \n\n" + + "\n".join(defs) + + "\n" + + link.make_func_pointers(names, first_meta) + + "\n" + + link.make_get_num_algos_def(first_meta) + + "\n" + + link.make_kernel_meta_const_dispatcher(first_meta) + + "\n" + + link.make_kernel_load_def(names, first_meta) + + "\n" + + link.make_default_algo_kernel(first_meta) + ) + + +def build( + configs: Sequence[CompileConfig], + *, + path: pathlib.Path, + kernel_name: str, + out_dir: pathlib.Path, + kernel_args: Sequence[str], +) -> pathlib.Path: + if not configs: + raise ValueError("empty compile configs") + + out_name = configs[0].out_name + out_dir.mkdir(parents=True, exist_ok=True) + + headers = [] + for config in configs: + headers.extend( + compile( + config, + path=path, + kernel_name=kernel_name, + out_dir=out_dir, + kernel_args=kernel_args, + ) + ) + + if not headers: + raise ValueError(f"no headers generated for {out_name}") + + out_base = out_dir / out_name + link_headers(headers, out_base) + + return out_base + + +def write_header( + headers: Sequence[pathlib.Path], + out_path: pathlib.Path, + *, + op_name: str, + configs: Sequence[CompileConfig], + kernel_args: Sequence[str], +): + guard = f"INFINI_OPS_GENERATED_{out_path.stem.upper()}_H_" + includes = "\n".join(f'#include "{header.name}"' for header in headers) + params = _dispatch_params(configs[0].signature, kernel_args) + param_decls = ", ".join(f"{ty} {name}" for ty, name in params) + param_names = ", ".join(name for _, name in params) + + body = f"#ifndef {guard}\n#define {guard}\n\n" + body += f'extern "C" {{\n{includes}\n}}\n\n' + body += '#include \n\n#include "data_type.h"\n\n' + body += "namespace infini::ops {\n\n" + + body += f"inline TT_ResultTy launch_infini_ops_triton_{op_name}(\n" + body += f" DataType dtype, TT_StreamTy stream, {param_decls}) {{\n" + body += " switch (dtype) {\n" + for config in configs: + dtype = _out_dtype(config.out_name) + body += f" case DataType::{_data_type(dtype)}:\n" + body += f" return {config.out_name}_default(stream, {param_names});\n" + body += " default:\n return TT_ERROR_INVALID_VALUE;\n }\n}\n\n" + + body += f"inline void load_infini_ops_triton_{op_name}(DataType dtype) {{\n" + body += " switch (dtype) {\n" + for config in configs: + dtype = _out_dtype(config.out_name) + body += f" case DataType::{_data_type(dtype)}: {{\n" + body += " static std::once_flag once;\n" + body += f" std::call_once(once, &load_{config.out_name});\n" + body += " return;\n }\n" + body += " default:\n return;\n }\n}\n\n" + + body += "} // namespace infini::ops\n\n#endif\n" + out_path.write_text(body) + + +def kernel_args(path, kernel_name): + tree = ast.parse(pathlib.Path(path).read_text()) + for node in tree.body: + if isinstance(node, ast.FunctionDef) and node.name == kernel_name: + return tuple(arg.arg for arg in node.args.args) + raise ValueError(f"kernel {kernel_name} not found in {path}") + + +def _render_signature(signature: Signature, args: Sequence[str]) -> str: + pointer_alignments = signature.pointer_alignments or {} + scalar_dtypes = signature.scalar_dtypes or {} + constexprs = signature.constexprs or {} + + parts = [] + for arg in args: + if arg in constexprs: + parts.append(str(constexprs[arg])) + elif arg in signature.pointer_dtypes: + parts.append( + _ptr(signature.pointer_dtypes[arg], pointer_alignments.get(arg)) + ) + elif arg in scalar_dtypes: + parts.append(str(scalar_dtypes[arg])) + else: + raise ValueError(f"missing signature rule for {arg}") + + return ", ".join(parts) + + +def _dispatch_params(signature: Signature, args: Sequence[str]): + scalar_dtypes = signature.scalar_dtypes or {} + constexprs = signature.constexprs or {} + + params = [] + for arg in args: + if arg in constexprs: + continue + if arg in signature.pointer_dtypes: + params.append(("CUdeviceptr", arg)) + elif arg in scalar_dtypes: + params.append((_scalar_ctype(scalar_dtypes[arg]), arg)) + else: + raise ValueError(f"missing dispatch rule for {arg}") + return params + + +def _scalar_ctype(dtype): + return { + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp32": "float", + "fp64": "double", + }[dtype] + + +def _out_dtype(out_name): + return out_name.rsplit("_", 1)[1] + + +def _data_type(dtype): + return { + "fp16": "kFloat16", + "bf16": "kBFloat16", + "fp32": "kFloat32", + "fp64": "kFloat64", + "i8": "kInt8", + "i16": "kInt16", + "i32": "kInt32", + "i64": "kInt64", + "u8": "kUInt8", + "u16": "kUInt16", + "u32": "kUInt32", + "u64": "kUInt64", + }[dtype] + + +def _ptr(dtype, alignment=None): + return f"*{dtype}" if alignment is None else f"*{dtype}:{alignment}" diff --git a/scripts/generate_triton_ops.py b/scripts/triton/generate_ops.py similarity index 59% rename from scripts/generate_triton_ops.py rename to scripts/triton/generate_ops.py index 917981b67..e4ab763c1 100644 --- a/scripts/generate_triton_ops.py +++ b/scripts/triton/generate_ops.py @@ -4,10 +4,22 @@ import shutil import sys -_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[1] +import aot + +_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[2] +if str(_PROJECT_DIR) not in sys.path: + sys.path.insert(0, str(_PROJECT_DIR)) + +_KERNEL_NAME = "kernel" _OPS_DIR = _PROJECT_DIR / "src" / "triton" / "ops" +def _prepend_sys_path(path): + path = str(path) + if path not in sys.path: + sys.path.insert(0, path) + + def _find_op_modules(): return { path.parent.name: path @@ -31,10 +43,12 @@ def _write_cmake_manifest(output_dir, sources): manifest_path.write_text("\n".join(lines) + "\n") -def _load_op_module(op): - path = _find_op_modules()[op] - sys.path.insert(0, str(path.parent)) - spec = importlib.util.spec_from_file_location(path.stem, path) +def _load_op_module(path): + _prepend_sys_path(path.parent) + spec = importlib.util.spec_from_file_location( + f"infiniops_triton_{path.parent.name}_build", + path, + ) module = importlib.util.module_from_spec(spec) assert spec.loader is not None sys.modules[spec.name] = module @@ -55,8 +69,29 @@ def generate(ops, *, output_dir): output_dir.mkdir(parents=True, exist_ok=True) for op in ops: - module = _load_op_module(op) - module.build(output_dir) + path = op_modules[op] + kernel_path = path.parent / f"{op}.py" + module = _load_op_module(path) + kernel_args = aot.kernel_args(kernel_path, _KERNEL_NAME) + headers = [] + dispatch_configs = [] + for configs in module.configs(): + out_base = aot.build( + configs, + path=kernel_path, + kernel_name=_KERNEL_NAME, + out_dir=output_dir / op, + kernel_args=kernel_args, + ) + headers.append(out_base.with_suffix(".h")) + dispatch_configs.append(configs[0]) + aot.write_header( + headers, + output_dir / op / f"infini_ops_triton_{op}.h", + op_name=op, + configs=dispatch_configs, + kernel_args=kernel_args, + ) sources = _build_manifest(output_dir) _write_cmake_manifest(output_dir, sources) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4a465b356..2e1e12f7d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -98,7 +98,7 @@ if(WITH_TRITON) set(_triton_output_dir "${CMAKE_CURRENT_BINARY_DIR}/triton") set(_triton_generator_args - "${PROJECT_SOURCE_DIR}/scripts/generate_triton_ops.py" + "${PROJECT_SOURCE_DIR}/scripts/triton/generate_ops.py" --output-dir "${_triton_output_dir}") execute_process( diff --git a/src/triton/ops/add/add.h b/src/triton/ops/add/add.h index 70b0525c3..45de13f41 100644 --- a/src/triton/ops/add/add.h +++ b/src/triton/ops/add/add.h @@ -6,27 +6,12 @@ #include #include #include -#include #include +#include "add/infini_ops_triton_add.h" #include "base/add.h" #include "data_type.h" -extern "C" { -#include "add/infini_ops_triton_add_bf16.h" -#include "add/infini_ops_triton_add_fp16.h" -#include "add/infini_ops_triton_add_fp32.h" -#include "add/infini_ops_triton_add_fp64.h" -#include "add/infini_ops_triton_add_i16.h" -#include "add/infini_ops_triton_add_i32.h" -#include "add/infini_ops_triton_add_i64.h" -#include "add/infini_ops_triton_add_i8.h" -#include "add/infini_ops_triton_add_u16.h" -#include "add/infini_ops_triton_add_u32.h" -#include "add/infini_ops_triton_add_u64.h" -#include "add/infini_ops_triton_add_u8.h" -} - namespace infini::ops { template <> @@ -64,7 +49,7 @@ class Operator : public Add { void operator()(const Tensor input, const Tensor other, Tensor out) const override { - EnsureLoaded(out.dtype()); + load_infini_ops_triton_add(out.dtype()); CUstream stream = static_cast(stream_); auto x = reinterpret_cast(const_cast(input.data())); @@ -77,73 +62,12 @@ class Operator : public Add { int32_t y_contig = static_cast(is_other_contiguous_); int32_t out_contig = static_cast(is_out_contiguous_); - CUresult rc = CUDA_ERROR_INVALID_VALUE; - switch (out.dtype()) { - case DataType::kFloat16: - rc = infini_ops_triton_add_fp16_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kBFloat16: - rc = infini_ops_triton_add_bf16_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kFloat32: - rc = infini_ops_triton_add_fp32_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kFloat64: - rc = infini_ops_triton_add_fp64_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kInt8: - rc = infini_ops_triton_add_i8_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kInt16: - rc = infini_ops_triton_add_i16_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kInt32: - rc = infini_ops_triton_add_i32_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kInt64: - rc = infini_ops_triton_add_i64_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kUInt8: - rc = infini_ops_triton_add_u8_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kUInt16: - rc = infini_ops_triton_add_u16_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kUInt32: - rc = infini_ops_triton_add_u32_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - case DataType::kUInt64: - rc = infini_ops_triton_add_u64_default( - stream, x, y, o, d_out_shape_, d_input_stride_, d_other_stride_, - d_out_stride_, x_contig, y_contig, out_contig, ndim_val, n); - break; - default: - assert(false && "triton `Add` does not support this dtype"); - } + auto result = launch_infini_ops_triton_add( + out.dtype(), stream, x, y, o, d_out_shape_, d_input_stride_, + d_other_stride_, d_out_stride_, x_contig, y_contig, out_contig, + ndim_val, n); - assert(rc == CUDA_SUCCESS && "Triton `Add` launch failed"); + assert(result == CUDA_SUCCESS && "Triton `Add` launch failed"); } private: @@ -156,51 +80,6 @@ class Operator : public Add { CUdeviceptr d_other_stride_{0}; CUdeviceptr d_out_stride_{0}; - - static void EnsureLoaded(DataType dtype) { - static std::once_flag fp16, bf16, fp32, fp64, i8, i16, i32, i64, u8, u16, - u32, u64; - switch (dtype) { - case DataType::kFloat16: - std::call_once(fp16, &load_infini_ops_triton_add_fp16); - break; - case DataType::kBFloat16: - std::call_once(bf16, &load_infini_ops_triton_add_bf16); - break; - case DataType::kFloat32: - std::call_once(fp32, &load_infini_ops_triton_add_fp32); - break; - case DataType::kFloat64: - std::call_once(fp64, &load_infini_ops_triton_add_fp64); - break; - case DataType::kInt8: - std::call_once(i8, &load_infini_ops_triton_add_i8); - break; - case DataType::kInt16: - std::call_once(i16, &load_infini_ops_triton_add_i16); - break; - case DataType::kInt32: - std::call_once(i32, &load_infini_ops_triton_add_i32); - break; - case DataType::kInt64: - std::call_once(i64, &load_infini_ops_triton_add_i64); - break; - case DataType::kUInt8: - std::call_once(u8, &load_infini_ops_triton_add_u8); - break; - case DataType::kUInt16: - std::call_once(u16, &load_infini_ops_triton_add_u16); - break; - case DataType::kUInt32: - std::call_once(u32, &load_infini_ops_triton_add_u32); - break; - case DataType::kUInt64: - std::call_once(u64, &load_infini_ops_triton_add_u64); - break; - default: - break; - } - } }; } // namespace infini::ops diff --git a/src/triton/ops/add/build.py b/src/triton/ops/add/build.py index 28f77872a..70fcf18be 100644 --- a/src/triton/ops/add/build.py +++ b/src/triton/ops/add/build.py @@ -1,10 +1,5 @@ -import pathlib +from scripts.triton import aot -from triton.tools.compile import CompileArgs, compile_kernel -from triton.tools import link - -_KERNEL_PATH = pathlib.Path(__file__).parent / "add.py" -_KERNEL_NAME = "kernel" _DTYPES = ( "fp16", "bf16", @@ -20,95 +15,38 @@ "u64", ) _BLOCK_SIZES = (512, 1024) +_ALIGNMENTS = (16, None) _NUM_WARPS = 4 _NUM_STAGES = 3 +_DATA_PTRS = ("x_ptr", "y_ptr", "out_ptr") +_META_PTRS = ("out_shape_ptr", "x_stride_ptr", "y_stride_ptr", "out_stride_ptr") +_SCALARS = ("x_contig", "y_contig", "out_contig", "ndim", "n_elements") + + +def _signature(dtype, block_size, alignment): + return aot.Signature( + pointer_dtypes={ + **{name: dtype for name in _DATA_PTRS}, + **{name: "i64" for name in _META_PTRS}, + }, + pointer_alignments={name: alignment for name in _DATA_PTRS}, + scalar_dtypes={ + **{name: "i32" for name in _SCALARS}, + }, + constexprs={"BLOCK_SIZE": block_size}, + ) -def _compile_variants(variant_dir, dtype): - out_name = f"infini_ops_triton_add_{dtype}" - headers = [] - for block_size in _BLOCK_SIZES: - aligned_sig = ( - f"*{dtype}:16, *{dtype}:16, *{dtype}:16, " - f"*i64, *i64, *i64, *i64, " - f"i32, i32, i32, i32, i32, {block_size}" - ) - _, files = compile_kernel( - CompileArgs( - path=str(_KERNEL_PATH), - kernel_name=_KERNEL_NAME, - signature=aligned_sig, - grid=f"(n_elements + {block_size} - 1) / {block_size}, 1, 1", - num_warps=_NUM_WARPS, - num_stages=_NUM_STAGES, - out_name=out_name, - out_path=variant_dir / out_name, - target=None, - ) - ) - headers.extend(f for f in files if f.suffix == ".h") - - generic_sig = ( - f"*{dtype}, *{dtype}, *{dtype}, " - f"*i64, *i64, *i64, *i64, " - f"i32, i32, i32, i32, i32, {block_size}" - ) - _, files = compile_kernel( - CompileArgs( - path=str(_KERNEL_PATH), - kernel_name=_KERNEL_NAME, - signature=generic_sig, +def configs(): + for dtype in _DTYPES: + yield tuple( + aot.CompileConfig( + signature=_signature(dtype, block_size, alignment), grid=f"(n_elements + {block_size} - 1) / {block_size}, 1, 1", + out_name=f"infini_ops_triton_add_{dtype}", num_warps=_NUM_WARPS, num_stages=_NUM_STAGES, - out_name=out_name, - out_path=variant_dir / out_name, - target=None, ) + for block_size in _BLOCK_SIZES + for alignment in _ALIGNMENTS ) - headers.extend(f for f in files if f.suffix == ".h") - return headers, out_name - - -def _link_one_dtype(variant_dir, headers, out_name): - parser = link.HeaderParser() - for h in headers: - parser.extract_linker_meta(h.read_text()) - - out_base = variant_dir / out_name - first_meta = next(iter(parser.kernels.values()))[0] - backend_prelude = ( - pathlib.Path(link.__file__).parent / "extra" / parser.backend_name / "link.h" - ).read_text() - - algo_decls = [link.make_algo_decls(name, m) for name, m in parser.kernels.items()] - out_base.with_suffix(".h").write_text( - backend_prelude - + "\n".join(algo_decls) - + "\n" - + link.make_get_num_algos_decl(first_meta) - + "\n" - + link.make_global_decl(first_meta) - ) - defs = [ - link.make_kernel_hints_dispatcher(name, m) for name, m in parser.kernels.items() - ] - names = list(parser.kernels.keys()) - src = backend_prelude - src += "#include \n#include \n\n" - src += "\n".join(defs) + "\n" - src += link.make_func_pointers(names, first_meta) + "\n" - src += link.make_get_num_algos_def(first_meta) + "\n" - src += link.make_kernel_meta_const_dispatcher(first_meta) + "\n" - src += link.make_kernel_load_def(names, first_meta) + "\n" - src += link.make_default_algo_kernel(first_meta) - out_base.with_suffix(".c").write_text(src) - - -def build(output_dir: pathlib.Path): - variant_dir = output_dir / "add" - variant_dir.mkdir(parents=True, exist_ok=True) - - for dtype in _DTYPES: - headers, out_name = _compile_variants(variant_dir, dtype) - _link_one_dtype(variant_dir, headers, out_name)