dflash简易实现

发布时间:2026/6/28 9:17:01
dflash简易实现 dflash简易实现 DFlash: Block Diffusion for Flash Speculative Decoding. Key innovations over EAGLE: 1. Block diffusion drafting: generates an entire block of tokens in a single parallel forward pass (vs EAGLEs sequential autoregression). 2. KV injection: target models hidden states are injected directly into the draft model, so it skips context processing and focuses purely on predicting the next block. 3. Sequence mixing: W_mix matrix lets unmasked tokens influence adjacent [MASK] positions during denoising. 4. Positional encoding: breaks symmetry across [MASK] positions. Reference: https://arxiv.org/abs/2602.06036 from __future__ import annotations import time from dataclasses import dataclass, field import numpy as np # --------------------------------------------------------------------------- # 配置 # --------------------------------------------------------------------------- VOCAB_SIZE 1000 HIDDEN_DIM 64 NUM_LAYERS 4 BLOCK_SIZE 8 NUM_DENOISE_STEPS 2 MASK_TOKEN_ID 0 TEMPERATURE 1.0 # --------------------------------------------------------------------------- # Target Model # --------------------------------------------------------------------------- dataclass class TargetModel: layers: list[np.ndarray] field(default_factorylist) lm_head: np.ndarray field(default_factorylambda: np.zeros((1, 1))) token_embed: np.ndarray field(default_factorylambda: np.zeros((1, 1))) def __post_init__(self): rng np.random.RandomState(42) self.layers [ rng.randn(HIDDEN_DIM, HIDDEN_DIM) * 0.1 for _ in range(NUM_LAYERS) ] self.lm_head rng.randn(HIDDEN_DIM, VOCAB_SIZE) * 0.1 self.token_embed rng.randn(VOCAB_SIZE, HIDDEN_DIM) * 0.1 def embed_tokens(self, token_ids: list[int]) - np.ndarray: return np.stack([self.token_embed[tid] for tid in token_ids], axis0) def forward(self, input_ids: list[int]) - dict: h self.embed_tokens(input_ids) for i, layer in enumerate(self.layers): h h layer if i NUM_LAYERS - 1: h np.maximum(h, 0) logits h self.lm_head return {logits: logits, hidden: h, all_hidden: h} def sample_token(self, logits: np.ndarray, temperature: float TEMPERATURE) - int: if temperature 0: probs np.exp(logits / temperature) probs probs / probs.sum() return int(np.random.choice(VOCAB_SIZE, pprobs)) return int(np.argmax(logits)) def logits_to_probs(self, logits: np.ndarray, temperature: float TEMPERATURE) - np.ndarray: if temperature 0: logits logits / temperature logits logits - logits.max() probs np.exp(logits) return probs / probs.sum() # --------------------------------------------------------------------------- # DFlash Block Diffusion Draft Model # --------------------------------------------------------------------------- dataclass class DFlashDraftModel: Lightweight block diffusion draft model. Key design: 1. KV injection: receives targets last hidden state as context, skipping full forward computation. 2. Block diffusion: predicts entire token block in one parallel forward, refining through multi-step denoising. 3. Sequence mixing: W_mix matrix lets unmasked tokens influence adjacent [MASK] positions. 4. Positional encoding: breaks symmetry across [MASK] positions. W_proj: np.ndarray field(default_factorylambda: np.zeros((1, 1))) W_block: np.ndarray field(default_factorylambda: np.zeros((1, 1))) W_mix: np.ndarray field(default_factorylambda: np.zeros((1, 1))) pos_embed: np.ndarray field(default_factorylambda: np.zeros((1, 1))) lm_head: np.ndarray field(default_factorylambda: np.zeros((1, 1))) token_embed: np.ndarray field(default_factorylambda: np.zeros((1, 1))) def __post_init__(self): rng np.random.RandomState(789) self.W_proj rng.randn(HIDDEN_DIM, HIDDEN_DIM) * 0.1 self.W_block rng.randn(HIDDEN_DIM HIDDEN_DIM, HIDDEN_DIM) * 0.2 self.W_mix rng.randn(BLOCK_SIZE, BLOCK_SIZE) * 0.1 self.pos_embed rng.randn(BLOCK_SIZE, HIDDEN_DIM) * 0.1 self.lm_head rng.randn(HIDDEN_DIM, VOCAB_SIZE) * 0.1 self.token_embed rng.randn(VOCAB_SIZE, HIDDEN_DIM) * 0.1 def forward( self, target_hidden: np.ndarray, block_tokens: list[int], ) - np.ndarray: One forward pass predicting logits for the entire block. Args: target_hidden: [HIDDEN_DIM] targets last hidden state (KV injection) block_tokens: [BLOCK_SIZE] current block tokens (some may be MASK) Returns: logits: [BLOCK_SIZE, VOCAB_SIZE] # KV injection: project target hidden into draft space ctx target_hidden self.W_proj # [HIDDEN_DIM] # Token embeddings positional encoding embeds np.stack([self.token_embed[t] for t in block_tokens], axis0) embeds embeds self.pos_embed # [BLOCK_SIZE, HIDDEN_DIM] ctx_expanded np.tile(ctx, (BLOCK_SIZE, 1)) # [BLOCK_SIZE, HIDDEN_DIM] # Concatenate context and token embedding combined np.concatenate([ctx_expanded, embeds], axis-1) hidden combined self.W_block # [BLOCK_SIZE, HIDDEN_DIM] # Sequence mixing: let unmasked tokens influence adjacent [MASK] positions hidden self.W_mix hidden # [BLOCK_SIZE, HIDDEN_DIM] hidden np.maximum(hidden, 0) # ReLU logits hidden self.lm_head # [BLOCK_SIZE, VOCAB_SIZE] return logits def logits_to_probs(self, logits: np.ndarray) - np.ndarray: logits logits - logits.max(axis-1, keepdimsTrue) probs np.exp(logits) return probs / probs.sum(axis-1, keepdimsTrue) # --------------------------------------------------------------------------- # DFlash Decoder # --------------------------------------------------------------------------- dataclass class DFlashDecoder: DFlash speculative decoder. Flow: 1. Target forward - hidden state sample t1 token 2. Block diffusion draft: init all-MASK block - multi-step denoising - final tokens 3. Target one forward to verify entire block 4. Speculative accept/reject target: TargetModel field(default_factoryTargetModel) draft: DFlashDraftModel field(default_factoryDFlashDraftModel) block_size: int BLOCK_SIZE num_denoise_steps: int NUM_DENOISE_STEPS mask_token_id: int MASK_TOKEN_ID # ------------------------------------------------------------------ # Block Diffusion Draft # ------------------------------------------------------------------ def _block_diffusion_draft( self, target_hidden: np.ndarray, t_1: int, ) - tuple[list[int], list[np.ndarray]]: Block diffusion draft generation. 1. Init: all [MASK] 2. Each denoise step: draft model predicts all positions in parallel, then unmask highest-confidence positions 3. Final output: fully determined token sequence Args: target_hidden: [HIDDEN_DIM] targets last hidden state t_1: token just sampled by target, used as key condition for draft # Fuse t_1 embedding into context so draft knows what was just generated t_1_embed self.target.token_embed[t_1] fused_hidden target_hidden t_1_embed # [HIDDEN_DIM] # Init: all MASK block [self.mask_token_id] * self.block_size unmasked [False] * self.block_size # Denoising schedule: progressively unmask more positions for step in range(self.num_denoise_steps): logits self.draft.forward(fused_hidden, block) probs self.draft.logits_to_probs(logits) # Confidence max probability at each position confidences probs.max(axis-1) # [BLOCK_SIZE] # Decide how many positions to unmask this step remaining sum(1 for u in unmasked if not u) if remaining 0: break if step self.num_denoise_steps - 1: num_to_unmask remaining # final step: unmask all else: num_to_unmask max(1, remaining // (self.num_denoise_steps - step)) # Pick highest-confidence masked positions masked_indices [i for i in range(self.block_size) if not unmasked[i]] masked_indices.sort(keylambda i: -confidences[i]) to_unmask masked_indices[:num_to_unmask] for i in to_unmask: # Sample token (with temperature for early steps) p probs[i] if step self.num_denoise_steps - 1: p np.exp(np.log(p 1e-12) / 0.8) p p / p.sum() token int(np.random.choice(VOCAB_SIZE, pp)) block[i] token unmasked[i] True # Final probability distribution (from last forward) final_logits self.draft.forward(fused_hidden, block) final_probs self.draft.logits_to_probs(final_logits) return block, [final_probs[i] for i in range(self.block_size)] # ------------------------------------------------------------------ # Verification # ------------------------------------------------------------------ def _verify_block( self, prefix_ids: list[int], draft_tokens: list[int], ) - tuple[np.ndarray, np.ndarray]: Target model verifies entire draft block in one forward pass. verify_ids prefix_ids draft_tokens outputs self.target.forward(verify_ids) verify_logits outputs[logits] verify_probs np.array([ self.target.logits_to_probs(verify_logits[i]) for i in range(len(verify_ids)) ]) return verify_logits, verify_probs # ------------------------------------------------------------------ # Speculative Accept/Reject # ------------------------------------------------------------------ def _speculative_accept( self, draft_tokens: list[int], draft_probs: list[np.ndarray], verify_probs: np.ndarray, prefix_len: int, ) - list[int]: Speculative accept/reject for draft block. Key: draft_tokens[i] is verified using the targets prediction at position (prefix_len - 1 i), which is the output after processing the token *before* draft_tokens[i]. Once a token is rejected, all subsequent tokens are discarded. accepted: list[int] [] for i, draft_tok in enumerate(draft_tokens): # draft_tokens[i] is at index prefix_len i in verify_ids, # but the target predicts it based on the token at index # prefix_len - 1 i (the preceding position) target_pos prefix_len - 1 i target_prob verify_probs[target_pos] p_target float(target_prob[draft_tok]) p_draft float(draft_probs[i][draft_tok]) if p_target p_draft: accepted.append(draft_tok) elif np.random.random() p_target / max(p_draft, 1e-12): accepted.append(draft_tok) else: # Rejected: sample from residual distribution residual target_prob - draft_probs[i] residual np.maximum(residual, 0) residual_sum residual.sum() if residual_sum 1e-12: residual / residual_sum corrected int(np.random.choice(VOCAB_SIZE, presidual)) else: corrected int(np.argmax(target_prob)) accepted.append(corrected) break # stop after rejection return accepted # ------------------------------------------------------------------ # Main Generation Loop # ------------------------------------------------------------------ def generate( self, input_ids: list[int], max_new_tokens: int, *, verbose: bool False, ) - list[int]: tokens list(input_ids) generated: list[int] [] total_target_forwards 0 total_drafted 0 total_accepted 0 while len(generated) max_new_tokens: # Step 1: Target forward outputs self.target.forward(tokens) total_target_forwards 1 last_logits outputs[logits][-1] last_hidden outputs[hidden][-1] t_1 self.target.sample_token(last_logits) tokens.append(t_1) generated.append(t_1) if len(generated) max_new_tokens: break # Step 2: Block diffusion draft (with t_1 as condition) draft_tokens, draft_probs self._block_diffusion_draft(last_hidden, t_1) total_drafted len(draft_tokens) if verbose: print(f Draft block: {draft_tokens}) # Step 3: Target verify verify_logits, verify_probs self._verify_block(tokens, draft_tokens) total_target_forwards 1 # Step 4: Speculative accept accepted_tokens self._speculative_accept( draft_tokens, draft_probs, verify_probs, len(tokens), ) total_accepted len(accepted_tokens) for tok in accepted_tokens: if len(generated) max_new_tokens: break tokens.append(tok) generated.append(tok) if verbose: print(f Accepted: {accepted_tokens} ({len(accepted_tokens)}/{len(draft_tokens)})) if verbose: avg_accept total_accepted / max(total_drafted, 1) print(f\n Total target forwards: {total_target_forwards}) print(f Total drafted: {total_drafted}, accepted: {total_accepted}) print(f Avg acceptance rate: {avg_accept:.2%}) print(f Theoretical speedup: {max_new_tokens / max(total_target_forwards, 1):.2f}x) return generated[:max_new_tokens] # --------------------------------------------------------------------------- # Vanilla Autoregressive Decoding # --------------------------------------------------------------------------- def vanilla_generate( target: TargetModel, input_ids: list[int], max_new_tokens: int, *, verbose: bool False, ) - list[int]: tokens list(input_ids) generated: list[int] [] total_forwards 0 while len(generated) max_new_tokens: outputs target.forward(tokens) total_forwards 1 next_tok target.sample_token(outputs[logits][-1]) tokens.append(next_tok) generated.append(next_tok) if verbose: print(f Total forwards: {total_forwards}) return generated # --------------------------------------------------------------------------- # Experiments # --------------------------------------------------------------------------- def run_experiments(): input_ids [1, 2, 3, 4, 5] max_new_tokens 20 target TargetModel() dflash DFlashDecoder(targettarget) print( * 60) print(DFlash Speculative Decoding Demo) print( * 60) print(f Vocab: {VOCAB_SIZE}, Hidden: {HIDDEN_DIM}, Layers: {NUM_LAYERS}) print(f Block size: {BLOCK_SIZE}, denoise steps: {NUM_DENOISE_STEPS}) print(f Input: {input_ids}, max_new_tokens: {max_new_tokens}) print() # ---- Vanilla ---- print(--- Vanilla Autoregressive Decoding ---) np.random.seed(42) t0 time.perf_counter() vanilla_tokens vanilla_generate(target, input_ids, max_new_tokens, verboseTrue) vanilla_time time.perf_counter() - t0 print(f Generated: {vanilla_tokens}) print(f Time: {vanilla_time:.4f}s) print() # ---- DFlash ---- print(--- DFlash Speculative Decoding ---) np.random.seed(42) t0 time.perf_counter() dflash_tokens dflash.generate(input_ids, max_new_tokens, verboseTrue) dflash_time time.perf_counter() - t0 print(f Generated: {dflash_tokens}) print(f Time: {dflash_time:.4f}s) print() # ---- Comparison ---- print( * 60) print(Comparison) print( * 60) print(f Vanilla time: {vanilla_time:.4f}s) print(f DFlash time: {dflash_time:.4f}s) if dflash_time 0: print(f Speedup: {vanilla_time / dflash_time:.2f}x) print() # ---- Averaged over 10 runs ---- print(--- Averaged over 10 runs ---) vanilla_times [] dflash_times [] for seed in range(10): np.random.seed(seed) t0 time.perf_counter() vanilla_generate(target, input_ids, max_new_tokens) vanilla_times.append(time.perf_counter() - t0) np.random.seed(seed) t0 time.perf_counter() dflash.generate(input_ids, max_new_tokens) dflash_times.append(time.perf_counter() - t0) avg_v sum(vanilla_times) / len(vanilla_times) avg_d sum(dflash_times) / len(dflash_times) print(f Avg vanilla: {avg_v:.4f}s) print(f Avg DFlash: {avg_d:.4f}s) print(f Avg speedup: {avg_v / avg_d:.2f}x) if __name__ __main__: run_experiments()