Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions ops/ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1788,6 +1788,59 @@ static HWY_INLINE VF8 Reduce8(DF df, VF x_0, VF x_1, VF x_2, VF x_3, VF x_4,
return hn::Load(df8, buf);
}

template <class DN, class VI8, class VI32, HWY_IF_I32_D(DN)>
HWY_API VI32 PerBlock2x2MatMulMaybeEmulate(DN dn, VI8 a, VI8 b, VI32 c) {
#if HWY_NATIVE_PER_BLOCK_2X2_MATMUL_INT8
return hn::PerBlock2x2MatMul(dn, a, b, c);
#else
const hn::Repartition<int8_t, DN> di8;
constexpr size_t kMaxN = hn::MaxLanes(dn);
HWY_LANES_CONSTEXPR size_t N = hn::Lanes(dn);
HWY_ALIGN int8_t in_a[kMaxN * 4];
HWY_ALIGN int8_t in_b[kMaxN * 4];
HWY_ALIGN int32_t expected[kMaxN];
hn::Store(a, di8, in_a);
hn::Store(b, di8, in_b);
hn::Store(c, dn, expected);

for (size_t block = 0; block < N; block += 4) {
const size_t block_i8 = block * 4;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
int32_t sum = 0;
for (int k = 0; k < 8; ++k) {
sum += static_cast<int32_t>(in_a[block_i8 + i * 8 + k]) *
static_cast<int32_t>(in_b[block_i8 + j * 8 + k]);
}
expected[block + i * 2 + j] += sum;
}
}
}
return hn::Load(dn, expected);
#endif
}

template <class DN, class VBF, class VF, HWY_IF_F32_D(DN)>
HWY_API VF PerBlock2x2MatMulMaybeEmulate(DN dn, VBF a, VBF b, VF c) {
#if HWY_NATIVE_PER_BLOCK_2X2_MATMUL_BF16
return hn::PerBlock2x2MatMul(dn, a, b, c);
#else
const hn::Repartition<hwy::bfloat16_t, DN> dbf;
const auto a_f = hn::BitCast(dn, a);
const auto a1 = hn::BitCast(dbf, hn::Per4LaneBlockShuffle<2, 2, 0, 0>(a_f));
const auto a2 = hn::BitCast(dbf, hn::Per4LaneBlockShuffle<3, 3, 1, 1>(a_f));

const auto b_f = hn::BitCast(dn, b);
const auto b1 = hn::BitCast(dbf, hn::Per4LaneBlockShuffle<2, 0, 2, 0>(b_f));
const auto b2 = hn::BitCast(dbf, hn::Per4LaneBlockShuffle<3, 1, 3, 1>(b_f));

VF sum1 = hn::Zero(dn);
VF sum0 = hn::ReorderWidenMulAccumulate(dn, a1, b1, c, sum1);
sum0 = hn::ReorderWidenMulAccumulate(dn, a2, b2, sum0, sum1);
return hn::RearrangeToOddPlusEven(sum0, sum1);
#endif
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
Expand Down
Loading