Ai
4 Star 11 Fork 2

Gitee 极速下载/JAX

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/google/JAX
克隆/下载
topologies.py 2.03 KB
一键复制 编辑 原始数据 按行查看 历史
Peter Hawkins 提交于 2025-10-06 23:14 +08:00 . Rename XlaRuntimeError to JaxRuntimeError.
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import jax
from jax.experimental import mesh_utils
from jax._src.lib import _jax
from jax._src import xla_bridge as xb
Device = _jax.Device
class TopologyDescription:
def __init__(self, devices: list[Device]):
self.devices: list[Device] = devices
def get_attached_topology(platform=None) -> TopologyDescription:
return TopologyDescription(jax.devices(backend=platform))
def get_topology_desc(
topology_name: str = "", platform: str | None = None, **kwargs
) -> TopologyDescription:
if platform == "tpu" or platform is None:
return TopologyDescription(
xb.make_pjrt_tpu_topology(
topology_name, **kwargs
)._make_compile_only_devices()
)
try:
topology = xb.make_pjrt_topology(platform, topology_name, **kwargs)
return TopologyDescription(topology._make_compile_only_devices()) # pytype: disable=attribute-error
except _jax.JaxRuntimeError as e:
msg, *_ = e.args
if msg.startswith("UNIMPLEMENTED"):
raise NotImplementedError(msg) from e
else:
raise
# -- future mesh_utils --
def make_mesh(
topo: TopologyDescription,
mesh_shape: Sequence[int],
axis_names: tuple[str, ...],
*,
contiguous_submeshes: bool = False
) -> jax.sharding.Mesh:
devices = mesh_utils.create_device_mesh(
mesh_shape, list(topo.devices), contiguous_submeshes=contiguous_submeshes)
return jax.sharding.Mesh(devices, axis_names)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/JAX.git
git@gitee.com:mirrors/JAX.git
mirrors
JAX
JAX
main

搜索帮助