diff --git a/CMakeLists.txt b/CMakeLists.txt index 9acb2d033..3b726b917 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,8 @@ option(WITH_TORCH "Enable PyTorch C++ backend" OFF) option(WITH_NINETOOTHED "Enable NineToothed-generated kernels" OFF) +option(WITH_TRITON "Enable Triton-generated kernels" OFF) + # Default OFF until CANN's `extract_host_stub.py` path handling is fixed for # `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed # object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the @@ -316,6 +318,14 @@ if(WITH_NINETOOTHED) set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run NineToothed code generation") endif() +if(WITH_TRITON AND NOT WITH_NVIDIA) + message(FATAL_ERROR "`WITH_TRITON` temporarily requires `WITH_NVIDIA=ON` because Triton AOT temporarily targets CUDA.") +endif() + +if(WITH_TRITON) + set(TRITON_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run Triton AOT code generation") +endif() + if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) enable_language(CUDA) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 22be3a48f..8492b845a 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1183,13 +1183,15 @@ def _filter_ops(ops, op_allowlist, *, strict=False): return {op_name: ops[op_name] for op_name in op_allowlist if op_name in ops} -def _get_all_ops(devices, with_torch=False, with_ninetoothed=False): +def _get_all_ops(devices, with_torch=False, with_ninetoothed=False, with_triton=False): scan_dirs = set(devices) if with_torch: scan_dirs.add("torch") if with_ninetoothed: scan_dirs.add("ninetoothed") + if with_triton: + scan_dirs.add("triton") ops = {} @@ -1328,6 +1330,12 @@ def _dispatch_gen_batch_size(): help="Fail if `--ops` contains operators unavailable for the active devices.", ) + parser.add_argument( + "--with-triton", + action="store_true", + help="Include Triton backend implementations.", + ) + args = parser.parse_args() # Wipe previous outputs so files for ops that have since been removed @@ -1348,6 +1356,7 @@ def _dispatch_gen_batch_size(): args.devices, with_torch=args.with_torch, with_ninetoothed=args.with_ninetoothed, + with_triton=args.with_triton, ) ops = _filter_ops( diff --git a/scripts/triton/aot.py b/scripts/triton/aot.py new file mode 100644 index 000000000..90729a9d5 --- /dev/null +++ b/scripts/triton/aot.py @@ -0,0 +1,253 @@ +import ast +from dataclasses import dataclass +import pathlib +from typing import Any, Sequence + +from triton.tools import link +from triton.tools.compile import CompileArgs, compile_kernel + + +@dataclass(frozen=True) +class Signature: + pointer_dtypes: dict[str, str] + pointer_alignments: dict[str, int | None] | None = None + scalar_dtypes: dict[str, str] | None = None + constexprs: dict[str, Any] | None = None + + +@dataclass(frozen=True) +class CompileConfig: + signature: Signature + grid: str + out_name: str + num_warps: int = 4 + num_stages: int = 3 + target: Any = None + + +def compile( + config: CompileConfig, + *, + path: pathlib.Path, + kernel_name: str, + out_dir: pathlib.Path, + kernel_args: Sequence[str], +) -> list[pathlib.Path]: + _, files = compile_kernel( + CompileArgs( + path=str(path), + kernel_name=kernel_name, + signature=_render_signature(config.signature, kernel_args), + grid=config.grid, + num_warps=config.num_warps, + num_stages=config.num_stages, + out_name=config.out_name, + out_path=out_dir / config.out_name, + target=config.target, + ) + ) + + return [path for path in files if path.suffix == ".h"] + + +def link_headers(headers: Sequence[pathlib.Path], out_base: pathlib.Path): + parser = link.HeaderParser() + for header in headers: + parser.extract_linker_meta(header.read_text()) + + first_meta = next(iter(parser.kernels.values()))[0] + backend_prelude = ( + pathlib.Path(link.__file__).parent / "extra" / parser.backend_name / "link.h" + ).read_text() + + out_base.with_suffix(".h").write_text( + backend_prelude + + "\n".join( + link.make_algo_decls(name, meta) for name, meta in parser.kernels.items() + ) + + "\n" + + link.make_get_num_algos_decl(first_meta) + + "\n" + + link.make_global_decl(first_meta) + ) + + names = list(parser.kernels) + defs = [ + link.make_kernel_hints_dispatcher(name, meta) + for name, meta in parser.kernels.items() + ] + + out_base.with_suffix(".c").write_text( + backend_prelude + + "#include \n#include \n\n" + + "\n".join(defs) + + "\n" + + link.make_func_pointers(names, first_meta) + + "\n" + + link.make_get_num_algos_def(first_meta) + + "\n" + + link.make_kernel_meta_const_dispatcher(first_meta) + + "\n" + + link.make_kernel_load_def(names, first_meta) + + "\n" + + link.make_default_algo_kernel(first_meta) + ) + + +def build( + configs: Sequence[CompileConfig], + *, + path: pathlib.Path, + kernel_name: str, + out_dir: pathlib.Path, + kernel_args: Sequence[str], +) -> pathlib.Path: + if not configs: + raise ValueError("empty compile configs") + + out_name = configs[0].out_name + out_dir.mkdir(parents=True, exist_ok=True) + + headers = [] + for config in configs: + headers.extend( + compile( + config, + path=path, + kernel_name=kernel_name, + out_dir=out_dir, + kernel_args=kernel_args, + ) + ) + + if not headers: + raise ValueError(f"no headers generated for {out_name}") + + out_base = out_dir / out_name + link_headers(headers, out_base) + + return out_base + + +def write_header( + headers: Sequence[pathlib.Path], + out_path: pathlib.Path, + *, + op_name: str, + configs: Sequence[CompileConfig], + kernel_args: Sequence[str], +): + guard = f"INFINI_OPS_GENERATED_{out_path.stem.upper()}_H_" + includes = "\n".join(f'#include "{header.name}"' for header in headers) + params = _dispatch_params(configs[0].signature, kernel_args) + param_decls = ", ".join(f"{ty} {name}" for ty, name in params) + param_names = ", ".join(name for _, name in params) + + body = f"#ifndef {guard}\n#define {guard}\n\n" + body += f'extern "C" {{\n{includes}\n}}\n\n' + body += '#include \n\n#include "data_type.h"\n\n' + body += "namespace infini::ops {\n\n" + + body += f"inline TT_ResultTy launch_infini_ops_triton_{op_name}(\n" + body += f" DataType dtype, TT_StreamTy stream, {param_decls}) {{\n" + body += " switch (dtype) {\n" + for config in configs: + dtype = _out_dtype(config.out_name) + body += f" case DataType::{_data_type(dtype)}:\n" + body += f" return {config.out_name}_default(stream, {param_names});\n" + body += " default:\n return TT_ERROR_INVALID_VALUE;\n }\n}\n\n" + + body += f"inline void load_infini_ops_triton_{op_name}(DataType dtype) {{\n" + body += " switch (dtype) {\n" + for config in configs: + dtype = _out_dtype(config.out_name) + body += f" case DataType::{_data_type(dtype)}: {{\n" + body += " static std::once_flag once;\n" + body += f" std::call_once(once, &load_{config.out_name});\n" + body += " return;\n }\n" + body += " default:\n return;\n }\n}\n\n" + + body += "} // namespace infini::ops\n\n#endif\n" + out_path.write_text(body) + + +def kernel_args(path, kernel_name): + tree = ast.parse(pathlib.Path(path).read_text()) + for node in tree.body: + if isinstance(node, ast.FunctionDef) and node.name == kernel_name: + return tuple(arg.arg for arg in node.args.args) + raise ValueError(f"kernel {kernel_name} not found in {path}") + + +def _render_signature(signature: Signature, args: Sequence[str]) -> str: + pointer_alignments = signature.pointer_alignments or {} + scalar_dtypes = signature.scalar_dtypes or {} + constexprs = signature.constexprs or {} + + parts = [] + for arg in args: + if arg in constexprs: + parts.append(str(constexprs[arg])) + elif arg in signature.pointer_dtypes: + parts.append( + _ptr(signature.pointer_dtypes[arg], pointer_alignments.get(arg)) + ) + elif arg in scalar_dtypes: + parts.append(str(scalar_dtypes[arg])) + else: + raise ValueError(f"missing signature rule for {arg}") + + return ", ".join(parts) + + +def _dispatch_params(signature: Signature, args: Sequence[str]): + scalar_dtypes = signature.scalar_dtypes or {} + constexprs = signature.constexprs or {} + + params = [] + for arg in args: + if arg in constexprs: + continue + if arg in signature.pointer_dtypes: + params.append(("CUdeviceptr", arg)) + elif arg in scalar_dtypes: + params.append((_scalar_ctype(scalar_dtypes[arg]), arg)) + else: + raise ValueError(f"missing dispatch rule for {arg}") + return params + + +def _scalar_ctype(dtype): + return { + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp32": "float", + "fp64": "double", + }[dtype] + + +def _out_dtype(out_name): + return out_name.rsplit("_", 1)[1] + + +def _data_type(dtype): + return { + "fp16": "kFloat16", + "bf16": "kBFloat16", + "fp32": "kFloat32", + "fp64": "kFloat64", + "i8": "kInt8", + "i16": "kInt16", + "i32": "kInt32", + "i64": "kInt64", + "u8": "kUInt8", + "u16": "kUInt16", + "u32": "kUInt32", + "u64": "kUInt64", + }[dtype] + + +def _ptr(dtype, alignment=None): + return f"*{dtype}" if alignment is None else f"*{dtype}:{alignment}" diff --git a/scripts/triton/generate_ops.py b/scripts/triton/generate_ops.py new file mode 100644 index 000000000..e4ab763c1 --- /dev/null +++ b/scripts/triton/generate_ops.py @@ -0,0 +1,118 @@ +import argparse +import importlib.util +import pathlib +import shutil +import sys + +import aot + +_PROJECT_DIR = pathlib.Path(__file__).resolve().parents[2] +if str(_PROJECT_DIR) not in sys.path: + sys.path.insert(0, str(_PROJECT_DIR)) + +_KERNEL_NAME = "kernel" +_OPS_DIR = _PROJECT_DIR / "src" / "triton" / "ops" + + +def _prepend_sys_path(path): + path = str(path) + if path not in sys.path: + sys.path.insert(0, path) + + +def _find_op_modules(): + return { + path.parent.name: path + for path in sorted(_OPS_DIR.glob("*/build.py")) + if path.is_file() + } + + +def _build_manifest(output_dir): + return sorted(str(path) for path in pathlib.Path(output_dir).rglob("*.c")) + + +def _write_cmake_manifest(output_dir, sources): + manifest_path = pathlib.Path(output_dir) / "manifest.cmake" + lines = ["set(INFINIOPS_TRITON_SOURCES"] + lines.extend(f' "{source}"' for source in sources) + lines.append(")") + lines.append("") + lines.append(f'set(INFINIOPS_TRITON_INCLUDE_DIRS "{output_dir}")') + lines.append("") + manifest_path.write_text("\n".join(lines) + "\n") + + +def _load_op_module(path): + _prepend_sys_path(path.parent) + spec = importlib.util.spec_from_file_location( + f"infiniops_triton_{path.parent.name}_build", + path, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + return module + + +def generate(ops, *, output_dir): + op_modules = _find_op_modules() + unknown_ops = tuple(op for op in ops if op not in op_modules) + + if unknown_ops: + raise ValueError(f"unsupported Triton ops: {', '.join(unknown_ops)}") + + output_dir = pathlib.Path(output_dir) + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.mkdir(parents=True, exist_ok=True) + + for op in ops: + path = op_modules[op] + kernel_path = path.parent / f"{op}.py" + module = _load_op_module(path) + kernel_args = aot.kernel_args(kernel_path, _KERNEL_NAME) + headers = [] + dispatch_configs = [] + for configs in module.configs(): + out_base = aot.build( + configs, + path=kernel_path, + kernel_name=_KERNEL_NAME, + out_dir=output_dir / op, + kernel_args=kernel_args, + ) + headers.append(out_base.with_suffix(".h")) + dispatch_configs.append(configs[0]) + aot.write_header( + headers, + output_dir / op / f"infini_ops_triton_{op}.h", + op_name=op, + configs=dispatch_configs, + kernel_args=kernel_args, + ) + + sources = _build_manifest(output_dir) + _write_cmake_manifest(output_dir, sources) + + return sources + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate Triton operator sources for InfiniOps." + ) + parser.add_argument("--output-dir", required=True) + parser.add_argument("--ops", nargs="+", default=tuple(_find_op_modules())) + + return parser.parse_args() + + +def main(): + args = _parse_args() + generate(args.ops, output_dir=args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7cd3c0043..2e1e12f7d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -84,6 +84,41 @@ if(WITH_NINETOOTHED) target_sources(infiniops PRIVATE ${INFINI_OPS_NINETOOTHED_SOURCES}) endif() +if(WITH_TRITON) + find_package(Python COMPONENTS Interpreter REQUIRED) + + if(TRITON_PYTHON_EXECUTABLE) + set(_triton_python "${TRITON_PYTHON_EXECUTABLE}") + elseif(_TORCH_PYTHON) + set(_triton_python "${_TORCH_PYTHON}") + else() + set(_triton_python "${Python_EXECUTABLE}") + endif() + message(STATUS "Triton codegen Python: ${_triton_python}") + + set(_triton_output_dir "${CMAKE_CURRENT_BINARY_DIR}/triton") + set(_triton_generator_args + "${PROJECT_SOURCE_DIR}/scripts/triton/generate_ops.py" + --output-dir "${_triton_output_dir}") + + execute_process( + COMMAND "${_triton_python}" ${_triton_generator_args} + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + RESULT_VARIABLE _triton_generation_result + ) + + if(NOT _triton_generation_result EQUAL 0) + message(FATAL_ERROR "Generating Triton AOT operator sources failed with `${_triton_python}`. Set `TRITON_PYTHON_EXECUTABLE` to a Python with `triton` and CUDA dependencies installed.") + endif() + + enable_language(C) + + include("${_triton_output_dir}/manifest.cmake") + target_compile_definitions(infiniops PUBLIC WITH_TRITON=1) + target_include_directories(infiniops PRIVATE ${INFINIOPS_TRITON_INCLUDE_DIRS}) + target_sources(infiniops PRIVATE ${INFINIOPS_TRITON_SOURCES}) +endif() + if(WITH_ILUVATAR) set(ILUVATAR_PATTERNS "native/cuda/*.cc" @@ -552,6 +587,10 @@ if(GENERATE_OPERATOR_CALL_INSTANTIATIONS OR GENERATE_PYTHON_BINDINGS) list(APPEND GENERATOR_ARGS --with-ninetoothed) endif() + if(WITH_TRITON) + list(APPEND GENERATOR_ARGS --with-triton) + endif() + execute_process( COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS} WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} @@ -791,6 +830,12 @@ if(GENERATE_PYTHON_BINDINGS) target_include_directories(ops PRIVATE ${INFINI_OPS_NINETOOTHED_INCLUDE_DIRS}) endif() + + if(WITH_TRITON) + target_include_directories(ops PRIVATE + ${INFINIOPS_TRITON_INCLUDE_DIRS}) + endif() + target_link_libraries(ops PRIVATE infiniops) # Cambricon generated dispatch is compiled into the Python extension and diff --git a/src/triton/ops/add/add.h b/src/triton/ops/add/add.h new file mode 100644 index 000000000..45de13f41 --- /dev/null +++ b/src/triton/ops/add/add.h @@ -0,0 +1,87 @@ +#ifndef INFINI_OPS_TRITON_ADD_H_ +#define INFINI_OPS_TRITON_ADD_H_ + +#include + +#include +#include +#include +#include + +#include "add/infini_ops_triton_add.h" +#include "base/add.h" +#include "data_type.h" + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + using Add::operator(); + + Operator(const Tensor input, const Tensor other, Tensor out) + : Add{input, other, out} { + const int ndim = static_cast(ndim_); + std::vector h_metadata(4 * std::max(ndim, 1), 0); + for (int i = 0; i < ndim; ++i) { + h_metadata[0 * ndim + i] = static_cast(out_shape_[i]); + h_metadata[1 * ndim + i] = static_cast(input_strides_[i]); + h_metadata[2 * ndim + i] = static_cast(other_strides_[i]); + h_metadata[3 * ndim + i] = static_cast(out_strides_[i]); + } + + const size_t bytes = h_metadata.size() * sizeof(int64_t); + cuMemAlloc(&d_metadata_, bytes); + cuMemcpyHtoD(d_metadata_, h_metadata.data(), bytes); + + const size_t stride_bytes = ndim * sizeof(int64_t); + d_out_shape_ = d_metadata_ + 0 * stride_bytes; + d_input_stride_ = d_metadata_ + 1 * stride_bytes; + d_other_stride_ = d_metadata_ + 2 * stride_bytes; + d_out_stride_ = d_metadata_ + 3 * stride_bytes; + } + + ~Operator() { + if (d_metadata_) { + cuMemFree(d_metadata_); + } + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + load_infini_ops_triton_add(out.dtype()); + + CUstream stream = static_cast(stream_); + auto x = reinterpret_cast(const_cast(input.data())); + auto y = reinterpret_cast(const_cast(other.data())); + auto o = reinterpret_cast(out.data()); + + int32_t n = static_cast(out.numel()); + int32_t ndim_val = static_cast(ndim_); + int32_t x_contig = static_cast(is_input_contiguous_); + int32_t y_contig = static_cast(is_other_contiguous_); + int32_t out_contig = static_cast(is_out_contiguous_); + + auto result = launch_infini_ops_triton_add( + out.dtype(), stream, x, y, o, d_out_shape_, d_input_stride_, + d_other_stride_, d_out_stride_, x_contig, y_contig, out_contig, + ndim_val, n); + + assert(result == CUDA_SUCCESS && "Triton `Add` launch failed"); + } + + private: + CUdeviceptr d_metadata_{0}; + + CUdeviceptr d_out_shape_{0}; + + CUdeviceptr d_input_stride_{0}; + + CUdeviceptr d_other_stride_{0}; + + CUdeviceptr d_out_stride_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/triton/ops/add/add.py b/src/triton/ops/add/add.py new file mode 100644 index 000000000..eb3e35724 --- /dev/null +++ b/src/triton/ops/add/add.py @@ -0,0 +1,52 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + x_ptr, + y_ptr, + out_ptr, + out_shape_ptr, + x_stride_ptr, + y_stride_ptr, + out_stride_ptr, + x_contig, + y_contig, + out_contig, + ndim, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = (pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)).to(tl.int64) + mask = offsets < n_elements + + if (x_contig != 0) and (y_contig != 0) and (out_contig != 0): + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x + y, mask=mask) + else: + x_offs = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + y_offs = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + out_offs = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + tmp = offsets + + for i in range(ndim): + s = tl.load(out_shape_ptr + (ndim - 1 - i)) + d = tmp % s + tmp = tmp // s + x_offs += d * tl.load(x_stride_ptr + (ndim - 1 - i)) + y_offs += d * tl.load(y_stride_ptr + (ndim - 1 - i)) + out_offs += d * tl.load(out_stride_ptr + (ndim - 1 - i)) + + if x_contig != 0: + x_offs = offsets + if y_contig != 0: + y_offs = offsets + if out_contig != 0: + out_offs = offsets + + x = tl.load(x_ptr + x_offs, mask=mask) + y = tl.load(y_ptr + y_offs, mask=mask) + tl.store(out_ptr + out_offs, x + y, mask=mask) diff --git a/src/triton/ops/add/build.py b/src/triton/ops/add/build.py new file mode 100644 index 000000000..70fcf18be --- /dev/null +++ b/src/triton/ops/add/build.py @@ -0,0 +1,52 @@ +from scripts.triton import aot + +_DTYPES = ( + "fp16", + "bf16", + "fp32", + "fp64", + "i8", + "i16", + "i32", + "i64", + "u8", + "u16", + "u32", + "u64", +) +_BLOCK_SIZES = (512, 1024) +_ALIGNMENTS = (16, None) +_NUM_WARPS = 4 +_NUM_STAGES = 3 +_DATA_PTRS = ("x_ptr", "y_ptr", "out_ptr") +_META_PTRS = ("out_shape_ptr", "x_stride_ptr", "y_stride_ptr", "out_stride_ptr") +_SCALARS = ("x_contig", "y_contig", "out_contig", "ndim", "n_elements") + + +def _signature(dtype, block_size, alignment): + return aot.Signature( + pointer_dtypes={ + **{name: dtype for name in _DATA_PTRS}, + **{name: "i64" for name in _META_PTRS}, + }, + pointer_alignments={name: alignment for name in _DATA_PTRS}, + scalar_dtypes={ + **{name: "i32" for name in _SCALARS}, + }, + constexprs={"BLOCK_SIZE": block_size}, + ) + + +def configs(): + for dtype in _DTYPES: + yield tuple( + aot.CompileConfig( + signature=_signature(dtype, block_size, alignment), + grid=f"(n_elements + {block_size} - 1) / {block_size}, 1, 1", + out_name=f"infini_ops_triton_add_{dtype}", + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + ) + for block_size in _BLOCK_SIZES + for alignment in _ALIGNMENTS + )