class MambaHybridModelState(DefaultModelState):
"""Model state for hybrid attention + Mamba / linear-attention models."""
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
) -> None:
super().__init__(vllm_config, model, encoder_cache, device)
self.num_accepted_tokens_gpu = torch.ones(
self.max_num_reqs, dtype=torch.int32, device=self.device
)
def prepare_attn(
self,
input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
for_capture: bool = False,
) -> dict[str, Any]:
if cudagraph_mode == CUDAGraphMode.FULL:
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
else:
num_reqs = input_batch.num_reqs
num_tokens = input_batch.num_tokens
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
is_prefilling = torch.zeros(num_reqs, dtype=torch.bool, device="cpu")
is_prefilling[: input_batch.num_reqs] = torch.from_numpy(
input_batch.is_prefilling_np
)
# During CUDAGraph capture, num_decode_draft_tokens_cpu and num_accepted_tokens
# are created by attn_metadata_builder.build_for_cudagraph_capture, so we only
# compute them during actual (non-capture) forward execution.
num_accepted_tokens = None
num_decode_draft_tokens_cpu = None
if not for_capture:
num_accepted_tokens = self.num_accepted_tokens_gpu.new_ones(num_reqs)
num_accepted_tokens[: input_batch.num_reqs] = self.num_accepted_tokens_gpu[
input_batch.idx_mapping
]
# GDN uses >= 0 to select spec-decode rows, so non-decode rows
# need the -1 sentinel rather than a raw zero draft count.
num_decode_draft_tokens_np = np.full(num_reqs, -1, dtype=np.int32)
if input_batch.num_draft_tokens_per_req is not None:
spec_decode_mask = (
input_batch.num_draft_tokens_per_req > 0
) & ~input_batch.is_prefilling_np
num_decode_draft_tokens_np[: input_batch.num_reqs] = np.where(
spec_decode_mask,
input_batch.num_draft_tokens_per_req,
-1,
)
num_decode_draft_tokens_cpu = torch.from_numpy(num_decode_draft_tokens_np)
mamba_attn_metadata = MambaHybridAttnMetadata(
is_prefilling=is_prefilling,
num_accepted_tokens=num_accepted_tokens,
num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu,
)
return build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
model_specific_attn_metadata=mamba_attn_metadata,
for_cudagraph_capture=for_capture,
)
def postprocess_state(
self,
input_batch: InputBatch,
num_sampled: torch.Tensor,
) -> None:
# Chunked prefill does not sample a token, so num_sampled can be 0.
# Mamba treats num_accepted_tokens=1 as the neutral non-spec value.
self.num_accepted_tokens_gpu[input_batch.idx_mapping] = torch.clamp(
num_sampled, min=1
)