Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ cc_library(
deps = [
":allocator",
":basics",
":gilbert",
":mat",
":threading",
":threading_context",
Expand All @@ -359,6 +360,7 @@ cc_library(
deps = [
":allocator",
":basics",
":gilbert",
":mat",
":matmul_env",
":threading",
Expand Down Expand Up @@ -521,6 +523,12 @@ cc_test(
],
)

cc_library(
name = "gilbert",
srcs = ["ops/gilbert.cc"],
hdrs = ["ops/gilbert.h"],
)

cc_test(
name = "bench_matmul",
size = "small",
Expand Down
117 changes: 117 additions & 0 deletions ops/gilbert.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright 2026 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Adapted from code by abetusk (BSD-2-Clause) in
// https://github.com/jakubcerveny/gilbert.

#include "ops/gilbert.h"

#include <math.h>
#include <stdlib.h>

namespace gcpp {

namespace {

int sgn(int x) {
if (x < 0) {
return -1;
}
if (x > 0) {
return 1;
}
return 0;
}

int gilbert_d2xy_r(int dst_idx, int cur_idx, int* xres, int* yres, int ax,
int ay, int bx, int by) {
int nxt_idx;
int w, h, x, y, dax, day, dbx, dby, di;
int ax2, ay2, bx2, by2, w2, h2;

w = abs(ax + ay);
h = abs(bx + by);
x = *xres;
y = *yres;
dax = sgn(ax);
day = sgn(ay);
dbx = sgn(bx);
dby = sgn(by);
di = dst_idx - cur_idx;

if (h == 1) {
*xres = x + dax * di;
*yres = y + day * di;
return 0;
}
if (w == 1) {
*xres = x + dbx * di;
*yres = y + dby * di;
return 0;
}

ax2 = ax >> 1;
ay2 = ay >> 1;
bx2 = bx >> 1;
by2 = by >> 1;
w2 = abs(ax2 + ay2);
h2 = abs(bx2 + by2);

if ((2 * w) > (3 * h)) {
if ((w2 & 1) && (w > 2)) {
ax2 += dax;
ay2 += day;
}
nxt_idx = cur_idx + abs((ax2 + ay2) * (bx + by));
if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) {
*xres = x;
*yres = y;
return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax2, ay2, bx, by);
}
cur_idx = nxt_idx;
*xres = x + ax2;
*yres = y + ay2;
return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax - ax2, ay - ay2, bx,
by);
}

if ((h2 & 1) && (h > 2)) {
bx2 += dbx;
by2 += dby;
}

nxt_idx = cur_idx + abs((bx2 + by2) * (ax2 + ay2));
if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) {
*xres = x;
*yres = y;
return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, bx2, by2, ax2, ay2);
}
cur_idx = nxt_idx;

nxt_idx = cur_idx + abs((ax + ay) * ((bx - bx2) + (by - by2)));
if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) {
*xres = x + bx2;
*yres = y + by2;
return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax, ay, bx - bx2,
by - by2);
}
cur_idx = nxt_idx;

*xres = x + (ax - dax) + (bx2 - dbx);
*yres = y + (ay - day) + (by2 - dby);
return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, -bx2, -by2, -(ax - ax2),
-(ay - ay2));
}

} // namespace

int gilbert_d2xy(int* x, int* y, int idx, int w, int h) {
*x = 0;
*y = 0;
if (w >= h) {
return gilbert_d2xy_r(idx, 0, x, y, w, 0, 0, h);
}
return gilbert_d2xy_r(idx, 0, x, y, 0, h, w, 0);
}

} // namespace gcpp
17 changes: 17 additions & 0 deletions ops/gilbert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright 2026 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Adapted from code by abetusk (BSD-2-Clause) in
// https://github.com/jakubcerveny/gilbert.

#ifndef THIRD_PARTY_GEMMA_CPP_OPS_GILBERT_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_GILBERT_H_

namespace gcpp {

// Maps a 1D Hilbert curve index to 2D coordinates (x, y).
int gilbert_d2xy(int* x, int* y, int idx, int w, int h);

} // namespace gcpp

#endif // THIRD_PARTY_GEMMA_CPP_OPS_GILBERT_H_
65 changes: 65 additions & 0 deletions ops/matmul-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,71 @@ class MMLoops {
}
});
}

// Parallel loops over mc/nc blocks of M/range_n via SFC, single K.
template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderSFC, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMSFC);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_kc = args.ranges_kc.Range(0); // whole K

parallel.ForRangesSFC(
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
MMKernel::B3A2C0(
A, B, range_mc, range_kc, range_nc, args, MMSetC(),
C.View(range_mc.begin(), range_nc.begin(), range_nc.Num()));

const StridedViewBF C2 = args.env.C_tiles.C(
Extents2D(range_mc.Num(), range_nc.Num()), worker);

if (B2 != nullptr) {
MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args,
MMSetC(), C2);
}
if constexpr (IsBF16<TC>()) {
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
}
});
}

// Parallel loops over mc/nc blocks of M/range_n via SFC, sequential K.
template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderSFC_K, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMSFC_K);

parallel.ForRangesSFC(
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
MMKernel::ForeachKC(
A, B, range_mc, args.ranges_kc, range_nc, args,
C.View(range_mc.begin(), range_nc.begin(), range_nc.Num()));

const StridedViewBF C2 = args.env.C_tiles.C(
Extents2D(range_mc.Num(), range_nc.Num()), worker);

if (B2 != nullptr) {
MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc, range_nc,
args, C2);
}

