diff --git a/research/baichuan2/baichuan2_13b.py b/research/baichuan2/baichuan2_13b.py index baf4883a5fb269c25b38b4fb19699c09325f6a4d..55d6ea37051e04129a0eeb34133385749ba4f8b8 100644 --- a/research/baichuan2/baichuan2_13b.py +++ b/research/baichuan2/baichuan2_13b.py @@ -368,6 +368,12 @@ class Baichuan13BV2Model(Baichuan2PreTrainedModel): else: alibi_tensor = self.gather(self.alibi_tensor, batch_valid_length, 2) alibi_tensor = self.transpose(alibi_tensor, (2, 1, 0, 3)) + alibi_tensor = self.slice(alibi_tensor, + (0, 0, 0, 0), + (alibi_tensor.shape[0], alibi_tensor.shape[1], alibi_tensor.shape[2], + block_tables.shape[1] * self.block_size), + (1, 1, 1, 1), + ) # tokens: [bs, seq/1] h = self.tok_embeddings(tokens) h = self.reshape(h, (bs, seq_len, self.hidden_size))