Skip to content
32 changes: 32 additions & 0 deletions transformer_lens/benchmarks/component_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ def print_detailed_analysis(self) -> None:
class ComponentBenchmarker:
"""Benchmarking utility for testing TransformerBridge components against HuggingFace."""

def _is_delegated_block(self) -> bool:
"""Return True if the blocks component has maintain_native_attention set."""
blocks = (
getattr(self.adapter, "component_mapping", {}).get("blocks")
if self.adapter is not None
else None
)
return getattr(blocks, "maintain_native_attention", False)

def __init__(
self,
bridge_model: nn.Module,
Expand Down Expand Up @@ -419,6 +428,23 @@ def _test_component_recursive(
):
return

# Skip attention and PLE submodules when using DelegatedAttentionBlockBridge.
# These architectures delegate all math to HF; the benchmark can't call the HF
# attention in isolation (missing position_embeddings, attention_mask, etc.) and
# PLE submodules receive per-layer inputs at a different dimension than hidden_states.
_is_delegated = self._is_delegated_block()
if _is_delegated and "attn" in component_path:
return
if _is_delegated and any(
name in component_path
for name in (
"per_layer_input_gate",
"per_layer_projection",
"post_per_layer_input_norm",
)
):
return

# Skip models whose MLP/attn forward signatures require extra context from the block:
# - BLOOM: MLP requires residual and alibi bias
# - T5: requires cache_position for relative position embeddings
Expand Down Expand Up @@ -526,6 +552,12 @@ def _test_component(
ComponentTestResult or None if the component cannot be tested
"""
try:
# Skip rotary_emb for DelegatedAttentionBlockBridge architectures.
# Gemma4's RotaryEmbeddingBridge wraps a rotary that returns a set-like
# structure which the benchmark comparison can't subscript.
if self._is_delegated_block() and component_path == "rotary_emb":
return None

# Get bridge component
# The adapter returns nn.Module, but for bridge models it's actually GeneralizedComponent
bridge_component = cast(
Expand Down
10 changes: 5 additions & 5 deletions transformer_lens/benchmarks/main_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def cleanup_model(model, model_name_str: str):
# Use appropriate AutoModel class (e.g., AutoModelForSeq2SeqLM for T5)
auto_model_class = get_auto_model_class(model_name, trust_remote_code=trust_remote_code)
if verbose and auto_model_class != AutoModelForCausalLM:
print(f"Using {auto_model_class.__name__} for encoder-decoder model")
print(f"Using {auto_model_class.__name__}")
# Ensure pad_token_id exists (some models crash without it during init).
hf_config = AutoConfig.from_pretrained(
model_name, trust_remote_code=trust_remote_code, token=_hf_token()
Expand Down Expand Up @@ -1209,14 +1209,14 @@ def cleanup_model(model, model_name_str: str):
# PHASE 2: Bridge (unprocessed) + HookedTransformer (unprocessed)
# ========================================================================
current_phase[0] = 2
if verbose:
print(f"\n{'='*80}")
print("PHASE 2: TransformerBridge (unprocessed) + HookedTransformer (unprocessed)")
print(f"{'='*80}\n")

# OPTIMIZATION: Run generation benchmarks first (only bridge in memory)
# Then cleanup bridge before loading HT to reduce peak memory
if should_run_phase(2) and bridge_unprocessed:
if verbose:
print(f"\n{'='*80}")
print("PHASE 2: TransformerBridge (unprocessed) + HookedTransformer (unprocessed)")
print(f"{'='*80}\n")
if verbose:
print("Running Phase 2 benchmarks...\n")

Expand Down
4 changes: 4 additions & 0 deletions transformer_lens/model_bridge/generalized_components/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,10 @@ class DelegatedAttentionBlockBridge(BlockBridge):
``DelegatedAttentionBlockBridge`` and knows those hooks are absent.
"""

# Tell the component benchmark this block's attention is delegated wholesale
# to HF and cannot be tested in isolation (requires model-specific kwargs).
maintain_native_attention: bool = True

def __init__(
self,
name: str,
Expand Down
Loading