1 Star 0 Fork 0

明故为知 / PyTorch-Project-Framework

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README
MIT

PyTorch-Project-Framework

Language grade: Python Travis CircleCI License PRs Welcome

A high cohesion, low coupling, and plug-and-play project framework for PyTorch.

Folder Structure

  ├── configs
  |    ├── BaseConfig.py  - the loader of all configuration file
  |    ├── BaseTest.py  - the test class of all configuration file
  |    ├── Env.py  - the loader of environmental configuration file
  |    └── Run.py  - the loader of hyperparameter configuration file
  |
  ├── datasets
  |    ├── functional  - the package of functional methods
  |    ├── BaseDataset.py  - the abstract class of all dataset
  |    ├── BaseTest.py  - the test class of all dataset
  |    └── ...  - any dataset of your project
  |
  ├── models
  |    ├── functional  - the package of functional methods
  |    ├── shallow  - the package of shallow methods
  |    ├── BaseModel.py  - the abstract class of all model
  |    ├── BaseTest.py  - the test class of all model
  |    └── ...  - any model of your project
  |
  ├── res
  |    ├── env  - the folder contains any json file of environmental configuration
  |    ├── datasets  - the folder contains any json file of dataset configuration
  |    ├── models  - the folder contains any json file of model configuration
  |    └── run  - the folder contains any json file of hyperparameter configuration
  |
  ├── test
  |    ├── test_configs.py  - the unittest classes of package configs
  |    ├── test_datasets.py  - the unittest classes of package datasets
  |    ├── test_models.py  - the unittest classes of package models
  |    └── test_utils.py  - the unittest classes of package utils
  |
  ├── utils
  |    ├── common.py  - the common methods
  |    ├── logger.py  - the logger class
  |    ├── summary.py  - the summary class
  |    └── ...  - any utils of your project
  |
  ├── main.py  - the main class of framework
  |
  └── test_component.py  - the global test class

Main Components

Datasets

  • Base dataset

    Base dataset is an abstract class that must be Inherited by any dataset you create, the idea behind this is that there's much shared stuff between all datasets. The base dataset mainly contains:

    • more - add / update unique configuration to dataset
    • load - load dataset
    • _recover - split single data
    • split - create trainset and testset
  • Your dataset

    Here's where you implement your dataset. So you should:

    • Create your dataset class and inherit the BaseDataset class
    • Override load method
    • Override other methods if your need special implementation
    • Add your dataset name to datasets/__init__.py
    • Create json file of your dataset's configuration in res/datasets/

Models

  • Base model

    Base model is an abstract class that must be Inherited by any model you create, the idea behind this is that there's much shared stuff between all models. The base model mainly contains:

    • check_cfg - filter data set
    • train - train step
    • test - test step
    • load - load previously trained model
    • save - save model
  • Your model

    Here's where you implement your model. So you should:

    • Create your model class and inherit the BaseModel class
    • Override train / test method
    • Override other methods if your need special implementation
    • Add your model name to models/__init__.py
    • Create json file of your model's configuration in res/models/

How to Use

