2 Star 16 Fork 7

算法美食屋/torchkeras

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README
Apache-2.0

炼丹师,这是你的梦中情炉吗?🌹🌹

English | 简体中文

torchkeras 是一个通用的pytorch模型训练模版工具,按照如下目标进行设计和实现:

  • 好看 (代码优雅,日志美丽,自带可视化)

  • 好用 (使用方便,支持 进度条、评估指标、early-stopping等常用功能,支持tensorboard,wandb回调函数等扩展功能)

  • 好改 (修改简单,核心代码模块化,仅约200行,并提供丰富的修改使用案例)


1,炼丹之痛 😭😭

无论是学术研究还是工业落地,pytorch几乎都是目前炼丹的首选框架。

pytorch的胜出不仅在于其简洁一致的api设计,更在于其生态中丰富和强大的模型库。

但是我们会发现不同的pytorch模型库提供的训练和验证代码非常不一样。

torchvision官方提供的范例代码主要是一个关联了非常多依赖函数的train_one_epoch和evaluate函数,针对检测和分割各有一套。

yolo系列的主要是支持ddp模式的各种风格迥异的Trainer,每个不同的yolo版本都会改动很多导致不同yolo版本之间都难以通用。

抱抱脸的transformers库在借鉴了pytorch_lightning的基础上也搞了一个自己的Trainer,但与pytorch_lightning并不兼容。

非常有名的facebook的目标检测库detectron2, 也是搞了一个它自己的Trainer,配合一个全局的cfg参数设置对象来训练模型。

还有我用的比较多的语义分割的segmentation_models.pytorch这个库,设计了一个TrainEpoch和一个ValidEpoch来做训练和验证。

在学习和使用这些不同的pytorch模型库时,尝试阅读理解和改动这些训练和验证相关的代码让我受到了一万点伤害。

有些设计非常糟糕,嵌套了十几层,有些实现非常dirty,各种带下划线的私有变量满天飞。

让你每次想要改动一下加入一些自己想要的功能时就感到望而却步。

我不就想finetune一下模型嘛,何必拿这么多垃圾代码搞我?


2,梦中情炉 🤗🤗

这一切的苦不由得让我怀念起tensorflow中keras的美好了。

还记得keras那compile, fit, evalute三连击吗?一切都像行云流水般自然,真正的for humans。

而且你看任何用keras实现的模型库,训练和验证都几乎可以用这一套相同的接口,没有那么多莫名奇妙的野生Trainer。

我能否基于pytorch打造一个接口和keras一样简洁易用,功能强大,但是实现代码非常简短易懂,便于修改的模型训练工具呢?

从2020年7月左右发布1.0版本到最近发布的3.86版本,我陆陆续续在工作中一边使用一边打磨一个工具,总共提交修改了70多次。

现在我感觉我细心雕琢的这个作品终于长成了我心目中接近完美的样子。

她有一个美丽的名字:torchkeras.

是的,她兼具torch的灵动,也有keras的优雅~

并且她的美丽,无与伦比~

她,就是我的梦中情炉~ 🤗🤗


3,使用方法 🍊🍊

安装torchkeras

pip install torchkeras

通过使用torchkeras,你不需要写自己的pytorch模型训练循环。你只要做这样两步就可以了。

(1) 创建你的模型结构net,然后把它和损失函数传入torchkeras.KerasModel构建一个model。

(2) 使用model的fit方法在你的训练数据和验证数据上进行训练,训练数据和验证数据需要封装成两个DataLoader.

核心使用代码就像下面这样:

import torch 
import torchkeras
import torchmetrics
model = torchkeras.KerasModel(net,
                              loss_fn = nn.BCEWithLogitsLoss(),
                              optimizer= torch.optim.Adam(net.parameters(),lr = 1e-4),
                              metrics_dict = {"acc":torchmetrics.Accuracy(task='binary')}
                             )
dfhistory=model.fit(train_data=dl_train, 
                    val_data=dl_val, 
                    epochs=20, 
                    patience=3, 
                    ckpt_path='checkpoint',
                    monitor="val_acc",
                    mode="max",
                    plot=True
                   )

在jupyter notebook中执行训练代码,你将看到类似下面的动态可视化图像和训练日志进度条。

