代码拉取完成,页面将自动刷新
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for synchronizing and communication across multiple hosts."""
from __future__ import annotations
from functools import partial, lru_cache
import zlib
import contextlib
from typing import Any
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import core
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src import array
from jax._src import sharding_impls
from jax._src.interpreters import pxla
from jax._src import pjit as pjit_lib
from jax._src import prng
from jax.sharding import PartitionSpec as P
from jax._src import distributed
from jax._src.util import safe_zip
from jax._src import xla_bridge
from jax._src.lib import xla_client
import numpy as np
def _psum(xs: Any) -> Any:
return jax.tree.map(lambda x: jnp.sum(x, dtype=x.dtype, axis=0), xs)
def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any:
"""Broadcast data from a source host (host 0 by default) to all other hosts.
Args:
in_tree: pytree of arrays - each array *must* have the same shape across the
hosts.
is_source: optional bool denoting whether the caller is the source. Only
'source host' will contribute the data for the broadcast. If None, then
host 0 is used.
Returns:
A pytree matching in_tree where the leaves now all contain the data from the
first host.
"""
if jax.process_count() == 1:
return jax.tree.map(np.asarray, in_tree)
if is_source is None:
is_source = jax.process_index() == 0
devices: np.ndarray = np.array(
jax.devices()).reshape(jax.process_count(), jax.local_device_count())
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
pspec = P('processes')
def pre_jit(x):
if is_source:
inp = x
else:
inp = np.zeros_like(x)
inp = np.expand_dims(inp, axis=0)
return host_local_array_to_global_array(inp, global_mesh, pspec)
def post_jit(x):
return jax.device_get(x.addressable_data(0))
in_tree = jax.tree.map(pre_jit, in_tree)
with jax.set_mesh(global_mesh):
out_tree = jax.jit(_psum, out_shardings=P())(in_tree)
return jax.tree.map(post_jit, out_tree)
# Identity function is at the top level so that `process_allgather` doesn't
# recompile on every invocation.
def _identity_fn(x):
return x
def _handle_array_process_allgather(inp, tiled):
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
if not tiled:
raise ValueError(
'Gathering global non-fully-addressable arrays only supports'
' tiled=True')
if isinstance(inp.sharding, sharding_impls.NamedSharding):
reps = inp.sharding.update(spec=P())
else:
reps = sharding_impls.GSPMDSharding.get_replicated(
inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind)
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
else:
# All inputs here will be fully addressable.
if jax.process_count() == 1:
out = np.asarray(inp)
return np.expand_dims(out, axis=0) if not tiled else out
devices = np.array(jax.devices()).reshape(jax.process_count(),
jax.local_device_count())
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
pspec = P('processes')
s = jax.sharding.NamedSharding(global_mesh, pspec)
host_np_arr = np.asarray(inp)
if host_np_arr.ndim == 0 or not tiled:
host_np_arr = np.expand_dims(host_np_arr, axis=0)
aval = core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping")
global_aval = pxla.mesh_local_to_global(
global_mesh, sharding_impls.get_array_mapping(pspec), aval)
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
global_arr = array.make_array_from_single_device_arrays(
global_aval.shape, s, bufs)
with jax.set_mesh(global_mesh):
out = jax.jit(_identity_fn, out_shardings=P())(global_arr)
return np.asarray(out.addressable_data(0))
def process_allgather(in_tree: Any, tiled: bool = False) -> Any:
"""Gather data from across processes.
Args:
in_tree: pytree of arrays - each array _must_ have the same shape across the
hosts.
tiled: Whether to stack or concat the output. Defaults to False i.e. stack
into a new positional axis at index 0.
Returns:
Pytrees of numpy arrays.
* If the input is a non-fully addressable jax.Array, then the data is
fully replicated.
* If the input is numpy array or fully addressable jax.Array, then the
output shape is dependent on the `tiled` argument.
If its False, then the output will be stacked else concatenated.
* If the input is a scalar, then the output will be stacked.
"""
def _pjit(inp):
return _handle_array_process_allgather(inp, tiled)
return jax.tree.map(_pjit, in_tree)
def sync_global_devices(name: str):
"""Creates a barrier across all hosts/devices."""
h = np.uint32(zlib.crc32(name.encode()))
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
def assert_equal(in_tree, fail_message: str = ''):
"""Verifies that all the hosts have the same tree of values."""
def concat_in_tree(x):
if isinstance(x, array.ArrayImpl) and not x.is_fully_addressable:
return np.asarray(x.addressable_data(0))
else:
x = np.asarray(x)
if x.ndim == 0:
x = np.expand_dims(x, axis=0)
return np.concat([x] * jax.process_count())
out = process_allgather(in_tree, tiled=True)
expected_in_tree = jax.tree.map(concat_in_tree, in_tree)
if not jax.tree.all(
jax.tree.map(lambda *x: np.all(np.equal(*x)), expected_in_tree, out)):
raise AssertionError(
f'{fail_message}. Expected: {out}; got: {in_tree}.')
def reached_preemption_sync_point(step_id: int) -> bool:
"""Determine whether all hosts have reached a preemption sync step.
When any host receives a preemption notice, the notice is propagated to all
hosts and triggers a synchronization protocol in the background. The
synchronization protocol calculates the maximum step ids from all hosts, and
uses the next step id (i.e., max + 1) as the safe step to save a checkpoint.
All hosts should continue training more steps until this method returns True,
indicating that the `step_id` is equal to the safe step and the hosts should
start saving a checkpoint.
To use this API, all hosts must start training from the same step and call it
at every training step. Example usage:
```
def should_save(step_id: int) -> bool:
# Should save an on-demand checkpoint for preemption
if multihost_utils.reached_preemption_sync_point(step_id):
return True
# Should save a regular checkpoint
return step_id - last_saved_checkpoint_step >= save_interval_steps
```
Preemption notice is provided by the cluster scheduler to notify the
application in advance before it gets evicted. By default, we use SIGTERM as
the signal for preemption notice.
TODO(b/230630494): Add instructions for customized preemption notice.
Returns:
A boolean indicating whether all hosts have reached a synchronization step
after some hosts are preempted.
Raises:
RuntimeError: if preemption sync manager has not been initialized.
"""
if distributed.global_state.client is None:
return False
sync_manager = distributed.global_state.preemption_sync_manager
if sync_manager is None:
raise RuntimeError(
"Preemption sync manager has not been initialized. Make sure the"
" 'jax_enable_preemption_service' config is enabled."
)
return sync_manager.reached_sync_point(step_id)
@lru_cache
def _flatten_pspecs(name, in_tree, pspecs_thunk):
return pjit_lib.flatten_axis_resources(
name, in_tree, pspecs_thunk(), tupled_args=True)
@lru_cache
def _local_to_global_aval(local_aval, mesh, pspec):
pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping")
return pxla.mesh_local_to_global(
mesh, sharding_impls.get_array_mapping(pspec), local_aval)
@lru_cache
def _global_to_local_aval(global_aval, mesh, pspec):
pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping")
return pxla.mesh_global_to_local(
mesh, sharding_impls.get_array_mapping(pspec), global_aval)
def host_local_array_to_global_array_impl(
arr: Any, *, global_mesh: jax.sharding.Mesh, pspec: Any):
if pspec is None:
raise ValueError(
'`None` is not a valid input to the pspecs argument. Please use '
'jax.sharding.PartitionSpec() if you wanted to replicate your input.')
# If the Array is not fully addressable i.e. not host local, return it.
if isinstance(arr, array.ArrayImpl) and not arr.is_fully_addressable:
return arr
if (isinstance(arr, array.ArrayImpl) and isinstance(
arr.sharding, jax.sharding.PmapSharding)) or not hasattr(arr, 'shape'):
arr = np.array(arr)
if arr.dtype == dtypes.float0:
arr = np.zeros(arr.shape, dtype=np.dtype(bool))
dtype = arr.dtype
if is_prng_key_array := isinstance(arr, prng.PRNGKeyArray):
arr = arr._base_array
local_sharding = jax.sharding.NamedSharding(global_mesh.local_mesh, pspec)
# If the input is a concrete jax.Array and the input array sharding
# matches the `local_sharding`, then there's no need to reshard and create
# copies.
if (isinstance(arr, array.ArrayImpl) and
arr.sharding.is_equivalent_to(local_sharding, arr.ndim)):
arrays = [x.data for x in arr.addressable_shards]
else:
arr = dtypes.canonicalize_value(arr)
arrays = [
arr[i] for i in local_sharding.devices_indices_map(arr.shape).values()
]
global_aval = _local_to_global_aval(
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
out = pxla.batched_device_put(
global_aval, jax.sharding.NamedSharding(global_mesh, pspec),
arrays, list(global_mesh.local_mesh.devices.flat))
if is_prng_key_array:
return prng.PRNGKeyArray(dtype._impl, out)
return out
def host_local_array_to_global_array(
local_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
r"""Converts a host local value to a globally sharded jax.Array.
This function takes host-local data (which might be different
across hosts), and populates a global array with this data, where each
device on each host, get the appropriate slice of the data according to
sharding defined by the global_mesh/pspects.
For example:
>>> global_mesh = jax.sharding.Mesh(jax.devices(), 'x')
>>> pspecs = jax.sharding.PartitionSpec('x')
>>> host_id = jax.process_index()
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # NB: assumes jax.local_device_count() divides 4. # doctest: +SKIP
The resulting array will have the shape (4 * num_processes) and will
have distributed value of: (0, 1, 2, 3, 0, 2, 4, 6, 0, 3, 6, 9, ... ),
where each slice np.arange(4) * host_id will be partitioned across the
corresponding host's devices.
Similarly:
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()), ['host', 'dev'])
>>> pspecs = jax.sharding.PartitionSpec('host')
>>> host_id = jax.process_index()
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # doctest: +SKIP
will create the same distributed value (0, 1, 2, 3, 0, 2, 4, 6, ...),
however each slice np.arange(4) * i will be *replicated* across corresponding
host devices.
On the other hand, if pspecs = PartitionSpec(), which means
replication across all axes, then this snippet:
>>> pspecs = jax.sharding.PartitionSpec()
>>> arr = host_local_array_to_global_array(np.arange(4), mesh, pspecs) # doctest: +SKIP
will have the shape (4,) and the value (0, 1, 2, 3) will be replicated
across all hosts and devices.
It is an undefined behavior to have not identical local_inputs with pspec
indicating data replication.
You can use this function to transition to jax.Array. Using jax.Array with
pjit has the same semantics of using GDA with pjit i.e. all jax.Array
inputs to pjit should be globally shaped.
If you are currently passing host local values to pjit, you can use this
function to convert your host local values to global Arrays and then pass that
to pjit.
Example usage.
>>> from jax.experimental import multihost_utils # doctest: +SKIP
>>>
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
>>>
>>> with mesh: # doctest: +SKIP
>>> global_out = pjitted_fun(global_inputs) # doctest: +SKIP
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
Please note this function requires global mesh to be a continuous mesh, meaning
that devices that belong to each host should form a subcube in this mesh.
To move local data to global array with non-continuous mesh use
jax.make_array_from_callback or jax.make_array_from_single_device_arrays
instead.
Args:
local_inputs: A Pytree of host local values.
global_mesh: A jax.sharding.Mesh object. The mesh must be a contiguous mesh,
that is all hosts' devices must form a subcube in this mesh.
pspecs: A Pytree of jax.sharding.PartitionSpec's.
Returns:
A pytree of global arrays.
"""
flat_inps, in_tree = tree_flatten(local_inputs)
in_pspecs = _flatten_pspecs('input pspecs', in_tree,
pjit_lib.hashable_pytree(pspecs))
out_flat = [
host_local_array_to_global_array_p.bind(inp, global_mesh=global_mesh,
pspec=in_spec)
for inp, in_spec in safe_zip(flat_inps, in_pspecs)
]
return tree_unflatten(in_tree, out_flat)
host_local_array_to_global_array_p = core.Primitive('host_local_array_to_global_array')
host_local_array_to_global_array_p.def_impl(host_local_array_to_global_array_impl)
def ltg_abstract_eval(arr, *, global_mesh, pspec):
return _local_to_global_aval(
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
host_local_array_to_global_array_p.def_abstract_eval(ltg_abstract_eval)
ad.deflinear2(host_local_array_to_global_array_p,
lambda ct, _, **params: (
host_local_array_to_global_array_p.bind(ct, **params),))
def ltg_batcher(insert_axis, axis_data, vals_in, dims_in, global_mesh, pspec):
del insert_axis
x, = vals_in
d, = dims_in
new_parts = None if axis_data.spmd_name is None else axis_data.spmd_name
new_pspec = list(pspec)
if d is not None:
new_pspec.insert(d, new_parts)
new_pspec = P(*new_pspec)
y = host_local_array_to_global_array_p.bind(
x, global_mesh=global_mesh, pspec=new_pspec)
return y, d
batching.fancy_primitive_batchers[host_local_array_to_global_array_p] = partial(
ltg_batcher, False)
def _ltg_lowering(ctx, x, *, global_mesh, pspec):
return [x]
mlir.register_lowering(host_local_array_to_global_array_p, _ltg_lowering)
def global_array_to_host_local_array_impl(
arr: Any, *, global_mesh: jax.sharding.Mesh, pspec: Any):
if pspec is None:
raise ValueError(
'`None` is not a valid input to the pspecs argument. Please use '
'jax.sharding.PartitionSpec() if you wanted to replicate your input.')
# If the Array is already fully addressable i.e. host local, return it.
if isinstance(arr, array.ArrayImpl) and arr.is_fully_addressable:
return arr
if not hasattr(arr, 'shape'):
arr = np.array(arr)
if arr.dtype == dtypes.float0:
arr = np.zeros(arr.shape, dtype=np.dtype(bool))
dtype = arr.dtype
if is_prng_key_array := isinstance(arr, prng.PRNGKeyArray):
arr = arr._base_array
global_sharding = jax.sharding.NamedSharding(global_mesh, pspec)
local_sharding = jax.sharding.NamedSharding(global_mesh.local_mesh, pspec)
local_aval = _global_to_local_aval(
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
if isinstance(arr, array.ArrayImpl):
if arr.sharding.is_equivalent_to(global_sharding, arr.ndim):
arrays = arr._arrays
else:
resharded_array = jax.device_put(arr, global_sharding)
arrays = resharded_array._arrays
out = array.ArrayImpl(local_aval, local_sharding, arrays, committed=True)
if is_prng_key_array:
return prng.PRNGKeyArray(dtype._impl, out)
return out
else:
# numpy array can show up here during AD.
arr = dtypes.canonicalize_value(arr)
arrays = [
arr[i] for i in local_sharding.devices_indices_map(arr.shape).values()
]
return pxla.batched_device_put(
local_aval, local_sharding, arrays,
list(global_mesh.local_mesh.devices.flat))
def global_array_to_host_local_array(
global_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
r"""Converts a global `jax.Array` to a host local `jax.Array`.
You can use this function to transition to `jax.Array`. Using `jax.Array` with
pjit has the same semantics of using GDA with pjit i.e. all `jax.Array`
inputs to pjit should be globally shaped and the output from pjit will also
be globally shaped jax.Array's
You can use this function to convert the globally shaped `jax.Array` output
from pjit to host local values again so that the transition to jax.Array can
be a mechanical change.
Example usage:
>>> from jax.experimental import multihost_utils # doctest: +SKIP
>>>
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
>>>
>>> with mesh: # doctest: +SKIP
... global_out = pjitted_fun(global_inputs) # doctest: +SKIP
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
Args:
global_inputs: A Pytree of global jax.Array's.
global_mesh: A :class:`jax.sharding.Mesh` object. The mesh must be contiguous
meaning all local devices of the host must form a subcube.
pspecs: A Pytree of :class:`jax.sharding.PartitionSpec` objects.
Returns:
A Pytree of host local arrays.
"""
flat_inps, out_tree = tree_flatten(global_inputs)
out_pspecs = _flatten_pspecs('output pspecs', out_tree,
pjit_lib.hashable_pytree(pspecs))
out_flat = [
global_array_to_host_local_array_p.bind(inp, global_mesh=global_mesh,
pspec=o)
for inp, o in safe_zip(flat_inps, out_pspecs)
]
return tree_unflatten(out_tree, out_flat)
global_array_to_host_local_array_p = core.Primitive('global_array_to_host_local_array')
global_array_to_host_local_array_p.def_impl(global_array_to_host_local_array_impl)
def gtl_abstract_eval(arr, *, global_mesh, pspec):
return _global_to_local_aval(
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
global_array_to_host_local_array_p.def_abstract_eval(gtl_abstract_eval)
ad.deflinear2(global_array_to_host_local_array_p,
lambda ct, _, **params: (
global_array_to_host_local_array_p.bind(ct, **params),))
batching.defvectorized(global_array_to_host_local_array_p)
def _gtl_lowering(ctx, x, *, global_mesh, pspec):
return [x]
mlir.register_lowering(global_array_to_host_local_array_p, _gtl_lowering)
def _live_devices(client, devices: list[xla_client.Device]) -> dict[xla_client.Device, int]:
"""Returns the subset of the provided devices that are live and healthy."""
process_ids = {d.process_index for d in devices}
if xla_bridge.process_index() not in process_ids:
# A process can only participate in an live_devices call if it hosts some of
# the provided devices.
raise ValueError('Provided devices do not have any local devices.')
live_process_ids = client.get_live_nodes(list(process_ids))
return {
d: live_process_ids[d.process_index]
for d in devices
if d.process_index in live_process_ids
}
class _LiveDevices:
"""A context manager for atomically running code on the set of live devices.
THIS API IS UNDER ACTIVE DEVELOPMENT AND IS NOT STABLE.
# Overview
`live_devices` is a low-level primitive that can be used to make
multi-controller JAX programs fault tolerant. A multi-controller JAX program
runs across many devices, and the machines that host these devices might fail.
`live_devices` is a context manager that yields the current set of healthy
devices, allowing you to run JAX code on the healthy devices while ignoring
the failed ones.
Concretely, `live_devices` is a context manager. You provide it the set of
devices you are interested in, and it yields the subset of these devices that
are live. In the body of the `with` statement, you can execute arbitrary JAX
code using the set of live devices.
# Example Usage
try:
with jax.live_devices(jax.devices()) as devices:
# Run JAX code here with devices.
pass
except:
# A device died while executing the with statement above.
pass
else:
# The with statement executed successfully.
pass
# Barrier Semantics
It's important that every process agrees on which devices are live to avoid
the processes' behavior from diverging. For example, imagine a set of
processes trying to run an AllGather, but they all disagree on which devices
should be participating in the AllGather. This is buggy.
To ensure that every process agrees on the set of live devices, the
`live_devices` context manager has barrier-like semantics. Consider an
invocation `with live_devices(devices)` where `devices` includes devices
across a set of processes P. The invocation acts as a barrier, waiting for
every process in P to call `with live_devices(devices)`. Afterwards,
`live_devices` returns the same set of live devices `A` to all the processes
in P. This ensures that every process agrees on the set of live devices.
`live_devices` does not actually act as a barrier for *every* process in P
because some processes in P might have failed. Instead, the `live_devices`
function waits only for the processes with a device in the returned set of
live devices A.
# An Example
Imagine we have four processes, each with two devices:
Process A: Devices 1 and 2
Process B: Devices 3 and 4
Process C: Devices 5 and 6
Process D: Devices 7 and 8
Further imagine that process D fails and that every process calls `with
live_devices(jax.devices())`. The invocation returns devices 1, 2, 3, 4, 5,
and 6. Because these devices are hosted by processes A, B, and C, the call to
`live_devices` acts as a barrier across processes A, B, and C. Process D,
which failed, is ignored.
# Atomicity
`live_devices` also provides the following transaction-like atomicity
property. When a process exits the body of a `with jax.live_devices(...) as
devices:` block, there are two possibilities.
1. All processes in `devices` successfully executed all code in the block
without any exceptions being raised.
2. All processes in `devices` did not successfully execute the code in the
block, and all the processes will raise an exception.
Consider the following code.
try:
with jax.live_devices(...) as devices:
pass
except:
pass # A
else:
pass # B
The atomicity property says that either every process with devices in
`devices` will enter the except branch (A) or every process with devices in
`devices` will enter the else branch (B). It is impossible for some processes
to enter A and others to enter B.
TODO: mwhittaker - Link to formal live devices semantics.
Args:
devices: A list of devices. The provided devices must include at least one
local device.
Returns:
The subset of the provided devices that are live and healthy.
Raises:
RuntimeError: If the distributed runtime was not initialized.
ValueError: If no local devices are provided.
"""
def __init__(self):
self.devices = None
@contextlib.contextmanager
def __call__(self, devices):
client = distributed.global_state.client
if client is None:
raise RuntimeError('Distributed JAX not initialized.')
if not devices:
# TODO(mwhittaker): Make devices optional. If it's not provided, use
# jax.devices() as a default.
raise ValueError('No devices provided.')
if self.devices is None:
self.devices = _live_devices(client, devices)
exception = None
try:
alive = list(self.devices.keys())
alive.sort(key=lambda d: d.id)
yield alive
except Exception as e:
exception = e
finally:
old_devices = self.devices
new_devices = _live_devices(client, devices)
self.devices = new_devices
if exception:
raise exception
if not old_devices.items() <= new_devices.items():
raise ValueError(f'{old_devices} is not a subset of {new_devices}')
live_devices = _LiveDevices()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。