diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 427bdf10..30534ec8 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -484,7 +484,10 @@ const std::vector att_out_gold = { 0.010839, 0.010652, 0.010471, 0.010297, 0.010128, 0.009965, 0.009807, 0.009653}; -void TestTiledFlashAttention() { +template +void RunTiledFlashAttentionTest(gcpp::KVEncoding kv_encoding, + AttentionImpl attention_impl, + float tol, float tol_exp, float tol_max) { size_t qkv_dim = 64; size_t kv_seq_len = 60; // number of tokens we will attend to. // Not divisible by tiles size to test the padding logic. @@ -498,11 +501,15 @@ void TestTiledFlashAttention() { kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); ThreadingArgs threading_args; ThreadingContext ctx(threading_args); - MatStorageT kv("kv", - Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize, - 2 * qkv_dim * gcpp::KVCache::kTileSize), - ctx.allocator, MatPadding::kPacked); - PopulateTestKVCache(kv, gcpp::KVEncoding::kF32, qkv_dim); + + size_t num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; + size_t tile_size_bytes = *gcpp::GetTileSizeBytes(kv_encoding, qkv_dim); + size_t tile_size_in_elements = tile_size_bytes / sizeof(KV_T); + + MatStorageT kv("kv", Extents2D(num_tiles, tile_size_in_elements), + ctx.allocator, MatPadding::kPacked); + PopulateTestKVCache(kv, kv_encoding, qkv_dim); + AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), @@ -535,347 +542,82 @@ void TestTiledFlashAttention() { } hwy::Span kvs(&kv, 1); - DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( - kvs, num_queries, q_all.data(), - hwy::Span(start_pos_per_query), - hwy::Span(last_pos_per_query), att_cap, att_out, - exp_denominator_sums.data(), max_logits.data()); - - // TODO: Replace with Other implementation for generating goldens. - // Current values are taken from a point in time where code was run with gemma - // and output looked good. Not ideal but should be good enough to test the - // plumbing and detect regressions. - PrintMatPtr(att_out); - for (size_t i = 0; i < num_queries; ++i) { - std::cerr << "exp_d: " << exp_denominator_sums[i] - << " max_logit: " << max_logits[i] << std::endl; - EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-3f) - << "i=" << i; - EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-6f) << "i=" << i; - for (size_t j = 0; j < qkv_dim; ++j) { - EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-5f); - } + if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) { + std::vector> bf16_queries(num_queries * + qkv_dim); + CompressQueriesBF16Contiguous(q_all.data(), qkv_dim, num_queries, + bf16_queries.data()); + + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( + kvs, num_queries, bf16_queries.data(), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); + } else if (attention_impl == AttentionImpl::kFlashTransposedQsInt16) { + std::vector> int16_queries( + num_queries * qkv_dim); + AlignedFloatVector q_scales(num_queries); + CompressQueriesInt16Contiguous(q_all.data(), qkv_dim, num_queries, + int16_queries.data(), q_scales.data()); + + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( + kvs, num_queries, int16_queries.data(), q_scales, + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); + } else { + DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( + kvs, num_queries, q_all.data(), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); } -} -void TestTiledFlashAttentionBF16() { - size_t qkv_dim = 64; - size_t kv_seq_len = 60; // number of tokens we will attend to. - // Not divisible by tiles size to test the padding logic. - size_t padded_kv_seq_len = - hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); - float att_cap = 10.0f; - size_t num_queries = 8; - size_t num_queries_per_timestep = 4; - size_t num_tokens = num_queries / num_queries_per_timestep; - size_t kv_seq_end = - kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); - ThreadingArgs threading_args; - ThreadingContext ctx(threading_args); - MatStorageT kv("kv", - Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize, - 2 * qkv_dim * gcpp::KVCache::kTileSize), - ctx.allocator, MatPadding::kPacked); - PopulateTestKVCache(kv, gcpp::KVEncoding::kBF16TwoTranspositions, qkv_dim); - - AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); - std::vector> bf16_queries(num_queries * - qkv_dim); - CompressQueriesBF16Contiguous(q_all.data(), qkv_dim, num_queries, - bf16_queries.data()); - - MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), - ctx.allocator, MatPadding::kPacked); - - HWY_LANES_CONSTEXPR size_t lanes = 4; - size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes); - std::vector exp_denominator_sums(num_queries_rounded_to_laness); - std::vector max_logits(num_queries_rounded_to_laness); - for (size_t i = 0; i < num_queries; ++i) { - hwy::ZeroBytes(att_out.Row(i), - qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); - exp_denominator_sums[i] = 0.0f; - max_logits[i] = -std::numeric_limits::max() / 2.0f; - } - std::vector> start_pos_per_query; - std::vector> last_pos_per_query; - start_pos_per_query.reserve(num_queries); - last_pos_per_query.reserve(num_queries); - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - ssize_t query_last_pos = kv_seq_end + token_idx; - ssize_t query_start_pos = - std::max(query_last_pos - 100000 + 1, static_cast(0)); - for (size_t q_head_idx = 0; q_head_idx < num_queries_per_timestep; - ++q_head_idx) { - start_pos_per_query.push_back(query_start_pos); - last_pos_per_query.push_back(query_last_pos); - } - } - hwy::Span kvs(&kv, 1); - DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( - kvs, num_queries, bf16_queries.data(), - hwy::Span(start_pos_per_query), - hwy::Span(last_pos_per_query), att_cap, att_out, - exp_denominator_sums.data(), max_logits.data()); - - // TODO: Replace with Other implementation for generating goldens. - // Current values are taken from a point in time where code was run with gemma - // and output looked good. Not ideal but should be good enough to test the - // plumbing and detect regressions. PrintMatPtr(att_out); for (size_t i = 0; i < num_queries; ++i) { std::cerr << "exp_d: " << exp_denominator_sums[i] << " max_logit: " << max_logits[i] << std::endl; - EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 4e-2f) + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], tol_exp) << "i=" << i; - EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; + EXPECT_NEAR(max_logits[i], max_logits_gold[i], tol_max) << "i=" << i; for (size_t j = 0; j < qkv_dim; ++j) { - if (j == 0) { + if (j == 0 && attention_impl == AttentionImpl::kFlashTransposedQsBF16 && + kv_encoding == gcpp::KVEncoding::kBF16TwoTranspositions) { std::cerr << "att_out[0][" << j << "]=" << att_out.Row(i)[j] << " gold=" << att_out_gold[i * qkv_dim + j] << "\n"; } - EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-3f); + EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], tol); } } } -void TestTiledFlashAttentionInt8() { - size_t qkv_dim = 64; - // number of tokens we will attend to. - // Not divisible by tiles size to test the padding logic. - size_t kv_seq_len = 60; - size_t padded_kv_seq_len = - hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); - float att_cap = 10.0f; - size_t num_queries = 8; - size_t num_queries_per_timestep = 4; - size_t num_tokens = num_queries / num_queries_per_timestep; - size_t kv_seq_end = - kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); - ThreadingArgs threading_args; - ThreadingContext ctx(threading_args); - - size_t num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; - size_t tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize + - 2 * sizeof(BF16) * gcpp::KVCache::kTileSize; - - MatStorageT kv("kv", Extents2D(num_tiles, tile_size_bytes), - ctx.allocator, MatPadding::kPacked); - PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8, qkv_dim); - - AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); - - MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), - ctx.allocator, MatPadding::kPacked); - using DF = hn::ScalableTag; - const DF df; - HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df); - size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes); - std::vector exp_denominator_sums(num_queries_rounded_to_laness); - std::vector max_logits(num_queries_rounded_to_laness); - for (size_t i = 0; i < num_queries; ++i) { - hwy::ZeroBytes(att_out.Row(i), - qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); - exp_denominator_sums[i] = 0.0f; - max_logits[i] = -std::numeric_limits::max() / 2.0f; - } - std::vector> start_pos_per_query; - std::vector> last_pos_per_query; - start_pos_per_query.reserve(num_queries); - last_pos_per_query.reserve(num_queries); - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - ssize_t query_last_pos = kv_seq_end + token_idx; - ssize_t query_start_pos = - std::max(query_last_pos - 100000 + 1, static_cast(0)); - for (size_t q_head_idx = 0; q_head_idx < num_queries_per_timestep; - ++q_head_idx) { - start_pos_per_query.push_back(query_start_pos); - last_pos_per_query.push_back(query_last_pos); - } - } +void TestTiledFlashAttention() { + RunTiledFlashAttentionTest(gcpp::KVEncoding::kF32, + AttentionImpl::kFlash, 1e-5f, 1e-3f, 1e-6f); +} - hwy::Span kvs(&kv, 1); - DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( - kvs, num_queries, q_all.data(), - hwy::Span(start_pos_per_query), - hwy::Span(last_pos_per_query), att_cap, att_out, - exp_denominator_sums.data(), max_logits.data()); - - // TODO: Replace with Other implementation for generating goldens. - // Current values are taken from a point in time where code was run with gemma - // and output looked good. Not ideal but should be good enough to test the - // plumbing and detect regressions. - PrintMatPtr(att_out); - for (size_t i = 0; i < num_queries; ++i) { - std::cerr << "exp_d: " << exp_denominator_sums[i] - << " max_logit: " << max_logits[i] << std::endl; - EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f) - << "i=" << i; - EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; - for (size_t j = 0; j < qkv_dim; ++j) { - EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f); - } - } +void TestTiledFlashAttentionBF16() { + RunTiledFlashAttentionTest(gcpp::KVEncoding::kBF16TwoTranspositions, + AttentionImpl::kFlashTransposedQsBF16, 1e-3f, + 4e-2f, 1e-3f); } +void TestTiledFlashAttentionInt8() { + RunTiledFlashAttentionTest(gcpp::KVEncoding::kInt8, + AttentionImpl::kFlash, 5e-3f, 2e-2f, + 1e-3f); +} void TestTiledFlashAttentionInt8BF16() { - size_t qkv_dim = 64; - size_t kv_seq_len = 60; // number of tokens we will attend to. Not divisible - // by tiles size to test the padding logic. - size_t padded_kv_seq_len = - hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); - float att_cap = 10.0f; - size_t num_queries = 8; - size_t num_queries_per_timestep = 4; - size_t num_tokens = num_queries / num_queries_per_timestep; - size_t kv_seq_end = - kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); - ThreadingArgs threading_args; - ThreadingContext ctx(threading_args); - - size_t num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; - size_t tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize + - 2 * sizeof(BF16) * gcpp::KVCache::kTileSize; - - MatStorageT kv("kv", Extents2D(num_tiles, tile_size_bytes), - ctx.allocator, MatPadding::kPacked); - - // fill in kvs with predictable, synthetic data matching BF16 paired layout - PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8TwoTranspositions, qkv_dim); - - AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); - std::vector> bf16_queries(num_queries * - qkv_dim); - CompressQueriesBF16Contiguous(q_all.data(), qkv_dim, num_queries, - bf16_queries.data()); - - MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), - ctx.allocator, MatPadding::kPacked); - using DF = hn::ScalableTag; - const DF df; - HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df); - size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes); - std::vector exp_denominator_sums(num_queries_rounded_to_laness); - std::vector max_logits(num_queries_rounded_to_laness); - for (size_t i = 0; i < num_queries; ++i) { - hwy::ZeroBytes(att_out.Row(i), - qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); - exp_denominator_sums[i] = 0.0f; - max_logits[i] = -std::numeric_limits::max() / 2.0f; - } - std::vector> start_pos_per_query; - std::vector> last_pos_per_query; - start_pos_per_query.reserve(num_queries); - last_pos_per_query.reserve(num_queries); - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - ssize_t query_last_pos = kv_seq_end + token_idx; - ssize_t query_start_pos = - std::max(query_last_pos - 100000 + 1, static_cast(0)); - for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep; - ++q_head_idx) { - start_pos_per_query.push_back(query_start_pos); - last_pos_per_query.push_back(query_last_pos); - } - } - - hwy::Span kvs(&kv, 1); - DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( - kvs, num_queries, bf16_queries.data(), - hwy::Span(start_pos_per_query), - hwy::Span(last_pos_per_query), att_cap, att_out, - exp_denominator_sums.data(), max_logits.data()); - - PrintMatPtr(att_out); - for (int i = 0; i < num_queries; ++i) { - std::cerr << "exp_d: " << exp_denominator_sums[i] - << " max_logit: " << max_logits[i] << std::endl; - EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f) - << "i=" << i; - EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; - for (int j = 0; j < qkv_dim; ++j) { - EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f); - } - } + RunTiledFlashAttentionTest(gcpp::KVEncoding::kInt8TwoTranspositions, + AttentionImpl::kFlashTransposedQsBF16, + 5e-3f, 2e-2f, 1e-3f); } void TestTiledFlashAttentionInt8Int16() { - int qkv_dim = 64; - int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by - // tiles size to test the padding logic. - int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); - float att_cap = 10.0f; - int num_queries = 8; - int num_queries_per_timestep = 4; - int num_tokens = num_queries / num_queries_per_timestep; - int kv_seq_end = - kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); - ThreadingArgs threading_args; - ThreadingContext ctx(threading_args); - - int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; - int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize + - 2 * sizeof(BF16) * gcpp::KVCache::kTileSize; - - MatStorageT kv("kv", Extents2D(num_tiles, tile_size_bytes), - ctx.allocator, MatPadding::kPacked); - - // fill in kvs with predictable, synthetic data matching BF16 paired layout - PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8TwoTranspositions, qkv_dim); - - AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); - std::vector> int16_queries( - num_queries * qkv_dim); - AlignedFloatVector q_scales(num_queries); - CompressQueriesInt16Contiguous(q_all.data(), qkv_dim, num_queries, - int16_queries.data(), q_scales.data()); - MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), - ctx.allocator, MatPadding::kPacked); - using DF = hn::ScalableTag; - const DF df; - HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df); - size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes); - std::vector exp_denominator_sums(num_queries_rounded_to_laness); - std::vector max_logits(num_queries_rounded_to_laness); - for (size_t i = 0; i < num_queries; ++i) { - hwy::ZeroBytes(att_out.Row(i), - qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); - exp_denominator_sums[i] = 0.0f; - max_logits[i] = -std::numeric_limits::max() / 2.0f; - } - std::vector> start_pos_per_query; - std::vector> last_pos_per_query; - start_pos_per_query.reserve(num_queries); - last_pos_per_query.reserve(num_queries); - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - ssize_t query_last_pos = kv_seq_end + token_idx; - ssize_t query_start_pos = - std::max(query_last_pos - 100000 + 1, static_cast(0)); - for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep; - ++q_head_idx) { - start_pos_per_query.push_back(query_start_pos); - last_pos_per_query.push_back(query_last_pos); - } - } - - hwy::Span kvs(&kv, 1); - DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( - kvs, num_queries, int16_queries.data(), q_scales, - hwy::Span(start_pos_per_query), - hwy::Span(last_pos_per_query), att_cap, att_out, - exp_denominator_sums.data(), max_logits.data()); - - PrintMatPtr(att_out); - for (int i = 0; i < num_queries; ++i) { - std::cerr << "exp_d: " << exp_denominator_sums[i] - << " max_logit: " << max_logits[i] << std::endl; - EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f) - << "i=" << i; - EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; - for (int j = 0; j < qkv_dim; ++j) { - EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f); - } - } + RunTiledFlashAttentionTest(gcpp::KVEncoding::kInt8TwoTranspositions, + AttentionImpl::kFlashTransposedQsInt16, + 5e-3f, 2e-2f, 1e-3f); } // NOLINTNEXTLINE(google-readability-namespace-comments)