除此之外,torchkeras还提供了一个VLog类,方便你在任意的训练逻辑中使用动态可视化图像和日志进度条。

import time
import math,random
from torchkeras import VLog

epochs = 10
batchs = 30

#0, 指定监控北极星指标,以及指标优化方向
vlog = VLog(epochs, monitor_metric='val_loss', monitor_mode='min') 

#1, log_start 初始化动态图表
vlog.log_start() 

for epoch in range(epochs):
    
    #train
    for step in range(batchs):
        
        #2, log_step 更新step级别日志信息,打日志,并用小进度条显示进度
        vlog.log_step({'train_loss':100-2.5*epoch+math.sin(2*step/batchs)}) 
        time.sleep(0.05)
        
    #eval    
    for step in range(20):
        
        #3, log_step 更新step级别日志信息,指定training=False说明在验证模式,只打日志不更新小进度条
        vlog.log_step({'val_loss':100-2*epoch+math.sin(2*step/batchs)},training=False)
        time.sleep(0.05)
        
    #4, log_epoch 更新epoch级别日志信息,每个epoch刷新一次动态图表和大进度条进度
    vlog.log_epoch({'val_loss':100 - 2*epoch+2*random.random()-1,
                    'train_loss':100-2.5*epoch+2*random.random()-1})  

# 5, log_end 调整坐标轴范围,输出最终指标可视化图表
vlog.log_end()

4,主要特性 🍉🍉

torchkeras 支持以下这些功能特性,稳定支持这些功能的起始版本以及这些功能借鉴或者依赖的库的来源见下表。

功能 稳定支持起始版本 依赖或借鉴库
✅ 训练进度条 3.0.0 依赖tqdm,借鉴keras
✅ 训练评估指标 3.0.0 借鉴pytorch_lightning
✅ notebook中训练自带可视化 3.8.0 借鉴fastai
✅ early stopping 3.0.0 借鉴keras
✅ gpu training 3.0.0 依赖accelerate
✅ multi-gpus training(ddp) 3.6.0 依赖accelerate
✅ fp16/bf16 training 3.6.0 依赖accelerate
✅ tensorboard callback 3.7.0 依赖tensorboard
✅ wandb callback 3.7.0 依赖wandb
✅ VLog 3.9.5 依赖matplotlib

5,基本范例 🌰🌰

以下范例是torchkeras的基础范例,演示了torchkeras的主要功能。

包括基础训练,使用wandb可视化,使用wandb调参,使用tensorboard可视化,使用多GPU的ddp模式训练,通用的VLog动态日志可视化等。

example notebook kaggle链接
①基础范例 🔥🔥 basic example
Open In Kaggle

②wandb可视化 🔥🔥🔥 wandb demo
Open In Kaggle

③wandb自动化调参🔥🔥 wandb sweep demo
Open In Kaggle

④tensorboard可视化 tensorboard example
⑤ddp/tpu训练范例 ddp tpu examples
Open In Kaggle

⑥VLog动态日志可视化范例🔥🔥🔥 VLog example

6,进阶范例 🔥🔥

在炼丹实践中,遇到的数据集结构或者训练推理逻辑往往会千差万别。

例如我们可能会遇到多输入多输出结构,或者希望在训练过程中计算并打印一些特定的指标等等。

这时候炼丹师可能会倾向于使用最纯粹的pytorch编写自己的训练循环。

实际上,torchkeras提供了极致的灵活性来让炼丹师掌控训练过程的每个细节。

从这个意义上说,torchkeras更像是一个训练代码模版。

这个模版由低到高由StepRunner,EpochRunner 和 KerasModel 三个类组成。

在绝大多数场景下,用户只需要在StepRunner上稍作修改并覆盖掉,就可以实现自己想要的训练推理逻辑。

就像下面这段代码范例,这是一个多输入的例子,并且嵌入了特定的accuracy计算逻辑。

这段代码的完整范例,见examples下的CRNN_CTC验证码识别。


import torch.nn.functional as F 
from torchkeras import KerasModel
from accelerate import Accelerator

#我们覆盖KerasModel的StepRunner以实现自定义训练逻辑。
#注意这里把acc指标的结果写在了step_losses中以便和loss一样在Epoch上求平均,这是一个非常灵活而且有用的写法。

