Ai
4 Star 11 Fork 2

Gitee 极速下载/JAX

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/google/JAX
克隆/下载
serialize_executable.py 4.52 KB
一键复制 编辑 原始数据 按行查看 历史
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pickling support for precompiled binaries."""
from __future__ import annotations
import pickle
import io
import jax
from jax._src.lib import xla_client as xc
from collections.abc import Sequence
def serialize(compiled: jax.stages.Compiled):
"""Serializes a compiled binary.
Because pytrees are not serializable, they are returned so that
the user can handle them properly.
"""
unloaded_executable = getattr(compiled._executable,
'_unloaded_executable', None)
if unloaded_executable is None:
raise ValueError("Compilation does not support serialization")
if getattr(unloaded_executable, 'mut', None) and unloaded_executable.mut.in_mut:
raise ValueError("can't serialize with a closed-over mutable array ref")
args_info_flat, in_tree = jax.tree_util.tree_flatten(compiled.args_info)
# TODO(necula): deal with constants in serialized executables
if compiled._params.const_args:
raise NotImplementedError("serialize_executables with const_args")
with io.BytesIO() as file:
_JaxPjrtPickler(file).dump(
(unloaded_executable, args_info_flat, compiled._no_kwargs))
return file.getvalue(), in_tree, compiled.out_tree
def deserialize_and_load(serialized,
in_tree,
out_tree,
backend: str | xc.Client | None = None,
execution_devices: Sequence[xc.Device] | None = None):
"""Constructs a jax.stages.Compiled from a serialized executable."""
if backend is None or isinstance(backend, str):
backend = jax.devices(backend)[0].client
if execution_devices is None:
execution_devices = backend.devices()
else:
device_backend = execution_devices[0].client
if device_backend != backend:
raise ValueError(
'Execution devices belong to a client other than `backend`. Got '
f'backend client: {(backend.platform, backend.platform_version)} and '
'execution devices client: '
f'{(device_backend.platform, device_backend.platform_version)}')
(unloaded_executable, args_info_flat,
no_kwargs) = _JaxPjrtUnpickler(
io.BytesIO(serialized), backend, execution_devices).load()
args_info = in_tree.unflatten(args_info_flat)
loaded_compiled_obj = unloaded_executable.load()
# TODO(necula): deal with constants in serialized executables
return jax.stages.Compiled(
loaded_compiled_obj, [], args_info, out_tree, no_kwargs=no_kwargs)
class _JaxPjrtPickler(pickle.Pickler):
device_types = (xc.Device,)
client_types = (xc.Client,)
def persistent_id(self, obj):
if isinstance(obj, xc.LoadedExecutable):
return ('exec', obj.client.serialize_executable(obj))
if isinstance(obj, xc._xla.Executable):
return ('exec', obj.serialize())
if isinstance(obj, self.device_types):
return ('device', obj.id)
if isinstance(obj, self.client_types):
return ('client',)
class _JaxPjrtUnpickler(pickle.Unpickler):
def __init__(self, file, backend, execution_devices=None):
super().__init__(file)
self.backend = backend
if execution_devices is None:
execution_devices = backend.devices()
else:
device_backend = execution_devices[0].client
if device_backend != backend:
raise ValueError(
'Execution devices belong to a client other than `backend`. Got '
f'backend client: {(backend.platform, backend.platform_version)} '
'and execution devices client: '
f'{(device_backend.platform, device_backend.platform_version)}')
self.devices_by_id = {d.id: d for d in execution_devices}
self.execution_devices = xc.DeviceList(tuple(execution_devices))
def persistent_load(self, pid):
if pid[0] == 'exec':
return self.backend.deserialize_executable(
pid[1], executable_devices=self.execution_devices)
if pid[0] == 'device':
return self.devices_by_id[pid[1]]
if pid[0] == 'client':
return self.backend
raise pickle.UnpicklingError
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/JAX.git
git@gitee.com:mirrors/JAX.git
mirrors
JAX
JAX
main

搜索帮助