Here's how to use this framework, you should do the following:

  • Dataset

    • In datasets folder create a class that inherit the BaseDataset class

       # YourDataset.py
       class YourDataset(datasets.BaseDataset):
           def __init__(self, cfg, **kwargs):
           super(YourDataset, self).__init__(cfg, **kwargs)
    • Override load method to load dataset

       # In YourDataset class
       def load(self):
           """
           Here load your dataset
           The parameters in `cfg` are load from json file of your dataset's configuration
           For example:
           - Create 4 random images of size (depth, height, width) as source data 
           - Create 4 random labels as target data
           Return data dictionary and the amount of data
           """
      
           data_count = 4
           source = np.random.rand(data_count, self.cfg.depth, self.cfg.height, self.cfg.width)
           target = np.random.randint(0, self.cfg.label_count, (data_count, 1))
      
           return {'source': source, 'target': target}, data_count
    • Add your dataset name to datasets/__init__.py

      from .YourDataset import YourDataset
    • Create json file of your dataset's configuration in res/datasets/

      {
          "name": "YourDataset", // same with your dataset class name
          // All dataset parameter your need where create `YourDataset` class
          // For example, the size of images and K-fold cross-validation
          "source": {
              depth: 3,
              height: 128,
              width: 128
          },
          "cross_folds": 2
      }
  • Model

    • In models folder create a class that inherit the BaseModel class

       # YourModel.py
       class YourModel(models.BaseModels):
           def __init__(self, cfg, data_cfg, run, **kwargs):
           super(YourModel, self).__init__(cfg, data_cfg, run, **kwargs)
      
           # The parameters in `cfg` are load from json file of your model's configuration
           # The parameters in `data_cfg` are load from json file of dataset's configuration
           # The parameters in `run` are load from json file of hyperparameter configuration
      
           # Create model, optimizer, criterion, and etc.
           # For example:
           # - model: Linear
           # - criterion: L1 loss
           # - optimizer: Adam
           self.model = nn.Linear(self.cfg.input_dims, self.cfg.output_dims).to(self.device)
           self.criterion = nn.L1Loss.to(self.device)
           self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.run.lr, betas=(self.run.b1, self.run.b2))
    • Override two methods train and test to write the logic of the training and testing process

      # In YourModel class
      def train(self, batch_idx, sample_dict):
          """
          batch_idx: the index of batch
          sample_dict: the dictionary of train data
      
          Implement the logic of training process
          For example:
              source -> [model] -> predict -> [criterion] (+target) -> loss
          Return loss dictionary
          """
          source = sample_dict['source'].to(self.device)
          target = sample_dict['target'].to(self.device)
      
          self.model.train()
          self.optimizer.zero_grad()
          predict = self.model(source)
          loss = self.criterion(predict, target)
          loss.backward()
          self.optimizer.step()
      
          # Others you need to calculate
      
          return {'loss': loss}
      
      def test(self, batch_idx, sample_dict):
          """
          batch_idx: the index of batch
          sample_dict: the dictionary of test data
      
          Implement the logic of testing process
          For example:
              source -> [model] -> predict
          Return dictionary of data which you want saved
          """
          source = sample_dict['source'].to(self.device)
          target = sample_dict['target'].to(self.device)
      
          self.model.eval()
          predict = self.model(source)
      
          # Others you need to calculate
      
          return {'target': target, 'predict': predict}
    • Add your model name to models/__init__.py

      from .YourModel import YourModel
    • Create json file of your model's configuration in res/models/

      {
          "name": "YourModel", // same with your model class name
          // All model parameter your need where create `YourModel` class
          // For example, the dimensions of input and output
          "input_dims": 256,
          "output_dims": 1
      }
  • Hyperparameter

    • Create json file of your hyperparameter's configuration in res/run/

      {
          "name": "YourHP",
          // Basic hyperparameter
          "batch_size": 32,
          "epochs": 200,
          "save_step": 10,
          // Hyperparameters your need where create optimizer in `YourModel` class or others
          // For example, learning rate
          "lr": 2e-4
      }
  • Run main.py to start training or testing

    • Training with configuration files res/datasets/yourdataset.json, res/models/yourmodel.json, and res/run/yourhp.json on GPU 0

      python3 -m main -d "yourdataset" -m "yourmodel" -r "yourhp" -g 0

    Every save_step epoch trained model and data which want to saved will be saved in the folder save/[yourmodel]-[yourhp]-[yourdataset]-[index of cross-validation].

    • If you want to testing epoch 10

      python3 -m main -d "yourdataset" -m "yourmodel" -r "yourhp" -g 0 -t 10

Contributing

Any kind of enhancement or contribution is welcomed.

License

The code is licensed with the MIT license.

MIT License Copyright (c) 2019-2020 Mingyuan Luo Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

简介

A high cohesion, low coupling, and plug-and-play project framework for PyTorch. 展开 收起
Python
MIT
取消

发行版

暂无发行版

贡献者

全部

近期动态

加载更多
不能加载更多了
Python
1
https://gitee.com/lmy0217/PyTorch-Project-Framework.git
git@gitee.com:lmy0217/PyTorch-Project-Framework.git
lmy0217
PyTorch-Project-Framework
PyTorch-Project-Framework
master

搜索帮助

14c37bed 8189591 565d56ea 8189591