From 6e238064e45cf721d22f9cf3ddcbc773df934951 Mon Sep 17 00:00:00 2001 From: Nikhil Dev Goyal Date: Thu, 11 Jun 2026 11:16:37 -0700 Subject: [PATCH] [MatMul] Add SFC(Space Filling Curves), specifically the hilbert curve into matmul-inl.h. PiperOrigin-RevId: 930635801 --- BUILD.bazel | 8 ++++ CMakeLists.txt | 2 + ops/gilbert.cc | 117 +++++++++++++++++++++++++++++++++++++++++++++++ ops/gilbert.h | 17 +++++++ ops/matmul-inl.h | 65 ++++++++++++++++++++++++++ ops/matmul.h | 91 ++++++++++++++++++++++++++++++++++-- util/zones.cc | 8 ++++ util/zones.h | 4 ++ 8 files changed, 308 insertions(+), 4 deletions(-) create mode 100644 ops/gilbert.cc create mode 100644 ops/gilbert.h diff --git a/BUILD.bazel b/BUILD.bazel index e8a1bbfe..a2d7803a 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -335,6 +335,7 @@ cc_library( deps = [ ":allocator", ":basics", + ":gilbert", ":mat", ":threading", ":threading_context", @@ -359,6 +360,7 @@ cc_library( deps = [ ":allocator", ":basics", + ":gilbert", ":mat", ":matmul_env", ":threading", @@ -521,6 +523,12 @@ cc_test( ], ) +cc_library( + name = "gilbert", + srcs = ["ops/gilbert.cc"], + hdrs = ["ops/gilbert.h"], +) + cc_test( name = "bench_matmul", size = "small", diff --git a/CMakeLists.txt b/CMakeLists.txt index 6791fbcc..4fb13a5e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -156,6 +156,8 @@ set(SOURCES io/io.cc io/io.h ops/dot-inl.h + ops/gilbert.cc + ops/gilbert.h ops/matmul_static_bf16.cc ops/matmul_static_f32.cc ops/matmul_static_nuq.cc diff --git a/ops/gilbert.cc b/ops/gilbert.cc new file mode 100644 index 00000000..3fae655e --- /dev/null +++ b/ops/gilbert.cc @@ -0,0 +1,117 @@ +// Copyright 2026 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Adapted from code by abetusk (BSD-2-Clause) in +// https://github.com/jakubcerveny/gilbert. + +#include "ops/gilbert.h" + +#include +#include + +namespace gcpp { + +namespace { + +int sgn(int x) { + if (x < 0) { + return -1; + } + if (x > 0) { + return 1; + } + return 0; +} + +int gilbert_d2xy_r(int dst_idx, int cur_idx, int* xres, int* yres, int ax, + int ay, int bx, int by) { + int nxt_idx; + int w, h, x, y, dax, day, dbx, dby, di; + int ax2, ay2, bx2, by2, w2, h2; + + w = abs(ax + ay); + h = abs(bx + by); + x = *xres; + y = *yres; + dax = sgn(ax); + day = sgn(ay); + dbx = sgn(bx); + dby = sgn(by); + di = dst_idx - cur_idx; + + if (h == 1) { + *xres = x + dax * di; + *yres = y + day * di; + return 0; + } + if (w == 1) { + *xres = x + dbx * di; + *yres = y + dby * di; + return 0; + } + + ax2 = ax >> 1; + ay2 = ay >> 1; + bx2 = bx >> 1; + by2 = by >> 1; + w2 = abs(ax2 + ay2); + h2 = abs(bx2 + by2); + + if ((2 * w) > (3 * h)) { + if ((w2 & 1) && (w > 2)) { + ax2 += dax; + ay2 += day; + } + nxt_idx = cur_idx + abs((ax2 + ay2) * (bx + by)); + if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { + *xres = x; + *yres = y; + return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax2, ay2, bx, by); + } + cur_idx = nxt_idx; + *xres = x + ax2; + *yres = y + ay2; + return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax - ax2, ay - ay2, bx, + by); + } + + if ((h2 & 1) && (h > 2)) { + bx2 += dbx; + by2 += dby; + } + + nxt_idx = cur_idx + abs((bx2 + by2) * (ax2 + ay2)); + if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { + *xres = x; + *yres = y; + return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, bx2, by2, ax2, ay2); + } + cur_idx = nxt_idx; + + nxt_idx = cur_idx + abs((ax + ay) * ((bx - bx2) + (by - by2))); + if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { + *xres = x + bx2; + *yres = y + by2; + return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax, ay, bx - bx2, + by - by2); + } + cur_idx = nxt_idx; + + *xres = x + (ax - dax) + (bx2 - dbx); + *yres = y + (ay - day) + (by2 - dby); + return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, -bx2, -by2, -(ax - ax2), + -(ay - ay2)); +} + +} // namespace + +int gilbert_d2xy(int* x, int* y, int idx, int w, int h) { + *x = 0; + *y = 0; + if (w >= h) { + return gilbert_d2xy_r(idx, 0, x, y, w, 0, 0, h); + } + return gilbert_d2xy_r(idx, 0, x, y, 0, h, w, 0); +} + +} // namespace gcpp \ No newline at end of file diff --git a/ops/gilbert.h b/ops/gilbert.h new file mode 100644 index 00000000..4d1b05ba --- /dev/null +++ b/ops/gilbert.h @@ -0,0 +1,17 @@ +// Copyright 2026 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Adapted from code by abetusk (BSD-2-Clause) in +// https://github.com/jakubcerveny/gilbert. + +#ifndef THIRD_PARTY_GEMMA_CPP_OPS_GILBERT_H_ +#define THIRD_PARTY_GEMMA_CPP_OPS_GILBERT_H_ + +namespace gcpp { + +// Maps a 1D Hilbert curve index to 2D coordinates (x, y). +int gilbert_d2xy(int* x, int* y, int idx, int w, int h); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_OPS_GILBERT_H_ \ No newline at end of file diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index fa7d11e5..d6c4d8f8 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1039,6 +1039,71 @@ class MMLoops { } }); } + + // Parallel loops over mc/nc blocks of M/range_n via SFC, single K. + template + static HWY_INLINE void Loop(MMOrderSFC, Parallel parallel, + const StridedViewBF A, const MatPtrT& B, + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { + const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMSFC); + HWY_DASSERT(args.ranges_kc.NumTasks() == 1); + const IndexRange& range_kc = args.ranges_kc.Range(0); // whole K + + parallel.ForRangesSFC( + args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, + [&](const IndexRange& range_mc, const IndexRange& range_nc, + size_t worker) HWY_ATTR { + MMZone mm_zone; + mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); + MMKernel::B3A2C0( + A, B, range_mc, range_kc, range_nc, args, MMSetC(), + C.View(range_mc.begin(), range_nc.begin(), range_nc.Num())); + + const StridedViewBF C2 = args.env.C_tiles.C( + Extents2D(range_mc.Num(), range_nc.Num()), worker); + + if (B2 != nullptr) { + MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args, + MMSetC(), C2); + } + if constexpr (IsBF16()) { + args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker); + } + }); + } + + // Parallel loops over mc/nc blocks of M/range_n via SFC, sequential K. + template + static HWY_INLINE void Loop(MMOrderSFC_K, Parallel parallel, + const StridedViewBF A, const MatPtrT& B, + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { + const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMSFC_K); + + parallel.ForRangesSFC( + args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, + [&](const IndexRange& range_mc, const IndexRange& range_nc, + size_t worker) HWY_ATTR { + MMZone mm_zone; + mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); + MMKernel::ForeachKC( + A, B, range_mc, args.ranges_kc, range_nc, args, + C.View(range_mc.begin(), range_nc.begin(), range_nc.Num())); + + const StridedViewBF C2 = args.env.C_tiles.C( + Extents2D(range_mc.Num(), range_nc.Num()), worker); + + if (B2 != nullptr) { + MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc, range_nc, + args, C2); + } + + if constexpr (IsBF16()) { + args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker); + } + }); + } }; // MMLoops // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. diff --git a/ops/matmul.h b/ops/matmul.h index f0d95d1e..66575756 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -25,6 +25,7 @@ // IWYU pragma: begin_exports #include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN_BRGEMM +#include "ops/gilbert.h" #include "util/basics.h" #include "util/mat.h" #include "util/threading.h" @@ -90,6 +91,25 @@ struct MMParallelNone { } } + // SFC = Space Filling Curve(specifically the Hilbert curve) + template + void ForRangesSFC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, size_t cluster_idx, + const Func& func) const { + const size_t worker = ctx.Worker(cluster_idx); + const size_t W = ranges_mc.NumTasks(); + const size_t H = ranges_nc.NumTasks(); + const size_t num_tasks = W * H; + + for (size_t task = 0; task < num_tasks; ++task) { + int x, y; + gilbert_d2xy(&x, &y, static_cast(task), static_cast(W), + static_cast(H)); + func(ranges_mc.Range(x), ranges_nc.Range(y), worker); + } + } + template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, size_t cluster_idx, const Func& func) const { @@ -138,6 +158,29 @@ struct MMParallelWithinCluster { }); } + template + void ForRangesSFC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, size_t cluster_idx, + const Func& func) const { + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kMMClusterForSFC); + + const size_t W = ranges_mc.NumTasks(); + const size_t H = ranges_nc.NumTasks(); + const size_t num_tasks = W * H; + + ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, + [&](uint64_t task, size_t worker) { + int x, y; + gilbert_d2xy(&x, &y, static_cast(task), + static_cast(W), + static_cast(H)); + func(ranges_mc.Range(x), ranges_nc.Range(y), + worker); + }); + } + template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, size_t cluster_idx, const Func& func) const { @@ -208,6 +251,35 @@ struct MMParallelHierarchical { }); } + template + void ForRangesSFC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, + size_t caller_cluster_idx, const Func& func) const { + HWY_DASSERT(caller_cluster_idx == 0); + (void)caller_cluster_idx; + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kMMHierForSFC); + + const size_t W = ranges_mc.NumTasks(); + const size_t H = ranges_nc.NumTasks(); + const IndexRange all_range(0, W * H); + + ParallelPartitionAcrossClusters( + all_range, /*task_multiple=*/1, /*inner_tasks=*/1, ctx, caller, + [&](const IndexRange& cluster_range, size_t cluster_idx) { + ParallelForWithinCluster( + cluster_range.Num(), ctx, cluster_idx, caller, + [&](uint64_t i, size_t worker) { + const size_t task = cluster_range.begin() + i; + int x, y; + gilbert_d2xy(&x, &y, static_cast(task), + static_cast(W), static_cast(H)); + func(ranges_mc.Range(x), ranges_nc.Range(y), worker); + }); + }); + } + // No multiple/inner_tasks, so this is just HierarchicalParallelFor. template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, @@ -305,9 +377,9 @@ enum class MMOrder : uint8_t { kNT_MT_K, kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `MMSetC`. - // Resident C (`kK_M_NT`) should be good for large K relative to M and N. - // However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are - // no kM* because we expect M (batch size) to be small relative to K and N. + // Space-filling curve (Hilbert curve) traversal over blocks of M and N. + kSFC_K, + kSFC, }; // Tag types for `DispatchOrder`. @@ -315,6 +387,8 @@ struct MMOrderNT_K {}; struct MMOrderNT {}; struct MMOrderNT_MT_K {}; struct MMOrderNT_MT {}; +struct MMOrderSFC_K {}; +struct MMOrderSFC {}; template void DispatchOrder(MMOrder order, const Func& func, Args&&... args) { @@ -327,6 +401,10 @@ void DispatchOrder(MMOrder order, const Func& func, Args&&... args) { return func(MMOrderNT_MT_K(), std::forward(args)...); case MMOrder::kNT_MT: return func(MMOrderNT_MT(), std::forward(args)...); + case MMOrder::kSFC_K: + return func(MMOrderSFC_K(), std::forward(args)...); + case MMOrder::kSFC: + return func(MMOrderSFC(), std::forward(args)...); default: HWY_UNREACHABLE; } @@ -337,7 +415,8 @@ static inline bool IsOneMC(MMOrder order) { } static inline bool IsOneKC(MMOrder order) { - return order == MMOrder::kNT || order == MMOrder::kNT_MT; + return order == MMOrder::kNT || order == MMOrder::kNT_MT || + order == MMOrder::kSFC; } static inline const char* StringFromOrder(MMOrder order) { @@ -350,6 +429,10 @@ static inline const char* StringFromOrder(MMOrder order) { return "NT_MT_K"; case MMOrder::kNT_MT: return "NT_MT"; + case MMOrder::kSFC_K: + return "SFC_K"; + case MMOrder::kSFC: + return "SFC"; default: return nullptr; } diff --git a/util/zones.cc b/util/zones.cc index b552bb17..acc26b2f 100644 --- a/util/zones.cc +++ b/util/zones.cc @@ -71,6 +71,10 @@ const char* ZoneName(Zones zone) { return "MM.NT_MT"; case Zones::kMMNT: return "MM.NT"; + case Zones::kMMSFC: + return "MM.SFC"; + case Zones::kMMSFC_K: + return "MM.SFC_K"; case Zones::kMMTwoMatMul: return "MM.TwoMatMul"; case Zones::kOpsAddFrom: @@ -153,12 +157,16 @@ const char* CallerName(Callers caller) { return "MM.ClusterForMCNC"; case Callers::kMMClusterForN: return "MM.ClusterForN"; + case Callers::kMMClusterForSFC: + return "MM.ClusterForSFC"; case Callers::kMMHierForMC: return "MM.HierForMC"; case Callers::kMMHierForMCNC: return "MM.HierForMCNC"; case Callers::kMMHierForN: return "MM.HierForN"; + case Callers::kMMHierForSFC: + return "MM.HierForSFC"; case Callers::kOpsAddFromBatched: return "Ops.AddFromBatched"; case Callers::kOpsMaybeLogitsSoftCapBatched: diff --git a/util/zones.h b/util/zones.h index ba3d5a9b..e731ff0d 100644 --- a/util/zones.h +++ b/util/zones.h @@ -41,6 +41,8 @@ enum class Zones { // Keep sorted kMMNT_MT_K, kMMNT_MT, kMMNT, + kMMSFC, + kMMSFC_K, kMMTwoMatMul, kOpsAddFrom, kOpsLogitsSoftCap, @@ -90,9 +92,11 @@ enum class Callers { // Keep sorted kMMClusterForMC, kMMClusterForMCNC, kMMClusterForN, + kMMClusterForSFC, kMMHierForMC, kMMHierForMCNC, kMMHierForN, + kMMHierForSFC, kOpsAddFromBatched, kOpsMaybeLogitsSoftCapBatched, kOpsRMSNormBatched,