diff --git a/docs/recommender/docs/source_en/online_learning.md b/docs/recommender/docs/source_en/online_learning.md index 28fae0ed3ab036165359125ff52976a8b65a3efa..baa158dcb4e66bf971ea411b83dcd31f4e89b699 100644 --- a/docs/recommender/docs/source_en/online_learning.md +++ b/docs/recommender/docs/source_en/online_learning.md @@ -1,3 +1,212 @@ # Online Learning - + + +## Overview + +The real-time update of the recommendation network model is one of the important technical indicators, and online learning can effectively improve the real-time update of the recommendation network model. + +Key differences between online learning and offline training: + +1. The dataset for online learning is streaming data with no definite dataset size, epoch, while the dataset for offline training has a definite data set size, epoch. +2. Online learning is in the form of a resident service, while the offline training exits tasks at the end of offline training. +3. Online learning requires collecting and storing training data, and driving the training process after a fixed amount of data has been collected or a fixed time window has elapsed. + +## Overall Architecture + +The user's streaming training data is pushed to kafka. MindPandas reads data from kafka and performs feature engineering transformation, and then writes to the feature storage engine. MindData reads data from the storage engine as training data for training. MindSpore, as a service resident, continuously receives data and performs training, with the overall process shown in the following figure: + +## Use Constraints + +- Python 3.8 and above is required to be installed. +- Currently only GPU training, Linux operating system are supported. + +## Python Package Dependencies + +mindpandas v0.1.0 + +mindspore_rec v0.2.0 + +kafka-python v2.0.2 + +## Example + +The following is an example of the process of online learning with the Criteo dataset training Wide&Deep. The sample code is located at [Online Learning](https://gitee.com/mindspore/recommender/tree/master/examples/online_learning). + +MindSpore Recommender provides a specialized algorithm model `RecModel` for online learning, which is combined with MindPandas, a real-time data source Kafka for data reading and feature processing, to implement a simple online learning process. +First define a custom dataset for real-time data processing, where the constructor parameter `receiver` is of type `DataReceiver` in MindPands for receiving real-time data, and `__getitem__` means read data one at a time. + +```python +class StreamingDataset: + def __init__(self, receiver): + self.data_ = [] + self.receiver_ = receiver + self.recv_data_cnt_ = 0 + + def __getitem__(self, item): + while not self.data_: + data = self.receiver_.recv() + self.recv_data_cnt_ += 1 + if data is not None: + self.data_ = data.tolist() + + last_row = self.data_.pop() + return np.array(last_row[0], dtype=np.int32), np.array(last_row[1], dtype=np.float32), np.array(last_row[2], dtype=np.float32) +``` + +Then the above custom dataset is encapsulated into the online dataset required by `RecModel`. + +```python +from mindpandas.channel import DataReceiver +from mindspore_rec import RecModel as Model + +receiver = DataReceiver(address=config.address, namespace=config.namespace, + dataset_name=config.dataset_name, shard_id=0) +stream_dataset = StreamingDataset(receiver) + +dataset = ds.GeneratorDataset(stream_dataset, column_names=["id", "weight", "label"]) +dataset = dataset.batch(config.batch_size) + +train_net, _ = GetWideDeepNet(config) +train_net.set_train() + +model = Model(train_net) +``` + +After configuring the export strategy for the model Checkpoint, start the online training process. + +```python +ckptconfig = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=5) +ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory="./ckpt", config=ckptconfig) + +model.online_train(dataset, callbacks=[TimeMonitor(1), callback, ckpoint_cb], dataset_sink_mode=True) +``` + +The following describes the start process for each module involved in the online learning process: + +### Downloading Kafka + +```bash +wget https://archive.apache.org/dist/kafka/3.2.0/kafka_2.13-3.2.0.tgz + +tar -xzf kafka_2.13-3.2.0.tar.gz + +cd kafka_2.13-3.2.0 +``` + +To install other versions, please refer to . + +### Starting kafka-zookeeper + +```bash +bin/zookeeper-server-start.sh config/zookeeper.properties +``` + +### Starting kafka-server + +Open another command terminal and start the kafka service. + +```bash +bin/kafka-server-start.sh config/server.properties +``` + +### Starting kafka_client + +kafka_client needs to be started only once, and you can use kafka to set the number of partitions corresponding to the topic. + +```bash +python kafka_client.py +``` + +### Start a Distributed Computing Engine + +```bash +yrctl start --master --address $MASTER_HOST_IP + +#Parameter description +--master: indicates that the current host is the master node. Non-master nodes do not need to specify the '--master' parameter +--address: ip of master node +``` + +### Starting Data producer + +producer is used to simulate an online learning scenario where a local criteo dataset is written to kafka for use by the consumer. The current sample uses multiple processes to read two files and write the data to kafka. + +```bash +python producer.py --file1=$CRITEO_DATASET_FILE_PATH --file2=$CRITEO_DATASET_FILE_PATH +#Parameter description +--file1: Path to the local disk for the criteo dataset +--file2: Path to the local disk for the criteo dataset +``` + +### Starting Data consumer + +```bash +python consumer.py --num_shards=$DEVICE_NUM --address=$LOCAL_HOST_IP --dataset_name=$DATASET_NAME + --max_dict=$PATH_TO_VAL_MAX_DICT --min_dict=$PATH_TO_CAT_TO_ID_DICT --map_dict=$PATH_TO_VAL_MAP_DICT + +#Parameter description +--num_shards: The number of device cards on the corresponding training side is set to 1 for single-card training and 8 for 8-card training. +--address: address of current sender +--dataset_name: dataset name +--namespace: channel name +--max_dict: Maximum dictionary of dense feature columns +--min_dict: Minimum dictionary of dense feature columns +--map_dict: Dictionary of sparse feature columns +``` + +The consumer needs 3 dataset-related files for feature engineering of criteo dataset: `all_val_max_dict.pkl`, `all_val_min_dict.pkl`, `cat2id_dict.pkl`, `$PATH_TO_VAL_MAX_DICT`, `$PATH _TO_CAT_TO_ID_DICT`, `$PATH_TO_VAL_MAP_DICT`, which are the absolute paths to these files on the environment, respectively. The specific production method of these 3 PKL files can be found in [process_data.py](https://gitee.com/mindspore/recommender/blob/master/datasets/criteo_1tb/process_data.py), switching the original criteo dataset to produce the corresponding .pkl files. + +### Starting Online Training + +For fhe yaml used by config, please refer to [default_config.yaml](https://gitee.com/mindspore/recommender/blob/master/examples/online_learning/default_config.yaml). + +Single-card traininf: + +```bash +python online_train.py --address=$LOCAL_HOST_IP --dataset_name=criteo + +#Parameter description: +--address: Local host ip. Receiving training data from MindPandas requires configuration +--dataset_name: Dataset name, consistent with the consumer module +``` + +Start with multi-card training MPI mode: + +```bash +bash mpirun_dist_online_train.sh [$RANK_SIZE] [$LOCAL_HOST_IP] + +#Parameter description: +RANK_SIZE:Number of multi-card training cards +LOCAL_HOST_IP:Local host ip for MindPandas to receive training data +``` + +Dynamic networking method to start multi-card training: + +```bash +bash run_dist_online_train.sh [$WORKER_NUM] [$SHED_HOST] [$SCHED_PORT] [$LOCAL_HOST_IP] + +#Parameter description: +WORKER_NUM:Number of multi-card training cards +SHED_HOST:IP of the Scheduler role required for MindSpore dynamic networking +SCHED_PORT:Port of the Scheduler role required for MindSpore dynamic networking +LOCAL_HOST_IP:Local host ip. Receiving training data from MindPandas requires configuration +``` + +When training is successfully started, the following log is output: + +epoch and step represent the number of epoch and step corresponding to the current training step, and wide_loss and deep_loss represent the training loss values in the wide&deep network. + +```text +epoch: 1, step: 1, wide_loss: 0.66100323, deep_loss: 0.72502613 +epoch: 1, step: 2, wide_loss: 0.46781272, deep_loss: 0.5293098 +epoch: 1, step: 3, wide_loss: 0.363207, deep_loss: 0.42204413 +epoch: 1, step: 4, wide_loss: 0.3051032, deep_loss: 0.36126155 +epoch: 1, step: 5, wide_loss: 0.24045062, deep_loss: 0.29395688 +epoch: 1, step: 6, wide_loss: 0.24296054, deep_loss: 0.29386574 +epoch: 1, step: 7, wide_loss: 0.20943595, deep_loss: 0.25780612 +epoch: 1, step: 8, wide_loss: 0.19562452, deep_loss: 0.24153553 +epoch: 1, step: 9, wide_loss: 0.16500896, deep_loss: 0.20854339 +epoch: 1, step: 10, wide_loss: 0.2188702, deep_loss: 0.26011512 +epoch: 1, step: 11, wide_loss: 0.14963374, deep_loss: 0.18867904 +``` diff --git a/docs/recommender/docs/source_zh_cn/online_learning.md b/docs/recommender/docs/source_zh_cn/online_learning.md index 9ce12324e090e294110574b2536cab5f1a5df702..3baedca18dc084b6f567af7ef1b44b9038628e63 100644 --- a/docs/recommender/docs/source_zh_cn/online_learning.md +++ b/docs/recommender/docs/source_zh_cn/online_learning.md @@ -96,7 +96,7 @@ tar -xzf kafka_2.13-3.2.0.tar.gz cd kafka_2.13-3.2.0 ``` -如需安装其他版本,请参照https://archive.apache.org/dist/kafka/ +如需安装其他版本,请参照。 ### 启动kafka-zookeeper @@ -157,11 +157,11 @@ python consumer.py --num_shards=$DEVICE_NUM --address=$LOCAL_HOST_IP --datase --map_dict: 稀疏特征列的字典 ``` -consumer为criteo数据集进行特征工程需要3个数据集相关文件: `all_val_max_dict.pkl`, `all_val_min_dict.pkl`, `cat2id_dict.pkl`, `$PATH_TO_VAL_MAX_DICT`, `$PATH_TO_CAT_TO_ID_DICT`, `$PATH_TO_VAL_MAP_DICT` 分别为这些文件在环境上的绝对路径。这3个pkl文件具体生产方法可以参考[process_data.py](https://gitee.com/mindspore/recommender/blob/master/datasets/criteo_1tb/process_data.py),对原始criteo数据集做转换生产对应的.pkl文件。 +consumer为criteo数据集进行特征工程需要3个数据集相关文件:`all_val_max_dict.pkl`、`all_val_min_dict.pkl`、`cat2id_dict.pkl`、`$PATH_TO_VAL_MAX_DICT`、`$PATH_TO_CAT_TO_ID_DICT`、`$PATH_TO_VAL_MAP_DICT` 分别为这些文件在环境上的绝对路径。这3个pkl文件具体生产方法可以参考[process_data.py](https://gitee.com/mindspore/recommender/blob/master/datasets/criteo_1tb/process_data.py),对原始criteo数据集做转换生产对应的.pkl文件。 ### 启动在线训练 -config采用yaml的形式,见[default_config.yaml](https://gitee.com/mindspore/recommender/blob/master/examples/online_learning/default_config.yaml) +config采用yaml的形式,见[default_config.yaml](https://gitee.com/mindspore/recommender/blob/master/examples/online_learning/default_config.yaml)。 单卡训练: diff --git a/tutorials/experts/source_en/parallel/other_features.rst b/tutorials/experts/source_en/parallel/other_features.rst index 579046e4bd089772bd2f91eafb10f45c116052c9..5ca5a68d1851595fada807cf6e8838d9d99ab2f7 100644 --- a/tutorials/experts/source_en/parallel/other_features.rst +++ b/tutorials/experts/source_en/parallel/other_features.rst @@ -10,6 +10,7 @@ Other Features sharding_propagation parameter_server_training + pynative_shard_function_parallel ms_operator `Sharding Propagation `__ @@ -68,8 +69,8 @@ dimension, and each device reads a part of the picture. This special performance supports splitting datasets into specific dimensions to meet training requirements in the field of large-format image processing. -Functional Operator Splitting ------------------------------ +`Functional Operator Splitting `__ +------------------------------------------------------------------------------------------------------------------------------------------- In dynamic graph mode, you specify that a part of the network structure executes in graph mode and performs various parallel operations. diff --git a/tutorials/experts/source_en/parallel/pynative_shard_function_parallel.md b/tutorials/experts/source_en/parallel/pynative_shard_function_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..8502234c7fe1994869f25a4d6255b36153a259c1 --- /dev/null +++ b/tutorials/experts/source_en/parallel/pynative_shard_function_parallel.md @@ -0,0 +1,306 @@ +# Functional Operator Sharding + + + +## Overview + +Dynamic diagram supports richer syntax and are more flexible to use, but currently MindSpore's dynamic diagram mode does not support the various features of automatic parallelism. Drawing on the design concept of Jax's pmap, we design the shard function to support specifying a certain part to be executed in graph mode and performing various parallel operations in dynamic graph mode. + +## Basic Principle + +In MindSpore dynamic graph mode, you can specify a segment to be compiled and executed in the graph mode by using the `@jit` decorator. During forward execution, the executed operators and subgraphs will be recorded, and after the forward execution, the whole graph obtained will be automatically differentiated to obtain the reverse graph, as shown in the following diagram: + +*Figure 1: Schematic diagram of the @jit decorator implementation* + +The Shard function follows this pattern, except that it can perform operator-level model parallelism in the session where the graph pattern is compiled and executed. + +## Operation Practices + +> You can download the full sample code here: +> +> . + +The directory structure is as follows: + +```text +└─sample_code + ├─shard_function_parallel + ├── rank_table_8pcs.json + ├── run_shard_function_example.sh + └── shard_function_example.py +``` + +The role of each of these files is as follows: + +- shard_function_example.py: The shard function sample code describes how to use the shard function to specify the part of the code to perform parallel execution. +- rank_table_8pcs.json: 8-card configuration file for RANK_TABLE_FILE. +- run_shard_function_example.sh: Start script for shard function example. + +### Interface Introduction + +```python +def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0): + return shard_fn(fn, in_strategy, out_strategy, device, level) +``` + +`in_strategy(tuple)`: Specify the shard strategy of the input `Tensor`. Each element is a tuple indicating the shard strategy corresponding to the input `Tensor`. The length of each tuple should be equal to the dimension of the corresponding `Tensor`, indicating how each dimension is sliced, and `None` can be passed in. The corresponding shard strategy will be automatically derived and generated. + +`out_strategy(None, tuple)`: Specify the shard strategy for output `Tensor`. The usage is the same as `in_strategy`, and the default value is None. This shard strategy is not enabled yet, will open later. In the deep learning model, the output strategy is replaced with data parallelism (False) and repeated computation (True), according to the value of full_batch. + +`parameter_plan(None, dict)`: Specify the shard strategy for each parameter. When passed into the dictionary, the key is the parameter name of str type, and the value is a 1-dimensional integer tuple indicating the corresponding shard strategy. The setting of this parameter will be skipped if the parameter name is wrong or the corresponding parameter has already set the shard strategy. Default: None, which means no setting. + +`device(string)`: Specify the device to execute on. The optional range is `Ascend`, `GPU` and `CPU`. Default: `Ascend`, currently not enabled, will open later. + +`level(int)`: Specify the search strategy for all operators. The shard strategy for the input and output `Tensor` is specified by the user, and the shard strategy for the rest of the operators will be obtained by the framework search. This parameter specifies the objective function when searching, and the optional range is 0, 1, 2, which represents maximizing the computational communication ratio, minimizing memory consumption, and maximizing operation speed, respectively. Default: 0, currently not enabled, will open later. + +### Importing Relevant Packages and Setting Execution Mode + +As mentioned earlier, the shard function executes a part of the dynamic graph schema in parallel with the operator-level model in graph mode, so during using the shard function you need to set the mode to PyNative. + +```python +import mindspore as ms +from mindspore.communication import init + + +ms.set_context(mode=ms.PYNATIVE_MODE) +init() +ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL, + search_mode="sharding_propagation", device_num=8) +``` + +> The current functional operator slicing is only supported when the parallel mode is "auto_parallel" and the policy search algorithm is "sharding_propagation". + +### Specifying Output Scheduling + +Support specifying output scheduling as data parallelism and double counting and control through the `dataset_strategy` or `full_batch` attribute in auto_parallel_context, which is set as follows: + +```python +# Set output scheduling via dataset_strategy, which is recommended +ms.set_auto_parallel_context(dataset_strategy="full_batch") # The dataset is not sliced and the output tensor of the shard is not sliced; (default configuration) +ms.set_auto_parallel_context(dataset_strategy="data_parallel") # The dataset is sliced in data parallelism and the output tensor of the shard is also sliced in data parallelism + +# Set output scheduling via full_batch, and this property will be deprecated soon +ms.set_auto_parallel_context(full_batch=True) # The dataset is not sliced and the output tensor of the shard is not sliced; (default configuration) +ms.set_auto_parallel_context(full_batch=False) # The dataset is sliced in data parallelism and the output tensor of the shard is also sliced in data parallelism +``` + +### Cell Uses Shard Function + +There are currently two ways to use the shard function. The following network is an example of how to use the shard function. + +```python +import mindspore.nn as nn +class BasicBlock(nn.Cell): + def __init__(self): + super(BasicBlock, self).__init__() + self.dense1 = nn.Dense(32, 32) + self.gelu = nn.GELU() + self.dense2 = nn.Dense(32, 32) + + def construct(self, x): + # two dimensional input x + x = self.dense1(x) + x = self.gelu(x) + x = self.dense2(x) + return x + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.block1 = BasicBlock() + self.block2 = BasicBlock() + self.block3 = BasicBlock() + + def construct(self, x): + # All three blocks are executed as PyNative mode. + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + return x +``` + +- Self-call via member method `shard` of Cell + + ```python + class Net1(Net): + def __init__(self): + super(Net1, self).__init__() + # slice input along the second axis and make output as data-parallel layout + self.block1.shard(in_strategy=((1, 8),), + parameter_plan={'self.block1.dense2.weight': (8, 1)}) + + def construct(self, x): + # block1 is executed as GRAPH. + x = self.block1(x) + # block2 and block3 are executed as PyNative mode. + x = self.block2(x) + x = self.block3(x) + return x + ``` + +- When using the functional interface `mindspore.shard`, since the return value of the `shard` function is a function, you cannot assign an instantiated class to the return value of `shard` when using the functional interface, because MindSpore does not support assigning class instances to other types + + ```python + class NetError(Net): + def __init__(self): + self.block1 = ms.shard(self.block1, in_strategy=((8, 1),), + parameter_plan={'self.block1.dense2.weight': (8, 1)}) + + def construct(self, x): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + return x + ``` + + An error may be reported after execution: + + ```text + TypeError: For 'Cell', the type of block1 should be cell, but got function. + ``` + + The correct use is as follows: + + ```python + class Net2(Net): + def __init__(self): + # set the return function of shard a different name with the Cell instance + self.block1_graph = ms.shard(self.block1, in_strategy=((8, 1),), + parameter_plan={'self.block1.dense2.weight': (8, 1)}) + self.block2.shard(in_strategy=((1, 8),)) + + def construct(self, x): + # block1 is executed as GRAPH with input sliced along the first dimension + x = self.block1_graph(x) + # block2 is executed as GRAPH as well. + x = self.block2(x) + # block3 is executed as PyNative mode. + x = self.block3(x) + return x + ``` + +### function Uses Shard Function + +- function can using ops.shard for shard function. Taking the matmul+bias_add+relu function as an example, the use is as follows: + + ```python + import numpy as np + + import mindspore as ms + import mindspore.ops as ops + from mindspore import Tensor + + ms.set_auto_parallel_context(dataset_strategy="full_batch") # Here is an example where the dataset is unsliced and the output tensor of the shard is unsliced + + def dense_relu(x, weight, bias): + x = ops.matmul(x, weight) + x = ops.bias_add(x, bias) + x = ops.relu(x) + return x + + x = Tensor(np.random.uniform(0, 1, (32, 128)), ms.float32) + weight = Tensor(np.random.uniform(0, 1, (128, 10)), ms.float32) + bias = Tensor(np.random.uniform(0, 1, (10,)), ms.float32) + + # Specify the shard strategy for x as (4, 2) and shard strategy of weight and bias as None via in_strategy, indicating automatic derivation generation. + result = ms.shard(dense_relu, in_strategy=((4, 2), None, None))(x, weight, bias) + print('result.shape:', result.shape) + ``` + + > It is noted that the initialization of parameters depends on the Cell parameter management, and when the fn type passed into the shard is function, its definition should not contain parameters (e.g. Conv2D and Dense). + +### Running the Code + +Currently MindSpore can pull up distributed parallel tasks by both multi-process start and mpirun. + +#### Starting via Multi-process + +When executed on Ascend and there is no sub-Group communication, distributed parallelism can be initiated by means of multi-process. + +> Model parallelism generates sub-Group communication when the number of parts that an object is cut is smaller than the number of cards or with at least two dimensions cut. +> +> That is, when started by this method, the communication generated by the model parallelism inside `shard` can only occur inside `world group`, so the specified shard strategy can currently only support slicing one dimension. + +The above code needs to be configured with distributed variables before it can run. The Ascend environment should configure with RANK_TABLE_FILE, RANK_ID and DEVICE_ID. Please refer to [here](https://www.mindspore.cn/tutorials/experts/en/master/parallel/train_ascend.html#configuring-distributed-environment-variables) for the configuration process. + +Environment variables related to Ascend distributed are: + +- RANK_TABLE_FILE: The path to the network information file. The rank_table_file file can be generated by using hccl_tools.py in the models code bin, which can be obtained from [here](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools). +- DEVICE_ID: The actual serial number of the current card on the machine. +- RANK_ID: The logical serial number of the current card. + +```bash +#!/bin/bash +set -e +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_shard_function_example.sh RANK_SIZE RANK_TABLE_FILE" +echo "For example: bash run_fusion_example.sh 8" +echo "It is better to use the absolute path." +echo "This example is expected to run on the Ascend environment." +echo "==============================================================================================================" +if [$# != 2] +then + echo "Usage: bash run_shard_function_example.sh RANK_SIZE RANK_TABLE_FILE" +exit 1 +fi +RANK_SIZE=$1 +RANK_TABLE_FILE=$2 +test_dist_8pcs() +{ + export RANK_TABLE_FILE=${RANK_TABLE_FILE} + export RANK_SIZE=8 +} +test_dist_${RANK_SIZE}pcs + +for((i=0;i<${RANK_SIZE};i++)) +do + rm -rf device$i + mkdir device$i + cp ./shard_function_example.py ./device$i + cd ./device$i + export DEVICE_ID=$i + export RANK_ID=$i + echo "start training for device $i" + env > env$i.log + python ./shard_function_example.py > train.log$i 2>&1 & + cd ../ +done +echo "The program launch succeed, the log is under device0/train.log0." +``` + +After configuring RANK_TABLE_FILE in the current directory, the following command requires the user to have 8 Ascend 910 devices. Run the command as follows: + +```bash +bash run_shard_function_example.sh 8 rank_table_8pcs.json +``` + +During execution, the framework automatically performs operator-level model parallelism for the input function of `shard`, and the parallel policy of each operator is obtained by the framework search. The whole process is not perceived by the user. The graph can be saved as follows: + +```python +ms.set_context(save_graphs=True) +``` + +In `step_parallel_end.ir`, you can see the specific parallel strategy for each operator. + +#### Starting via mpirun + +On Ascend and GPU, distributed parallelism can be started by means of mpirun. **This start method supports the creation of sub-Group communication**. Run the command as follows: + +```bash +mpirun -n ${DEVICE_NUM} --allow-run-as-root python ${PYTHON_SCRIPT_PATH} +``` + +Taking the sample code as an example to start 8-card, the corresponding command is: + +```bash +mpirun -n 8 --allow-run-as-root python shard_function_example.py +``` + +> It should be noted that when starting with mpirun on Ascend and a large number of sub-Group, the error of failure to create a communication domain is reported, as shown in the error message "Ascend collective Error: "HcclCommInitRootInfo failed. | Error Number 2". You can reduce the `max_device_memory` in `context` to reserve enough memory for hccl to create communication domains. + +## Usage Restrictions + +- The execution mode should be set to `PYNATIVE_MODE`, the parallelism configuration to `AUTO_PARALLEL`, and the `search_mode` to `sharding_propagation`. +- Support using nested `vmap`. When using, `shard` must be outside, `vmap` be inside. +- Not support using nested `shard`. diff --git a/tutorials/experts/source_zh_cn/parallel/pynative_shard_function_parallel.md b/tutorials/experts/source_zh_cn/parallel/pynative_shard_function_parallel.md index e647794d9227ba30e4e51dcf899adbcd9b53cd64..e0c486c6c7b78dd44784b42ae2768983da47677b 100644 --- a/tutorials/experts/source_zh_cn/parallel/pynative_shard_function_parallel.md +++ b/tutorials/experts/source_zh_cn/parallel/pynative_shard_function_parallel.md @@ -47,15 +47,15 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen return shard_fn(fn, in_strategy, out_strategy, device, level) ``` -`in_strategy(tuple)`: 指定输入`Tensor`的切分策略,每个元素为元组,表示对应输入`Tensor`的切分策略,每个元组的长度要与对应`Tensor`的维度相等,表示每个维度如何切分,可以传入`None`,对应的切分策略将自动推导生成。 +`in_strategy(tuple)`:指定输入`Tensor`的切分策略,每个元素为元组,表示对应输入`Tensor`的切分策略,每个元组的长度要与对应`Tensor`的维度相等,表示每个维度如何切分,可以传入`None`,对应的切分策略将自动推导生成。 -`out_strategy(None, tuple)`: 指定输出`Tensor`的切分策略,用法和`in_strategy`相同,默认值为None,目前尚未使能,后续会开放。在深度学习模型中,输出策略会根据full_batch的值,被替换为数据并行(False)和重复计算(True)。 +`out_strategy(None, tuple)`:指定输出`Tensor`的切分策略,用法和`in_strategy`相同,默认值为None,目前尚未使能,后续会开放。在深度学习模型中,输出策略会根据full_batch的值,被替换为数据并行(False)和重复计算(True)。 -`parameter_plan(None, dict)`: 指定各参数的切分策略,传入字典时,键是str类型的参数名,值是1维整数tuple表示相应的切分策略, 如果参数名错误或对应参数已经设置了切分策略,该参数的设置会被跳过。默认值:None,表示不设置。 +`parameter_plan(None, dict)`:指定各参数的切分策略,传入字典时,键是str类型的参数名,值是1维整数tuple表示相应的切分策略,如果参数名错误或对应参数已经设置了切分策略,该参数的设置会被跳过。默认值:None,表示不设置。 -`device(string)`: 指定执行的设备,可选范围`Ascend`、`GPU`和`CPU`,默认为`Ascend`,目前尚未使能,后续会开放。 +`device(string)`:指定执行的设备,可选范围`Ascend`、`GPU`和`CPU`,默认为`Ascend`,目前尚未使能,后续会开放。 -`level(int)`: 指定全部算子搜索策略,输入输出`Tensor`的切分策略由用户指定,其余算子的切分策略会由框架搜索得到,此参数指定搜索时的目标函数,可选范围为0、1、2,分别代表最大化计算通信比、内存消耗最小、最大化运行速度,默认为0,目前尚未使能,后续会开放。 +`level(int)`:指定全部算子搜索策略,输入输出`Tensor`的切分策略由用户指定,其余算子的切分策略会由框架搜索得到,此参数指定搜索时的目标函数,可选范围为0、1、2,分别代表最大化计算通信比、内存消耗最小、最大化运行速度,默认为0,目前尚未使能,后续会开放。 ### 导入相关包并设定执行模式 @@ -157,13 +157,13 @@ class Net(nn.Cell): return x ``` - 如此执行会遇到报错 + 如此执行会遇到报错: ```text TypeError: For 'Cell', the type of block1 should be cell, but got function. ``` - 正确使用方式如下 + 正确使用方式如下: ```python class Net2(Net): @@ -227,7 +227,7 @@ print('result.shape:', result.shape) 上述代码需要在配置分布式变量后才可以运行。Ascend环境需要配置RANK_TABLE_FILE、RANK_ID和DEVICE_ID。配置的过程请参考[此处](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_ascend.html#配置分布式环境变量)。 -Ascend分布式相关的环境变量有: +Ascend分布式相关的环境变量有: - RANK_TABLE_FILE:组网信息文件的路径。rank_table_file文件可以使用models代码仓中的hccl_tools.py生成,可以从[此处](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools)获取。 - DEVICE_ID:当前卡在机器上的实际序号。