if constexpr (IsBF16<TC>()) {
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
}
});
}
}; // MMLoops

// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
Expand Down
91 changes: 87 additions & 4 deletions ops/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

// IWYU pragma: begin_exports
#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN_BRGEMM
#include "ops/gilbert.h"
#include "util/basics.h"
#include "util/mat.h"
#include "util/threading.h"
Expand Down Expand Up @@ -90,6 +91,25 @@ struct MMParallelNone {
}
}

// SFC = Space Filling Curve(specifically the Hilbert curve)
template <class Func>
void ForRangesSFC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t cluster_idx,
const Func& func) const {
const size_t worker = ctx.Worker(cluster_idx);
const size_t W = ranges_mc.NumTasks();
const size_t H = ranges_nc.NumTasks();
const size_t num_tasks = W * H;

for (size_t task = 0; task < num_tasks; ++task) {
int x, y;
gilbert_d2xy(&x, &y, static_cast<int>(task), static_cast<int>(W),
static_cast<int>(H));
func(ranges_mc.Range(x), ranges_nc.Range(y), worker);
}
}

template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t cluster_idx, const Func& func) const {
Expand Down Expand Up @@ -138,6 +158,29 @@ struct MMParallelWithinCluster {
});
}

template <class Func>
void ForRangesSFC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t cluster_idx,
const Func& func) const {
const hwy::pool::Caller caller =
ctx.pool_callers.Get(Callers::kMMClusterForSFC);

const size_t W = ranges_mc.NumTasks();
const size_t H = ranges_nc.NumTasks();
const size_t num_tasks = W * H;

ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
[&](uint64_t task, size_t worker) {
int x, y;
gilbert_d2xy(&x, &y, static_cast<int>(task),
static_cast<int>(W),
static_cast<int>(H));
func(ranges_mc.Range(x), ranges_nc.Range(y),
worker);
});
}

template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t cluster_idx, const Func& func) const {
Expand Down Expand Up @@ -208,6 +251,35 @@ struct MMParallelHierarchical {
});
}

template <class Func>
void ForRangesSFC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc,
size_t caller_cluster_idx, const Func& func) const {
HWY_DASSERT(caller_cluster_idx == 0);
(void)caller_cluster_idx;
const hwy::pool::Caller caller =
ctx.pool_callers.Get(Callers::kMMHierForSFC);

const size_t W = ranges_mc.NumTasks();
const size_t H = ranges_nc.NumTasks();
const IndexRange all_range(0, W * H);

ParallelPartitionAcrossClusters(
all_range, /*task_multiple=*/1, /*inner_tasks=*/1, ctx, caller,
[&](const IndexRange& cluster_range, size_t cluster_idx) {
ParallelForWithinCluster(
cluster_range.Num(), ctx, cluster_idx, caller,
[&](uint64_t i, size_t worker) {
const size_t task = cluster_range.begin() + i;
int x, y;
gilbert_d2xy(&x, &y, static_cast<int>(task),
static_cast<int>(W), static_cast<int>(H));
func(ranges_mc.Range(x), ranges_nc.Range(y), worker);
});
});
}

// No multiple/inner_tasks, so this is just HierarchicalParallelFor.
template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
Expand Down Expand Up @@ -305,16 +377,18 @@ enum class MMOrder : uint8_t {
kNT_MT_K,
kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `MMSetC`.

// Resident C (`kK_M_NT`) should be good for large K relative to M and N.
// However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are
// no kM* because we expect M (batch size) to be small relative to K and N.
// Space-filling curve (Hilbert curve) traversal over blocks of M and N.
kSFC_K,
kSFC,
};

// Tag types for `DispatchOrder`.
struct MMOrderNT_K {};
struct MMOrderNT {};
struct MMOrderNT_MT_K {};
struct MMOrderNT_MT {};
struct MMOrderSFC_K {};
struct MMOrderSFC {};

template <class Func, typename... Args>
void DispatchOrder(MMOrder order, const Func& func, Args&&... args) {
Expand All @@ -327,6 +401,10 @@ void DispatchOrder(MMOrder order, const Func& func, Args&&... args) {
return func(MMOrderNT_MT_K(), std::forward<Args>(args)...);
case MMOrder::kNT_MT:
return func(MMOrderNT_MT(), std::forward<Args>(args)...);
case MMOrder::kSFC_K:
return func(MMOrderSFC_K(), std::forward<Args>(args)...);
case MMOrder::kSFC:
return func(MMOrderSFC(), std::forward<Args>(args)...);
default:
HWY_UNREACHABLE;
}
Expand All @@ -337,7 +415,8 @@ static inline bool IsOneMC(MMOrder order) {
}

static inline bool IsOneKC(MMOrder order) {
return order == MMOrder::kNT || order == MMOrder::kNT_MT;
return order == MMOrder::kNT || order == MMOrder::kNT_MT ||
order == MMOrder::kSFC;
}

static inline const char* StringFromOrder(MMOrder order) {
Expand All @@ -350,6 +429,10 @@ static inline const char* StringFromOrder(MMOrder order) {
return "NT_MT_K";
case MMOrder::kNT_MT:
return "NT_MT";
case MMOrder::kSFC_K:
return "SFC_K";
case MMOrder::kSFC:
return "SFC";
default:
return nullptr;
}
Expand Down
Loading
Loading