diff --git a/modellink/training.py b/modellink/training.py index bf2f72eb63b3072093d81d1aa32fc830bab5c34e..ff0cd983b697b895009c6df15954a7cdb51b7a86 100644 --- a/modellink/training.py +++ b/modellink/training.py @@ -85,6 +85,21 @@ def model_provider_func_wrapper(model_provider_func): model = get_peft_model(model, lora_config) model.add_module('module', model.get_base_model()) + + def _hook(_module, _x_in, _x_out): + """ Extract the feature map of model""" + _x_out.requires_grad_(True) + + def _create_hooks(_model, layer): + """ Make the hooks function""" + for name, module in _model.named_modules(): + if isinstance(module, megatron.core.tensor_parallel.layers.VocabParallelEmbedding): + _name = name.split('.')[-1] + if _name in layer: + module.register_forward_hook(_hook) + if args.recompute_method == 'block' and args.recompute_granularity == 'full': + _create_hooks(model, args.lora_register_forward_hook) + model.print_trainable_parameters() for module in model.modules(): # LoRA Linear Layer need all reduce