代码拉取完成,页面将自动刷新
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generation config."""
import copy
import os
import json
from typing import Any, Dict, Union, Optional
from mindformers.models.configuration_utils import PretrainedConfig
from mindformers.tools.logger import logger
GENERATION_CONFIG_NAME = "generation_config.json"
__all__ = ["GenerationConfig"]
class GenerationConfig:
r"""
Class that holds a configuration for a generation task.
Some parameters have specific functions, see the table below for details:
+------------------------------------------------------------+------------------------------+
| Functional classification | Configuration parameter |
+============================================================+==============================+
| Parameters that control the length of the output | max_length |
| +------------------------------+
| | max_new_tokens |
| +------------------------------+
| | min_length |
| +------------------------------+
| | min_new_tokens |
+------------------------------------------------------------+------------------------------+
| Parameters that control the generation strategy used | do_sample |
| +------------------------------+
| | use_past |
+------------------------------------------------------------+------------------------------+
| Parameters for manipulation of the model output logits | temperature |
| +------------------------------+
| | top_k |
| +------------------------------+
| | top_p |
| +------------------------------+
| | repetition_penalty |
| +------------------------------+
| | encoder_repetition_penalty |
| +------------------------------+
| | renormalize_logits |
+------------------------------------------------------------+------------------------------+
| Parameters that define the output variables of `generate` | output_scores |
| +------------------------------+
| | output_logits |
| +------------------------------+
| | return_dict_in_generate |
+------------------------------------------------------------+------------------------------+
| Special tokens that can be used at generation time | pad_token_id |
| +------------------------------+
| | bos_token_id |
| +------------------------------+
| | eos_token_id |
+------------------------------------------------------------+------------------------------+
Args:
**kwargs (Any): Arguments of text generation.
- max_length (int, optional): The maximum length the generated tokens can have.
Corresponds to the length of the input prompt + `max_new_tokens`.
If `max_new_tokens` is also set, the effect of `max_length` is overridden by `max_new_tokens`.
Default: ``20``.
- max_new_tokens (int, optional): The maximum numbers of tokens to generate,
ignoring the number of tokens in the prompt. Default: ``None``.
- min_length (int, optional): The minimum length of the sequence to be generated.
Corresponds to the length of the input prompt + `min_new_tokens`.
If `min_new_tokens` is also set, the effect of `min_length` is overridden by `min_new_tokens`.
Default: ``0``.
- min_new_tokens (int, optional): The minimum numbers of tokens to generate,
ignoring the number of tokens in the prompt. Default: ``None``.
- do_sample (bool, optional): Whether to use sampling ; ``True`` means using sampling encoding,
``False`` means using greedy decoding. Default: ``False``.
- use_past (bool, optional): Whether the model should use the past last key/values attentions
(if applicable to the model) to speed up decoding. Default: ``False``.
- temperature (float, optional): The value used to modulate the next token probabilities.
Default: ``1.0``.
- top_k (int, optional): The number of highest probability vocabulary tokens to keep for top-k-filtering.
Default: ``50``.
- top_p (float, optional): If set to ``float < 1``, only the smallest set of most probable tokens with
probabilities that add up to `top_p` or higher are kept for generation. Default: ``1.0``.
- repetition_penalty (float, optional): The parameter for repetition penalty. 1.0 means no penalty.
Greater than 1.0 means that repetition is penalized and less than 1.0 means that repetition is rewarded.
See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`_ for more details. Default: ``1.0`` .
- encoder_repetition_penalty (float, optional): The parameter for encoder_repetition_penalty.
An exponential penalty on sequences that are not in the original input.
1.0 means no penalty. Greater than 1.0 means that repetition is penalized and less than 1.0 means that
repetition is rewarded. Default: ``1.0`` .
- renormalize_logits (bool, optional): Whether to renormalize the logits after applying all the logits
processors or warpers (including the custom ones). It's highly recommended to set this flag to `True` as
the search algorithms suppose the score logits are normalized but some logit processors or warpers break
the normalization. Default: ``False``.
- output_scores (bool, optional): Whether to return the prediction scores before softmax.
Default: ``False``.
- output_logits (bool, optional): Whether to return the unprocessed prediction logit scores.
Default: ``False``.
- return_dict_in_generate (bool, optional): Whether to return a dictionary output instead of a
tuple with output_ids. Only when this is set to True, can generate other output items besides output_ids.
Default: ``False``.
- pad_token_id (int, optional): The id of the padding token. Default: ``None``.
- bos_token_id (int, optional): The id of the beginning-of-sequence token. Default: ``None``.
- eos_token_id (Union[int, List[int]], optional): The id of the end-of-sequence token.
Optionally, use a list to set multiple *end-of-sequence* tokens. Default: ``[]``.
Returns:
Instance of GenerationConfig.
Examples:
>>> from mindformers.generation import GenerationConfig
>>> config = GenerationConfig()
>>> print(config)
{'max_length': 20, 'max_new_tokens': None, 'min_length': 0, 'min_new_tokens': None, 'num_beams': 1,
'do_sample': False, 'use_past': False, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'repetition_penalty':
1.0, 'encoder_repetition_penalty': 1.0, 'renormalize_logits': False, 'return_dict_in_generate': False,
'output_scores': False, 'output_logits': False, 'pad_token_id': None, 'bos_token_id': None, 'eos_token_id':
[], '_from_model_config': False}
>>> config = GenerationConfig(max_length=100, min_length=10, do_sample=True, top_k=5, top_p=0.8)
>>> print(config)
{'max_length': 100, 'max_new_tokens': None, 'min_length': 10, 'min_new_tokens': None, 'num_beams': 1,
'do_sample': True, 'use_past': False, 'temperature': 1.0, 'top_k': 5, 'top_p': 0.8, 'repetition_penalty':
1.0, 'encoder_repetition_penalty': 1.0, 'renormalize_logits': False, 'return_dict_in_generate': False,
'output_scores': False, 'output_logits': False, 'pad_token_id': None, 'bos_token_id': None, 'eos_token_id':
[], '_from_model_config': False}
"""
def __init__(self, **kwargs):
# max generate length
self.max_length = kwargs.pop("max_decode_length", 20)
self.max_length = kwargs.pop("max_length", self.max_length)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
self.min_length = kwargs.pop("min_length", 0)
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
# number of beams
self.num_beams = kwargs.pop("num_beams", 1)
# do sample or not
self.do_sample = kwargs.pop("do_sample", False)
# incremental infer
self.use_past = kwargs.pop("use_past", False)
# logits processors
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0)
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
# dictionary output
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
self.output_scores = kwargs.pop("output_scores", False)
self.output_logits = kwargs.pop("output_logits", False)
# Special tokens that can be used at generation time
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", [])
# parallel_decoding
self.parallel_decoding = kwargs.pop("parallel_decoding", False)
self.window_size = kwargs.pop("window_size", 5)
self.level = kwargs.pop("level", 5)
self.guess_set_size = kwargs.pop("guess_set_size", 3)
if isinstance(self.eos_token_id, int):
self.eos_token_id = [self.eos_token_id]
# interface.
self._from_model_config = kwargs.pop("_from_model_config", False)
# Additional attributes without default values
if not self._from_model_config:
# we don't want to copy values from the model config
# if we're initializing a `GenerationConfig` from a
# model's default configuration file
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error("Can't set %s with value %s for %s", key, value, self)
raise err
def __str__(self) -> str:
return str(self.__dict__)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
"""
Instantiates a [`GenerationConfig`] from a Python dictionary of parameters.
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object.
kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object.
Returns:
[`GenerationConfig`], the configuration object instantiated from those parameters.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
config = cls(**{**config_dict, **kwargs})
unused_kwargs = config.update(**kwargs)
if return_unused_kwargs:
return config, unused_kwargs
return config
@classmethod
def from_pretrained(
cls,
pretrained_model_name: Union[str, os.PathLike],
config_file_name: Optional[Union[str, os.PathLike]] = None,
**kwargs,
) -> "GenerationConfig":
r"""
Instantiate a [`GenerationConfig`] from a generation configuration file.
```python
```"""
config_file_name = (
config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
)
config_path = os.path.join(pretrained_model_name, config_file_name)
resolved_config_file = os.path.realpath(config_path)
generation_config_is_exist = os.path.exists(resolved_config_file)
if generation_config_is_exist:
logger.info("Loading Generate config from %s", resolved_config_file)
try:
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise EnvironmentError(
f"It looks like the config file at '{resolved_config_file}' "
"is not a valid JSON file."
) from e
config = cls.from_dict(config_dict, **kwargs)
else:
config = cls(**kwargs)
return config
@classmethod
def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
"""
Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`].
This function is useful to convert legacy [`PretrainedConfig`] objects,
which may contain generation parameters, into a stand-alone [`GenerationConfig`].
Args:
model_config (`PretrainedConfig`):
The model config that will be used to instantiate the generation config.
Returns:
[`GenerationConfig`], the configuration object instantiated from those parameters.
"""
config_dict = model_config.to_dict()
config_dict.pop("_from_model_config", None)
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
return config
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs`
if they match existing atributtes, returning all the unused kwargs.
Args:
kwargs (`Dict[str, Any]`):
Dictionary of attributes to tentatively update this class.
Returns:
`Dict[str, Any]`, dictionary containing all the key-value pairs
that were not used to update the instance.
"""
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
def to_dict(self) -> Dict[str, Any]:
"""to dict convert function."""
output = copy.deepcopy(self.__dict__)
return output
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。