Ai
4 Star 11 Fork 2

Gitee 极速下载/JAX

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/google/JAX
克隆/下载
x64_context.py 2.51 KB
一键复制 编辑 原始数据 按行查看 历史
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Context managers for toggling X64 mode.
**Deprecated: use :func:`jax.enable_x64` instead.**
"""
from contextlib import contextmanager
from jax._src import config
@contextmanager
def _enable_x64(new_val: bool = True):
"""Experimental context manager to temporarily enable X64 mode.
.. warning::
This context manager is deprecated as of JAX v0.8.0, and will be removed in
JAX v0.9.0. Use :func:`jax.enable_x64` instead.
Usage::
>>> import jax
>>> x = np.arange(5, dtype='float64')
>>> with _enable_x64(True):
... print(jnp.asarray(x).dtype)
...
float64
See Also
--------
jax.experimental.disable_x64 : temporarily disable X64 mode.
"""
with config.enable_x64(new_val):
yield
@contextmanager
def _disable_x64():
"""Experimental context manager to temporarily disable X64 mode.
.. warning::
This context manager is deprecated as of JAX v0.8.0, and will be removed in
JAX v0.9.0. Use :func:`jax.enable_x64` instead.
Usage::
>>> x = np.arange(5, dtype='float64')
>>> with _disable_x64():
... print(jnp.asarray(x).dtype)
...
float32
See Also
--------
jax.experimental.enable_x64 : temporarily enable X64 mode.
"""
with config.enable_x64(False):
yield
_deprecations = {
# Added for v0.8.0
"disable_x64": (
("jax.experimental.x64_context.disable_x64 is deprecated in JAX v0.8.0 and will be removed"
" in JAX v0.9.0; use jax.enable_x64(False) instead."),
_disable_x64
),
"enable_x64": (
("jax.experimental.x64_context.enable_x64 is deprecated in JAX v0.8.0 and will be removed"
" in JAX v0.9.0; use jax.enable_x64(True) instead."),
_enable_x64
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
enable_x64 = _enable_x64
disable_x64 = _disable_x64
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/JAX.git
git@gitee.com:mirrors/JAX.git
mirrors
JAX
JAX
main

搜索帮助