Ai
4 Star 11 Fork 2

Gitee 极速下载/JAX

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/google/JAX
克隆/下载
batching.py 8.06 KB
一键复制 编辑 原始数据 按行查看 历史
Dougal 提交于 2025-11-24 06:47 +08:00 . Remove dynamic shapes. Dead weight at this point.
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.interpreters import batching as _src_batching
from jax._src.interpreters.batching import (
axis_primitive_batchers as axis_primitive_batchers,
bdim_at_front as bdim_at_front,
broadcast as broadcast,
defbroadcasting as defbroadcasting,
defreducer as defreducer,
defvectorized as defvectorized,
fancy_primitive_batchers as fancy_primitive_batchers,
not_mapped as not_mapped,
primitive_batchers as primitive_batchers,
register_vmappable as register_vmappable,
unregister_vmappable as unregister_vmappable,
)
_deprecations = {
# Deprecated for JAX v0.7.1; finalize in JAX v0.9.0.
"AxisSize": (
"jax.interpreters.batching.AxisSize is deprecated.",
_src_batching.AxisSize,
),
"Array": (
"jax.interpreters.batching.Array is deprecated. Use jax.Array directly.",
_src_batching.Array,
),
"BatchTrace": (
"jax.interpreters.batching.BatchTrace is deprecated.",
_src_batching.BatchTrace,
),
"BatchTracer": (
"jax.interpreters.batching.BatchTracer is deprecated.",
_src_batching.BatchTracer,
),
"BatchingRule": (
"jax.interpreters.batching.BatchingRule is deprecated.",
_src_batching.BatchingRule,
),
"Elt": (
"jax.interpreters.batching.Elt is deprecated.",
_src_batching.Elt,
),
"FromEltHandler": (
"jax.interpreters.batching.FromEltHandler is deprecated.",
_src_batching.FromEltHandler,
),
"GetIdx": (
"jax.interpreters.batching.GetIdx is deprecated.",
_src_batching.GetIdx,
),
"MakeIotaHandler": (
"jax.interpreters.batching.MakeIotaHandler is deprecated.",
_src_batching.MakeIotaHandler,
),
"MapSpec": (
"jax.interpreters.batching.MapSpec is deprecated.",
_src_batching.MapSpec,
),
"NotMapped": (
"jax.interpreters.batching.NotMapped is deprecated.",
_src_batching.NotMapped,
),
"ToEltHandler": (
"jax.interpreters.batching.ToEltHandler is deprecated.",
_src_batching.ToEltHandler,
),
"Vmappable": (
"jax.interpreters.batching.Vmappable is deprecated.",
_src_batching.Vmappable,
),
"Zeros": (
"jax.interpreters.batching.Zero is deprecated. Use jax.interpreters.ad.Zero.",
_src_batching.Zero,
),
"ZeroIfMapped": (
"jax.interpreters.batching.ZeroIfMapped is deprecated. It is an internal type.",
_src_batching.ZeroIfMapped,
),
"batch": (
"jax.interpreters.batching.batch is deprecated. It is an internal API.",
_src_batching.batch,
),
"batch_custom_jvp_subtrace": (
"jax.interpreters.batching.batch_custom_jvp_subtrace is deprecated. It is an internal API.",
_src_batching.batch_custom_jvp_subtrace,
),
"batch_custom_vjp_bwd": (
"jax.interpreters.batching.batch_custom_vjp_bwd is deprecated. It is an internal API.",
_src_batching.batch_custom_vjp_bwd,
),
"batch_jaxpr": (
"jax.interpreters.batching.batch_jaxpr is deprecated. It is an internal API.",
_src_batching.batch_jaxpr,
),
"batch_jaxpr_axes": (
"jax.interpreters.batching.batch_jaxpr_axes is deprecated. It is an internal API.",
_src_batching.batch_jaxpr_axes,
),
"batch_subtrace": (
"jax.interpreters.batching.batch_subtrace is deprecated. It is an internal API.",
_src_batching.batch_subtrace,
),
"broadcast_batcher": (
"jax.interpreters.batching.broadcast_batcher is deprecated. It is an internal API.",
_src_batching.broadcast_batcher,
),
"flatten_fun_for_vmap": (
"jax.interpreters.batching.flatten_fun_for_vmap is deprecated. It is an internal API.",
_src_batching.flatten_fun_for_vmap,
),
"from_elt": (
"jax.interpreters.batching.from_elt is deprecated. It is an internal API.",
_src_batching.from_elt,
),
"from_elt_handlers": (
"jax.interpreters.batching.from_elt_handlers is deprecated. It is an internal API.",
_src_batching.from_elt_handlers,
),
"is_vmappable": (
"jax.interpreters.batching.is_vmappable is deprecated. It is an internal API.",
_src_batching.is_vmappable,
),
"make_iota": (
"jax.interpreters.batching.make_iota is deprecated. It is an internal API.",
_src_batching.make_iota,
),
"make_iota_handlers": (
"jax.interpreters.batching.make_iota_handlers is deprecated. It is an internal API.",
_src_batching.make_iota_handlers,
),
"matchaxis": (
"jax.interpreters.batching.matchaxis is deprecated. It is an internal API.",
_src_batching.matchaxis,
),
"moveaxis": (
"jax.interpreters.batching.moveaxis is deprecated. Use jax.numpy.moveaxis.",
_src_batching.moveaxis,
),
"reducer_batcher": (
"jax.interpreters.batching.reducer_batcher is deprecated. It is an internal API.",
_src_batching.reducer_batcher,
),
"spec_types": (
"jax.interpreters.batching.spec_types is deprecated. It is an internal API.",
_src_batching.spec_types,
),
"to_elt": (
"jax.interpreters.batching.to_elt is deprecated. It is an internal API.",
_src_batching.to_elt,
),
"to_elt_handlers": (
"jax.interpreters.batching.to_elt_handlers is deprecated. It is an internal API.",
_src_batching.to_elt_handlers,
),
"vectorized_batcher": (
"jax.interpreters.batching.vectorized_batcher is deprecated. It is an internal API.",
_src_batching.vectorized_batcher,
),
"vmappables": (
"jax.interpreters.batching.vmappables is deprecated. It is an internal API.",
_src_batching.vmappables,
),
"vtile": (
"jax.interpreters.batching.vtile is deprecated. It is an internal API.",
_src_batching.vtile,
),
"zero_if_mapped": (
"jax.interpreters.batching.zero_if_mapped is deprecated. It is an internal API.",
_src_batching.zero_if_mapped,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
Array = _src_batching.Array
AxisSize = _src_batching.AxisSize
BatchTrace = _src_batching.BatchTrace
BatchTracer = _src_batching.BatchTracer
BatchingRule = _src_batching.BatchingRule
Elt = _src_batching.Elt
FromEltHandler = _src_batching.FromEltHandler
GetIdx = _src_batching.GetIdx
MakeIotaHandler = _src_batching.MakeIotaHandler
MapSpec = _src_batching.MapSpec
NotMapped = _src_batching.NotMapped
ToEltHandler = _src_batching.ToEltHandler
Vmappable = _src_batching.Vmappable
Zero = _src_batching.Zero
ZeroIfMapped = _src_batching.ZeroIfMapped
batch = _src_batching.batch
batch_custom_jvp_subtrace = _src_batching.batch_custom_jvp_subtrace
batch_custom_vjp_bwd = _src_batching.batch_custom_vjp_bwd
batch_jaxpr = _src_batching.batch_jaxpr
batch_jaxpr_axes = _src_batching.batch_jaxpr_axes
batch_subtrace = _src_batching.batch_subtrace
broadcast_batcher = _src_batching.broadcast_batcher
flatten_fun_for_vmap = _src_batching.flatten_fun_for_vmap
from_elt = _src_batching.from_elt
from_elt_handlers = _src_batching.from_elt_handlers
is_vmappable = _src_batching.is_vmappable
make_iota = _src_batching.make_iota
make_iota_handlers = _src_batching.make_iota_handlers
matchaxis = _src_batching.matchaxis
moveaxis = _src_batching.moveaxis
reducer_batcher = _src_batching.reducer_batcher
spec_types = _src_batching.spec_types
to_elt = _src_batching.to_elt
to_elt_handlers = _src_batching.to_elt_handlers
vectorized_batcher = _src_batching.vectorized_batcher
vmappables = _src_batching.vmappables
vtile = _src_batching.vtile
zero_if_mapped = _src_batching.zero_if_mapped
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/JAX.git
git@gitee.com:mirrors/JAX.git
mirrors
JAX
JAX
main

搜索帮助