class StepRunner:
    def __init__(self, net, loss_fn, accelerator=None, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator if accelerator is not None else Accelerator()
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
    
    def __call__(self, batch):
        
        images, targets, input_lengths, target_lengths = batch
        
        #loss
        preds = self.net(images)
        preds_log_softmax = F.log_softmax(preds, dim=-1)
        loss = F.ctc_loss(preds_log_softmax, targets, input_lengths, target_lengths)
        acc = eval_acc(targets,preds)
            

        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
            
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses (or plain metric that can be averaged)
        step_losses = {self.stage+"_loss":all_loss.item(),
                       self.stage+'_acc':acc}
        
        #metrics (stateful metric)
        step_metrics = {}
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
#覆盖掉默认StepRunner 
KerasModel.StepRunner = StepRunner 

可以看到,这种修改实际上是非常简单并且灵活的,保持每个模块的输出与原始实现格式一致就行,中间处理逻辑根据需要灵活调整。

同理,用户也可以修改并覆盖EpochRunner来实现自己的特定逻辑,但我一般很少遇到有这样需求的场景。

examples目录下的范例库包括了使用torchkeras对一些非常常用的库中的模型进行训练的例子。

例如:

  • torchvision
  • transformers
  • segmentation_models_pytorch
  • ultralytics
  • timm

如果你想掌握一个东西,那么就去使用它,如果你想真正理解一个东西,那么尝试去改变它。 ———— 爱因斯坦

example 使用模型库 notebook
RL
强化学习——Q-Learning 🔥🔥 - Q-learning
强化学习——DQN - DQN
Tabular
二分类——LightGBM - LightGBM
多分类——FTTransformer🔥🔥🔥🔥🔥 - FTTransformer
二分类——FM - FM
二分类——DeepFM - DeepFM
二分类——DeepCross - DeepCross
CV
图片分类——Resnet - Resnet
语义分割——UNet - UNet
目标检测——SSD - SSD
文字识别——CRNN 🔥🔥 - CRNN-CTC
目标检测——FasterRCNN torchvision FasterRCNN
语义分割——DeepLabV3++ segmentation_models_pytorch Deeplabv3++
实例分割——MaskRCNN detectron2 MaskRCNN
图片分类——SwinTransformer timm Swin
目标检测——YOLOv8 🔥🔥🔥 ultralytics YOLOv8_Detect
实例分割——YOLOv8 🔥🔥🔥 ultralytics YOLOv8_Segment
NLP
序列翻译——Transformer🔥🔥 - Transformer
文本生成——Llama🔥 - Llama
文本分类——BERT transformers BERT
命名实体识别——BERT transformers BERT_NER
LLM微调——ChatGLM2_LoRA 🔥🔥🔥 transformers ChatGLM2_LoRA
LLM微调——ChatGLM2_AdaLoRA 🔥 transformers ChatGLM2_AdaLoRA
LLM微调——ChatGLM2_QLoRA transformers ChatGLM2_QLoRA_Kaggle
LLM微调——BaiChuan13B_QLoRA transformers BaiChuan13B_QLoRA
LLM微调——BaiChuan13B_NER 🔥🔥🔥 transformers BaiChuan13B_NER
LLM微调——BaiChuan13B_MultiRounds 🔥 transformers BaiChuan13B_MultiRounds
LLM微调——Qwen7B_MultiRounds 🔥🔥🔥 transformers Qwen7B_MultiRounds
LLM微调——BaiChuan2_13B 🔥 transformers BaiChuan2_13B

7,鼓励和联系作者 🎈🎈

如果本项目对你有所帮助,想鼓励一下作者,记得给本项目加一颗星星star⭐️,并分享给你的朋友们喔😊!

如果在torchkeras的使用中遇到问题,可以在项目中提交issue。

如果想要获得更快的反馈或者与其他torchkeras用户小伙伴进行交流,

可以在公众号算法美食屋后台回复关键字:加群

Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

简介

Pytorch❤️ Keras 😋😋 展开 收起
README
Apache-2.0
取消

发行版

暂无发行版

贡献者

全部

近期动态

不能加载更多了
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/Python_Ai_Road/torchkeras.git
git@gitee.com:Python_Ai_Road/torchkeras.git
Python_Ai_Road
torchkeras
torchkeras
master

搜索帮助