Ai
4 Star 11 Fork 2

Gitee 极速下载/JAX

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/google/JAX
克隆/下载
abstract_arrays.py 6.02 KB
一键复制 编辑 原始数据 按行查看 历史
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import numpy as np
from jax._src import config
from jax._src import core
from jax._src import literals
from jax._src import dtypes
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
ShapedArray = core.ShapedArray
AbstractToken = core.AbstractToken
abstract_token = core.abstract_token
canonicalize_shape = core.canonicalize_shape
numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
dtypes.int4, np.int8, np.int16, np.int32, np.int64,
dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64,
np.complex64, np.complex128,
np.bool_, np.longlong, np.intc,
} | {np.dtype(dt).type for dt in dtypes._float_types}
if dtypes.int2 is not None:
assert dtypes.uint2 is not None
numpy_scalar_types.add(dtypes.int2)
numpy_scalar_types.add(dtypes.uint2)
array_types: set[type] = {literals.TypedNdArray, np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic
def masked_array_error(*args, **kwargs):
raise ValueError(
"numpy masked arrays are not supported as direct inputs to JAX functions."
" Use arr.filled() to convert the value to a standard numpy array.")
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype), sharding=None)
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
def _make_shaped_array_for_typed_ndarray(
x: literals.TypedNdArray,
) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtype, sharding=None, weak_type=x.weak_type)
core.pytype_aval_mappings[literals.TypedNdArray] = _make_shaped_array_for_typed_ndarray
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
shape = np.shape(x)
return ShapedArray(shape, dtypes.canonicalize_dtype(dtype), sharding=None)
for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
core.literalable_types.update(array_types)
core.literalable_types.add(literals.TypedNdArray)
_int32_min = np.iinfo(np.int32).min
_int32_max = np.iinfo(np.int32).max
_int64_min = np.iinfo(np.int64).min
_int64_max = np.iinfo(np.int64).max
# Note: all python scalar types are weak except bool, because bool only
# comes in a single width.
_bool_aval = ShapedArray((), dtype=np.dtype(bool))
_int32_aval = ShapedArray((), dtype=np.dtype(np.int32), weak_type=True)
_int64_aval = ShapedArray((), dtype=np.dtype(np.int64), weak_type=True)
_float32_aval = ShapedArray((), dtype=np.dtype(np.float32), weak_type=True)
_float64_aval = ShapedArray((), dtype=np.dtype(np.float64), weak_type=True)
_complex64_aval = ShapedArray((), dtype=np.dtype(np.complex64), weak_type=True)
_complex128_aval = ShapedArray((), dtype=np.dtype(np.complex128), weak_type=True)
core.pytype_aval_mappings[bool] = lambda v: _bool_aval
def _int_aval(value):
if config.enable_x64.value:
if value < _int64_min or value > _int64_max:
raise OverflowError(f"Python int {value} too large to convert to int64")
return _int64_aval
else:
if value < _int32_min or value > _int32_max:
raise OverflowError(f"Python int {value} too large to convert to int32")
return _int32_aval
core.pytype_aval_mappings[int] = _int_aval
_float_aval = lambda v: _float64_aval if config.enable_x64.value else _float32_aval
core.pytype_aval_mappings[float] = _float_aval
_complex_aval = lambda v: _complex128_aval if config.enable_x64.value else _complex64_aval
core.pytype_aval_mappings[complex] = _complex_aval
core.literalable_types.update(dtypes.python_scalar_types)
def _aval_for_typed_scalar(x):
return ShapedArray((), x.dtype, weak_type=True, sharding=None)
for t in literals.typed_scalar_types:
core.pytype_aval_mappings[t] = _aval_for_typed_scalar
core.literalable_types.update(literals.typed_scalar_types)
def _canonicalize_ndarray_dtype(x):
dtype = dtypes.canonicalize_dtype(x.dtype)
return literals.TypedNdArray(np.asarray(x, dtype), weak_type=False)
def _canonicalize_masked_array_dtype(x):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
dtypes.canonicalize_value_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types)
dtypes.canonicalize_value_handlers[literals.TypedNdArray] = lambda x: x
dtypes.canonicalize_value_handlers[np.ndarray] = _canonicalize_ndarray_dtype
dtypes.canonicalize_value_handlers[np.ma.MaskedArray] = _canonicalize_masked_array_dtype
def _canonicalize_python_scalar(literal_type, typ):
def canonicalize_scalar(x):
return literal_type(x, dtypes.scalar_type_to_dtype(typ, x)) # pytype: disable=wrong-arg-types
return canonicalize_scalar
dtypes.canonicalize_value_handlers[bool] = lambda x: x
dtypes.canonicalize_value_handlers[int] = _canonicalize_python_scalar(
literals.TypedInt, int)
dtypes.canonicalize_value_handlers[float] = _canonicalize_python_scalar(
literals.TypedFloat, float)
dtypes.canonicalize_value_handlers[complex] = _canonicalize_python_scalar(
literals.TypedComplex, complex)
dtypes.canonicalize_value_handlers[literals.TypedInt] = lambda x: x
dtypes.canonicalize_value_handlers[literals.TypedFloat] = lambda x: x
dtypes.canonicalize_value_handlers[literals.TypedComplex] = lambda x: x
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/JAX.git
git@gitee.com:mirrors/JAX.git
mirrors
JAX
JAX
main

搜索帮助