代码拉取完成,页面将自动刷新
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable
import functools
import os
import traceback
import types
from typing import Any, TypeVar, cast
from jax._src import config
from jax._src import util
from jax._src.lib import _jax
C = TypeVar("C", bound=Callable[..., Any])
_exclude_paths: list[str] = []
def register_exclusion(path: str):
_exclude_paths.append(path)
# TODO(nbasile): Remove hasattr checks after jaxlib 0.8.1 release
if hasattr(_jax, "add_exclude_path"):
_jax.add_exclude_path(path)
register_exclusion(__file__)
register_exclusion(util.__file__)
_jax_message_append = (
'The stack trace below excludes JAX-internal frames.\n'
'The preceding is the original exception that occurred, unmodified.\n'
'\n--------------------')
def _path_starts_with(path: str, path_prefix: str) -> bool:
path = os.path.abspath(path)
path_prefix = os.path.abspath(path_prefix)
try:
common = os.path.commonpath([path, path_prefix])
except ValueError:
# path and path_prefix are both absolute, the only case will raise a
# ValueError is different drives.
# https://docs.python.org/3/library/os.path.html#os.path.commonpath
return False
try:
return common == path_prefix or os.path.samefile(common, path_prefix)
except OSError:
# One of the paths may not exist.
return False
def include_frame(f: types.FrameType) -> bool:
return include_filename(f.f_code.co_filename)
def include_filename(filename: str) -> bool:
return not any(_path_starts_with(filename, path) for path in _exclude_paths)
# When scanning stack traces, we might encounter frames from cpython that are
# removed from printed stack traces, such as frames from parts of importlib. We
# ignore these frames heuristically based on source and name match.
def _ignore_known_hidden_frame(f: types.FrameType) -> bool:
return 'importlib._bootstrap' in f.f_code.co_filename
def _add_tracebackhide_to_hidden_frames(tb: types.TracebackType):
for f, _lineno in traceback.walk_tb(tb):
if not include_frame(f) and not _is_reraiser_frame(f):
f.f_locals["__tracebackhide__"] = True
def filter_traceback(tb: types.TracebackType) -> types.TracebackType | None:
out = None
# Scan the traceback and collect relevant frames.
frames = list(traceback.walk_tb(tb))
for f, lineno in reversed(frames):
if include_frame(f):
out = types.TracebackType(out, f, f.f_lasti, lineno)
return out
def _add_call_stack_frames(tb: types.TracebackType) -> types.TracebackType:
# Continue up the call stack.
#
# We would like to avoid stepping too far up, e.g. past the exec/eval point of
# a REPL such as IPython. To that end, we stop past the first contiguous bunch
# of module-level frames, if we reach any such frames at all. This is a
# heuristic that might stop in advance of the REPL boundary. For example, if
# the call stack includes module-level frames from the current module A, and
# the current module A was imported from within a function F elsewhere, then
# the stack trace we produce will be truncated at F's frame.
out = tb
reached_module_level = False
for f, lineno in traceback.walk_stack(tb.tb_frame):
if _ignore_known_hidden_frame(f):
continue
if reached_module_level and f.f_code.co_name != '<module>':
break
if include_frame(f):
out = types.TracebackType(out, f, f.f_lasti, lineno)
if f.f_code.co_name == '<module>':
reached_module_level = True
return out
def _is_reraiser_frame(f: traceback.FrameSummary | types.FrameType) -> bool:
if isinstance(f, traceback.FrameSummary):
filename, name = f.filename, f.name
else:
filename, name = f.f_code.co_filename, f.f_code.co_name
return filename == __file__ and name == 'reraise_with_filtered_traceback'
def _is_under_reraiser(e: BaseException) -> bool:
if e.__traceback__ is None:
return False
tb = traceback.extract_stack(e.__traceback__.tb_frame)
return any(_is_reraiser_frame(f) for f in tb[:-1])
def format_exception_only(e: BaseException) -> str:
return ''.join(traceback.format_exception_only(type(e), e)).strip()
class UnfilteredStackTrace(Exception): pass
_simplified_tb_msg = ("For simplicity, JAX has removed its internal frames from the "
"traceback of the following exception. Set "
"JAX_TRACEBACK_FILTERING=off to include these.")
class SimplifiedTraceback(Exception):
def __str__(self):
return _simplified_tb_msg
SimplifiedTraceback.__module__ = "jax.errors"
def _running_under_ipython() -> bool:
"""Returns true if we appear to be in an IPython session."""
try:
get_ipython() # type: ignore
return True
except NameError:
return False
def _ipython_supports_tracebackhide() -> bool:
"""Returns true if the IPython version supports __tracebackhide__."""
import IPython # pytype: disable=import-error
return IPython.version_info[:2] >= (7, 17)
def _filtering_mode() -> str:
mode = config.traceback_filtering.value
if mode is None or mode == "auto":
if (_running_under_ipython() and _ipython_supports_tracebackhide()):
mode = "tracebackhide"
else:
mode = "quiet_remove_frames"
return mode
def api_boundary(
fun: C, *,
repro_api_name: str | None = None,
repro_user_func: bool = False) -> C:
'''Wraps ``fun`` to form a boundary for filtering exception tracebacks.
When an exception occurs below ``fun``, this appends to it a custom
``__cause__`` that carries a filtered traceback. The traceback imitates the
stack trace of the original exception, but with JAX-internal frames removed.
This boundary annotation works in composition with itself. The topmost frame
corresponding to an :func:`~api_boundary` is the one below which stack traces
are filtered. In other words, if ``api_boundary(f)`` calls
``api_boundary(g)``, directly or indirectly, the filtered stack trace provided
is the same as if ``api_boundary(f)`` were to simply call ``g`` instead.
This annotation is primarily useful in wrapping functions output by JAX's
transformations. For example, consider ``g = jax.jit(f)``. When ``g`` is
called, JAX's JIT compilation machinery is invoked, which in turn calls ``f``
in order to trace and translate it. If the function ``f`` raises an exception,
the stack unwinds through JAX's JIT internals up to the original call site of
``g``. Because the function returned by :func:`~jax.jit` is annotated as an
:func:`~api_boundary`, such an exception is accompanied by an additional
traceback that excludes the frames specific to JAX's implementation.
For the "repro" kwargs, see the comments for `repro.boundary`.
'''
@functools.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
return fun(*args, **kwargs)
except Exception as e:
mode = _filtering_mode()
if _is_under_reraiser(e) or mode == "off":
raise
if mode == "tracebackhide":
_add_tracebackhide_to_hidden_frames(e.__traceback__)
raise
tb = e.__traceback__
try:
e.with_traceback(filter_traceback(tb))
if mode == "quiet_remove_frames":
e.add_note("--------------------\n" + _simplified_tb_msg)
else:
if mode == "remove_frames":
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
jax_error = UnfilteredStackTrace(msg)
jax_error.with_traceback(_add_call_stack_frames(tb))
else:
raise ValueError(f"JAX_TRACEBACK_FILTERING={mode} is not a valid value.")
jax_error.__cause__ = e.__cause__
jax_error.__context__ = e.__context__
jax_error.__suppress_context__ = e.__suppress_context__
e.__cause__ = jax_error
e.__context__ = None
del jax_error
raise
finally:
del mode, tb
if (repro_api_name or repro_user_func) and repro:
reraise_with_filtered_traceback = repro.boundary(
reraise_with_filtered_traceback, api_name=repro_api_name,
is_user=repro_user_func)
return cast(C, reraise_with_filtered_traceback)
try:
# TODO: import from the final location
from jax._src import repro # type: ignore
repro_is_enabled = repro.is_enabled
except ImportError:
repro = None # type: ignore
def repro_is_enabled(): return False # type: ignore
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。