zhiqu22
commited on
Commit
·
1d88f72
1
Parent(s):
317d82a
implement early stop
Browse files- modeling_mitre.py +103 -28
modeling_mitre.py
CHANGED
|
@@ -11,8 +11,6 @@ from transformers.utils import logging
|
|
| 11 |
from transformers.generation import GenerationMixin
|
| 12 |
from transformers.modeling_utils import PreTrainedModel
|
| 13 |
from transformers.activations import ACT2FN
|
| 14 |
-
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 15 |
-
from transformers.integrations.fsdp import is_fsdp_managed_module
|
| 16 |
from transformers.modeling_outputs import (
|
| 17 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 18 |
Seq2SeqLMOutput,
|
|
@@ -75,10 +73,6 @@ class MitreSdpaAttention(nn.Module):
|
|
| 75 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 76 |
attention_mask: Optional[torch.Tensor] = None,
|
| 77 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 78 |
-
"""
|
| 79 |
-
Input shape: Batch x Time x Channel
|
| 80 |
-
Output objects: attn_output, attn_weights (always be None), past_key_value
|
| 81 |
-
"""
|
| 82 |
"""
|
| 83 |
1. MitreModel is using MitreSdpaAttention, which is modifed from M2M100SdpaAttention.
|
| 84 |
Notabley, both of them do not support `output_attentions=True` or `layer_head_mask` not None,
|
|
@@ -360,6 +354,8 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
| 360 |
|
| 361 |
elif past_key_values_length > 0:
|
| 362 |
# in generation
|
|
|
|
|
|
|
| 363 |
mask = torch.zeros(past_key_values_length + 1)
|
| 364 |
mask = mask.to(embeds, copy=True)
|
| 365 |
batch_mask = mask.unsqueeze(0).expand(b, -1).clone().contiguous()
|
|
@@ -374,7 +370,6 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
| 374 |
batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
|
| 375 |
return batch_mask
|
| 376 |
|
| 377 |
-
|
| 378 |
def forward(
|
| 379 |
self,
|
| 380 |
input_ids: Optional[torch.Tensor] = None,
|
|
@@ -531,7 +526,6 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
| 531 |
cache_value[:, :, src_length - max_register_num:, :]
|
| 532 |
)
|
| 533 |
next_decoder_cache += (clipped_rep,)
|
| 534 |
-
|
| 535 |
|
| 536 |
if past_key_values_length == 0:
|
| 537 |
hidden_states = hidden_states[:,src_length:,:]
|
|
@@ -759,6 +753,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
| 759 |
|
| 760 |
@staticmethod
|
| 761 |
def _reorder_register_cache(t, beam_idx):
|
|
|
|
| 762 |
return t.index_select(dim=0, index=beam_idx.to(t.device))
|
| 763 |
|
| 764 |
@staticmethod
|
|
@@ -782,15 +777,32 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
| 782 |
):
|
| 783 |
"""
|
| 784 |
Inference with beam search.
|
| 785 |
-
This code is
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
"""
|
| 795 |
if generation_config != None:
|
| 796 |
assert type(generation_config) is GenerationConfig
|
|
@@ -831,13 +843,18 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
| 831 |
past_key_values = None
|
| 832 |
registering_cache= None
|
| 833 |
attention_mask = None
|
|
|
|
|
|
|
|
|
|
| 834 |
|
|
|
|
|
|
|
| 835 |
logits_processor = LogitsProcessorList()
|
| 836 |
stopping_criteria = StoppingCriteriaList()
|
| 837 |
|
| 838 |
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input_ids.device)
|
| 839 |
beam_scores[:, 1:] = -1e9
|
| 840 |
-
beam_scores = beam_scores.view((batch_size * beam_size,))
|
| 841 |
while not this_peer_finished:
|
| 842 |
|
| 843 |
if past_key_values is not None:
|
|
@@ -850,7 +867,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
| 850 |
attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
|
| 851 |
else:
|
| 852 |
decoder_input_ids_for_generation = decoder_input_ids
|
| 853 |
-
|
| 854 |
outputs = self(
|
| 855 |
input_ids,
|
| 856 |
decoder_input_ids_for_generation,
|
|
@@ -859,21 +876,43 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
| 859 |
use_cache=True,
|
| 860 |
registering_cache=registering_cache
|
| 861 |
)
|
| 862 |
-
|
| 863 |
del input_ids
|
| 864 |
input_ids = None
|
| 865 |
|
| 866 |
past_key_values = outputs.past_key_values
|
| 867 |
registering_cache = outputs.registering_cache
|
| 868 |
-
|
| 869 |
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
| 870 |
-
|
| 871 |
|
|
|
|
| 872 |
next_token_scores = nn.functional.log_softmax(
|
| 873 |
next_token_logits, dim=-1
|
| 874 |
) # (batch_size * num_beams, vocab_size)
|
| 875 |
|
| 876 |
next_token_scores_processed = logits_processor(decoder_input_ids, next_token_scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
| 878 |
next_token_scores_processed
|
| 879 |
)
|
|
@@ -892,6 +931,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
| 892 |
|
| 893 |
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
| 894 |
next_tokens = next_tokens % vocab_size
|
|
|
|
| 895 |
beam_outputs = beam_scorer.process(
|
| 896 |
decoder_input_ids,
|
| 897 |
next_token_scores,
|
|
@@ -904,15 +944,50 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
| 904 |
beam_scores = beam_outputs["next_beam_scores"]
|
| 905 |
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
| 906 |
beam_idx = beam_outputs["next_beam_indices"]
|
| 907 |
-
decoder_input_ids = torch.cat([decoder_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
| 908 |
|
| 909 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 910 |
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 915 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 916 |
cur_len = cur_len + 1
|
| 917 |
|
| 918 |
if beam_scorer.is_done:
|
|
|
|
| 11 |
from transformers.generation import GenerationMixin
|
| 12 |
from transformers.modeling_utils import PreTrainedModel
|
| 13 |
from transformers.activations import ACT2FN
|
|
|
|
|
|
|
| 14 |
from transformers.modeling_outputs import (
|
| 15 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 16 |
Seq2SeqLMOutput,
|
|
|
|
| 73 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 74 |
attention_mask: Optional[torch.Tensor] = None,
|
| 75 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
"""
|
| 77 |
1. MitreModel is using MitreSdpaAttention, which is modifed from M2M100SdpaAttention.
|
| 78 |
Notabley, both of them do not support `output_attentions=True` or `layer_head_mask` not None,
|
|
|
|
| 354 |
|
| 355 |
elif past_key_values_length > 0:
|
| 356 |
# in generation
|
| 357 |
+
# this block is only used in fairseq and is not used in huggingface,
|
| 358 |
+
# because we reuse the mask by the cache.
|
| 359 |
mask = torch.zeros(past_key_values_length + 1)
|
| 360 |
mask = mask.to(embeds, copy=True)
|
| 361 |
batch_mask = mask.unsqueeze(0).expand(b, -1).clone().contiguous()
|
|
|
|
| 370 |
batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
|
| 371 |
return batch_mask
|
| 372 |
|
|
|
|
| 373 |
def forward(
|
| 374 |
self,
|
| 375 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 526 |
cache_value[:, :, src_length - max_register_num:, :]
|
| 527 |
)
|
| 528 |
next_decoder_cache += (clipped_rep,)
|
|
|
|
| 529 |
|
| 530 |
if past_key_values_length == 0:
|
| 531 |
hidden_states = hidden_states[:,src_length:,:]
|
|
|
|
| 753 |
|
| 754 |
@staticmethod
|
| 755 |
def _reorder_register_cache(t, beam_idx):
|
| 756 |
+
""" a costumized reorder method """
|
| 757 |
return t.index_select(dim=0, index=beam_idx.to(t.device))
|
| 758 |
|
| 759 |
@staticmethod
|
|
|
|
| 777 |
):
|
| 778 |
"""
|
| 779 |
Inference with beam search.
|
| 780 |
+
This code is improved from 'transformers.generation.utils.GenerationMixin.generate'.
|
| 781 |
+
There are **two main improved points**:
|
| 782 |
+
1. 'soft early_stop' in beam search.
|
| 783 |
+
a) problem in the vanilla version.
|
| 784 |
+
In multilingual translation model, e.g., NLLB and M2M, they adopt the 'vanilla early_
|
| 785 |
+
stop' in BeamSearchScorer (the official implementation provided by HuggingFace), i.e.,
|
| 786 |
+
the sequence, which is labled by 'end', is filled by 'pad(1)' still, in other words,
|
| 787 |
+
the ended sequence is fed into the model still, resulting in a heavy memory waste.
|
| 788 |
+
b) our improvement.
|
| 789 |
+
We implement soft early_stop to resolve the problem. Specifically, we do not change
|
| 790 |
+
anything in BeamSearchScorer to keep the codes' flexibility, rather we remove the ended
|
| 791 |
+
sequence from the input. Then, given that the output hidden states' shape is changed,
|
| 792 |
+
we insert some placeholders to keep the shape of BeamSearchScorer's states.
|
| 793 |
+
Based on our test, this improvement can decrease the memory cost to half than before.
|
| 794 |
+
2. mask reusing.
|
| 795 |
+
a) problem: registers need attention masks in each step.
|
| 796 |
+
A sequence possibly consists 4 parts, i.e., pads, source tokens, registers, and target
|
| 797 |
+
tokens. In training, we mask all tokens before registers for the generation of target
|
| 798 |
+
tokens. As a result, in generation, we cannot allow the target tokens to 'see' pads.
|
| 799 |
+
So, we need masks in each step, leading to computational resource waste.
|
| 800 |
+
b) our improvement.
|
| 801 |
+
First, we turncate the source tokens to save cost.
|
| 802 |
+
Second, given that there still exists some source tokens playing the role of placeholders,
|
| 803 |
+
we modify the mask generation compared to our codes in fairseq.
|
| 804 |
+
Third, in order to avoid re-generating masks, we add the mask into 'registering_cache'.
|
| 805 |
+
Then, we manage its order as the kv cache in beam search, and add a column of 0. every step.
|
| 806 |
"""
|
| 807 |
if generation_config != None:
|
| 808 |
assert type(generation_config) is GenerationConfig
|
|
|
|
| 843 |
past_key_values = None
|
| 844 |
registering_cache= None
|
| 845 |
attention_mask = None
|
| 846 |
+
# done_mask shows the ended sequences.
|
| 847 |
+
# (~done_mask) shows the running sequences.
|
| 848 |
+
done_mask = None
|
| 849 |
|
| 850 |
+
# we follow the style of M2M and NLLB
|
| 851 |
+
# so we simplify the initialization of thoes two processors.
|
| 852 |
logits_processor = LogitsProcessorList()
|
| 853 |
stopping_criteria = StoppingCriteriaList()
|
| 854 |
|
| 855 |
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input_ids.device)
|
| 856 |
beam_scores[:, 1:] = -1e9
|
| 857 |
+
beam_scores = beam_scores.view((batch_size * beam_size,))
|
| 858 |
while not this_peer_finished:
|
| 859 |
|
| 860 |
if past_key_values is not None:
|
|
|
|
| 867 |
attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
|
| 868 |
else:
|
| 869 |
decoder_input_ids_for_generation = decoder_input_ids
|
| 870 |
+
|
| 871 |
outputs = self(
|
| 872 |
input_ids,
|
| 873 |
decoder_input_ids_for_generation,
|
|
|
|
| 876 |
use_cache=True,
|
| 877 |
registering_cache=registering_cache
|
| 878 |
)
|
|
|
|
| 879 |
del input_ids
|
| 880 |
input_ids = None
|
| 881 |
|
| 882 |
past_key_values = outputs.past_key_values
|
| 883 |
registering_cache = outputs.registering_cache
|
|
|
|
| 884 |
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
| 885 |
+
del outputs
|
| 886 |
|
| 887 |
+
next_token_logits = next_token_logits.to(device)
|
| 888 |
next_token_scores = nn.functional.log_softmax(
|
| 889 |
next_token_logits, dim=-1
|
| 890 |
) # (batch_size * num_beams, vocab_size)
|
| 891 |
|
| 892 |
next_token_scores_processed = logits_processor(decoder_input_ids, next_token_scores)
|
| 893 |
+
|
| 894 |
+
# if any sequence is ended, we have to keep the shape of Scorer's states.
|
| 895 |
+
# Details are described in the head of this function.
|
| 896 |
+
if done_mask is not None:
|
| 897 |
+
if done_mask.any():
|
| 898 |
+
# the placeholder of scores is '0.'
|
| 899 |
+
restored_tensor = torch.zeros(
|
| 900 |
+
(batch_size * beam_size, next_token_scores_processed.shape[1]),
|
| 901 |
+
dtype=next_token_scores_processed.dtype,
|
| 902 |
+
device=next_token_scores_processed.device
|
| 903 |
+
)
|
| 904 |
+
restored_tensor[~done_mask] = next_token_scores_processed
|
| 905 |
+
next_token_scores_processed = restored_tensor
|
| 906 |
+
# the placeholder of tokens is 'pad_token_id'
|
| 907 |
+
restored_tokens = torch.full(
|
| 908 |
+
(batch_size * beam_size, decoder_input_ids.shape[1]),
|
| 909 |
+
self.generation_config.pad_token_id,
|
| 910 |
+
dtype=decoder_input_ids.dtype,
|
| 911 |
+
device=device
|
| 912 |
+
)
|
| 913 |
+
restored_tokens[~done_mask] = decoder_input_ids
|
| 914 |
+
decoder_input_ids = restored_tokens
|
| 915 |
+
|
| 916 |
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
| 917 |
next_token_scores_processed
|
| 918 |
)
|
|
|
|
| 931 |
|
| 932 |
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
| 933 |
next_tokens = next_tokens % vocab_size
|
| 934 |
+
|
| 935 |
beam_outputs = beam_scorer.process(
|
| 936 |
decoder_input_ids,
|
| 937 |
next_token_scores,
|
|
|
|
| 944 |
beam_scores = beam_outputs["next_beam_scores"]
|
| 945 |
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
| 946 |
beam_idx = beam_outputs["next_beam_indices"]
|
|
|
|
| 947 |
|
| 948 |
+
# 'last_done_mask' is used for reordering cache
|
| 949 |
+
# details are described in the next code block
|
| 950 |
+
if done_mask is not None:
|
| 951 |
+
last_done_mask = done_mask
|
| 952 |
+
|
| 953 |
+
# get the newest status of sequences.
|
| 954 |
+
# then, filter the beam_idx
|
| 955 |
+
done_mask = beam_scorer._done.clone().view(-1)
|
| 956 |
+
done_mask = self._expand_inputs_for_generation(done_mask, beam_size)
|
| 957 |
+
beam_idx = beam_idx[~done_mask]
|
| 958 |
|
| 959 |
+
decoder_input_ids = torch.cat([decoder_input_ids[beam_idx, :], beam_next_tokens[~done_mask].unsqueeze(-1)], dim=-1)
|
| 960 |
+
|
| 961 |
+
# different from processing tokens, caches' order is decided by 'tokens', 'done_mask' and
|
| 962 |
+
# 'beam_idx', simultaneously.
|
| 963 |
+
if decoder_input_ids_for_generation.shape[0] < beam_next_tokens.shape[0]:
|
| 964 |
+
# Take carefule! If the running sequences' num is small than the num of input sequences,
|
| 965 |
+
# it means the Scorer decides to end it, but the cache still follows the last status.
|
| 966 |
+
# Therefore, we should employ the last done mask rather than newest done mask.
|
| 967 |
+
if (~done_mask).sum() < decoder_input_ids_for_generation.shape[0]:
|
| 968 |
+
count_mask = last_done_mask
|
| 969 |
+
else:
|
| 970 |
+
count_mask = done_mask
|
| 971 |
+
# For biasing the beam_idx
|
| 972 |
+
# Example:
|
| 973 |
+
# done_mask with beam size of 2: [f, f, t, t, f, f]
|
| 974 |
+
# beam_idx: [0, 0, 2, 2, 4, 5]
|
| 975 |
+
# reorder_idx: [0-0, 0-0, 4-2, 5-2]
|
| 976 |
+
prefix_sum = torch.cat([
|
| 977 |
+
torch.zeros_like(count_mask[:1], dtype=torch.long),
|
| 978 |
+
torch.cumsum(count_mask.long(), dim=0)
|
| 979 |
+
], dim=0)
|
| 980 |
+
reorder_idx = beam_idx - prefix_sum[beam_idx]
|
| 981 |
+
not_done = ~done_mask[beam_idx]
|
| 982 |
+
reorder_idx = reorder_idx[not_done]
|
| 983 |
+
else:
|
| 984 |
+
reorder_idx = beam_idx
|
| 985 |
|
| 986 |
+
past_key_values = self._reorder_cache(past_key_values, reorder_idx)
|
| 987 |
+
registering_cache["register_nums"] = self._reorder_register_cache(registering_cache["register_nums"], reorder_idx)
|
| 988 |
+
if registering_cache["attention_mask"] is not None:
|
| 989 |
+
registering_cache["attention_mask"] = self._reorder_register_cache(registering_cache["attention_mask"], reorder_idx)
|
| 990 |
+
|
| 991 |
cur_len = cur_len + 1
|
| 992 |
|
| 993 |
if beam_scorer.is_done:
|