diff --git a/infer_engines/bash_install_sglang.sh b/infer_engines/bash_install_sglang.sh index 7a652ef4632eae8bc79e596fe3ff8873076dcdfb..80789ed538c7eb720f6522f2c0fc9334c5958198 100644 --- a/infer_engines/bash_install_sglang.sh +++ b/infer_engines/bash_install_sglang.sh @@ -23,3 +23,4 @@ git apply --whitespace=nowarn $PATCH_ROOT/promote_forward_batch_init_stage.patch git apply --whitespace=nowarn $PATCH_ROOT/overlap_mtp.patch git apply --whitespace=nowarn $PATCH_ROOT/fix_sampler.patch git apply --whitespace=nowarn $PATCH_ROOT/send_kvcache_multi_rank.patch +git apply --whitespace=nowarn $PATCH_ROOT/enable_retract_decode.patch diff --git a/omni/adaptors/sglang/patches/enable_retract_decode.patch b/omni/adaptors/sglang/patches/enable_retract_decode.patch new file mode 100644 index 0000000000000000000000000000000000000000..b02ea2bb177f1f9336cc12b5adab01f9375d1775 --- /dev/null +++ b/omni/adaptors/sglang/patches/enable_retract_decode.patch @@ -0,0 +1,93 @@ +diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py +index 5af463dd9..7db7b5a4e 100644 +--- a/python/sglang/srt/managers/schedule_batch.py ++++ b/python/sglang/srt/managers/schedule_batch.py +@@ -1373,11 +1373,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + return len(self.reqs) + # In the decoding phase, the length of a request's KV cache should be + # the total length of the request minus 1 +- return ( +- sum(1 for req in self.reqs if req.seqlen % page_size == 0) +- if self.enable_overlap +- else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0) +- ) ++ if self.spec_algorithm.is_mtp(): ++ return ( ++ sum(1 for req in self.reqs if ((req.seqlen + 1) % page_size == 0 or req.seqlen % page_size == 0)) ++ if self.enable_overlap ++ else sum(1 for req in self.reqs if ((req.seqlen - 1) % page_size == 0 or req.seqlen % page_size == 0)) ++ ) ++ else: ++ return ( ++ sum(1 for req in self.reqs if req.seqlen % page_size == 0) ++ if self.enable_overlap ++ else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0) ++ ) + + def check_decode_mem(self, buf_multiplier=1): + num_tokens = ( +diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py +index 21ae9e798..03774a8a0 100644 +--- a/python/sglang/srt/mem_cache/memory_pool.py ++++ b/python/sglang/srt/mem_cache/memory_pool.py +@@ -914,31 +914,49 @@ class MLATokenToKVPool(KVCache): + raise RuntimeError("Only NPU backend is supported for MLA KV cache now.") + + def get_cpu_copy(self, indices): +- torch.cuda.synchronize() ++ if _is_npu: ++ torch.npu.synchronize() ++ else: ++ torch.cuda.synchronize() + kv_cache_cpu = [] + chunk_size = self.cpu_offloading_chunk_size + for layer_id in range(self.layer_num): + kv_cache_cpu.append([]) + for i in range(0, len(indices), chunk_size): + chunk_indices = indices[i : i + chunk_size] +- kv_cpu = self.kv_buffer[layer_id][chunk_indices].to( ++ ++ kv_cpu_k = self.k_buffer[layer_id][chunk_indices].to( + "cpu", non_blocking=True + ) +- kv_cache_cpu[-1].append(kv_cpu) +- torch.cuda.synchronize() ++ kv_cpu_v = self.v_buffer[layer_id][chunk_indices].to( ++ "cpu", non_blocking=True ++ ) ++ kv_cache_cpu[-1].append((kv_cpu_k, kv_cpu_v)) ++ if _is_npu: ++ torch.npu.synchronize() ++ else: ++ torch.cuda.synchronize() + return kv_cache_cpu + + def load_cpu_copy(self, kv_cache_cpu, indices): +- torch.cuda.synchronize() ++ if _is_npu: ++ torch.npu.synchronize() ++ else: ++ torch.cuda.synchronize() + chunk_size = self.cpu_offloading_chunk_size + for layer_id in range(self.layer_num): + for i in range(0, len(indices), chunk_size): + chunk_indices = indices[i : i + chunk_size] +- kv_cpu = kv_cache_cpu[layer_id][i // chunk_size] +- assert kv_cpu.shape[0] == len(chunk_indices) +- kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True) +- self.kv_buffer[layer_id][chunk_indices] = kv_chunk +- torch.cuda.synchronize() ++ kv_cpu_k, kv_cpu_v = kv_cache_cpu[layer_id][i // chunk_size] ++ assert kv_cpu_k.shape[0] == kv_cpu_v.shape[0] == len(chunk_indices) ++ k_chunk = kv_cpu_k.to(self.k_buffer[0].device, non_blocking=True) ++ v_chunk = kv_cpu_v.to(self.v_buffer[0].device, non_blocking=True) ++ self.k_buffer[layer_id][chunk_indices] = k_chunk ++ self.v_buffer[layer_id][chunk_indices] = v_chunk ++ if _is_npu: ++ torch.npu.synchronize() ++ else: ++ torch.cuda.synchronize() + + + class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): + \ No newline at end of file