From f5e78cc506cbbae02199d81c7817d74fd5289e34 Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Wed, 10 Jun 2026 07:27:14 -0700 Subject: [PATCH] Helper functions to use arm matrix multiplication instructions in a way that is testable on other platofrms PiperOrigin-RevId: 929840936 --- ops/ops-inl.h | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 2b2f310b..d5d1b49e 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -1788,6 +1788,59 @@ static HWY_INLINE VF8 Reduce8(DF df, VF x_0, VF x_1, VF x_2, VF x_3, VF x_4, return hn::Load(df8, buf); } +template +HWY_API VI32 PerBlock2x2MatMulMaybeEmulate(DN dn, VI8 a, VI8 b, VI32 c) { +#if HWY_NATIVE_PER_BLOCK_2X2_MATMUL_INT8 + return hn::PerBlock2x2MatMul(dn, a, b, c); +#else + const hn::Repartition di8; + constexpr size_t kMaxN = hn::MaxLanes(dn); + HWY_LANES_CONSTEXPR size_t N = hn::Lanes(dn); + HWY_ALIGN int8_t in_a[kMaxN * 4]; + HWY_ALIGN int8_t in_b[kMaxN * 4]; + HWY_ALIGN int32_t expected[kMaxN]; + hn::Store(a, di8, in_a); + hn::Store(b, di8, in_b); + hn::Store(c, dn, expected); + + for (size_t block = 0; block < N; block += 4) { + const size_t block_i8 = block * 4; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + int32_t sum = 0; + for (int k = 0; k < 8; ++k) { + sum += static_cast(in_a[block_i8 + i * 8 + k]) * + static_cast(in_b[block_i8 + j * 8 + k]); + } + expected[block + i * 2 + j] += sum; + } + } + } + return hn::Load(dn, expected); +#endif +} + +template +HWY_API VF PerBlock2x2MatMulMaybeEmulate(DN dn, VBF a, VBF b, VF c) { +#if HWY_NATIVE_PER_BLOCK_2X2_MATMUL_BF16 + return hn::PerBlock2x2MatMul(dn, a, b, c); +#else + const hn::Repartition dbf; + const auto a_f = hn::BitCast(dn, a); + const auto a1 = hn::BitCast(dbf, hn::Per4LaneBlockShuffle<2, 2, 0, 0>(a_f)); + const auto a2 = hn::BitCast(dbf, hn::Per4LaneBlockShuffle<3, 3, 1, 1>(a_f)); + + const auto b_f = hn::BitCast(dn, b); + const auto b1 = hn::BitCast(dbf, hn::Per4LaneBlockShuffle<2, 0, 2, 0>(b_f)); + const auto b2 = hn::BitCast(dbf, hn::Per4LaneBlockShuffle<3, 1, 3, 1>(b_f)); + + VF sum1 = hn::Zero(dn); + VF sum0 = hn::ReorderWidenMulAccumulate(dn, a1, b1, c, sum1); + sum0 = hn::ReorderWidenMulAccumulate(dn, a2, b2, sum0, sum1); + return hn::RearrangeToOddPlusEven(sum0, sum1); +#endif +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp