代码拉取完成,页面将自动刷新
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterator
import contextlib
import dataclasses
import functools
import itertools
import os.path
import re
import sysconfig
import threading
import types
from typing import NamedTuple
from jax._src.lib import xla_client
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
Traceback = xla_client.Traceback
class Frame(NamedTuple):
file_name: str
function_name: str
start_line: int
start_column: int
end_line: int
end_column: int
_exclude_paths: list[str] = [
# Attach the separator to make sure that .../jax does not end up matching
# .../jax_triton and other packages that might have a jax prefix.
os.path.dirname(os.path.dirname(__file__)) + os.sep,
# Also exclude stdlib as user frames. In a non-standard Python runtime,
# the following may be different.
sysconfig.get_path('stdlib'),
os.path.dirname(contextlib.__file__),
]
@functools.cache
def _exclude_path_regex() -> re.Pattern[str]:
# The regex below would not handle an empty set of exclusions correctly.
assert len(_exclude_paths) > 0
return re.compile('|'.join(f'^{re.escape(path)}' for path in _exclude_paths))
def register_exclusion(path: str):
_exclude_paths.append(path)
_exclude_path_regex.cache_clear()
is_user_filename.cache_clear()
# Explicit inclusions take priority over exclude paths.
_include_paths: list[str] = []
@functools.cache
def _include_path_regex() -> re.Pattern[str]:
patterns = [f'^{re.escape(path)}' for path in _include_paths]
patterns.append('_test.py$')
return re.compile('|'.join(patterns))
def register_inclusion(path: str):
_include_paths.append(path)
_include_path_regex.cache_clear()
is_user_filename.cache_clear()
class Scope(NamedTuple):
name: str
def wrap(self, stack: list[str]):
stack.append(self.name)
class Transform(NamedTuple):
name: str
def wrap(self, stack: list[str]):
if stack:
stack[-1] = f'{self.name}({stack[-1]})'
else:
stack.append(f'{self.name}()')
@dataclasses.dataclass(frozen=True)
class NameStack:
stack: tuple[Scope | Transform, ...] = ()
def extend(self, name: str) -> NameStack:
return NameStack((*self.stack, Scope(name)))
def transform(self, transform_name: str) -> NameStack:
return NameStack((*self.stack, Transform(transform_name)))
def __getitem__(self, idx: slice) -> NameStack:
return NameStack(self.stack[idx])
def __len__(self):
return len(self.stack)
def __add__(self, other: NameStack) -> NameStack:
return NameStack(self.stack + other.stack)
def __radd__(self, other: NameStack) -> NameStack:
return NameStack(other.stack + self.stack)
def __str__(self) -> str:
scope: list[str] = []
for elem in self.stack[::-1]:
elem.wrap(scope)
return '/'.join(reversed(scope))
def new_name_stack(name: str = '') -> NameStack:
name_stack = NameStack()
if name:
name_stack = name_stack.extend(name)
return name_stack
class SourceInfo:
traceback: Traceback | None
name_stack: NameStack
# It's slightly faster to use a class with __slots__ than a NamedTuple.
__slots__ = ['traceback', 'name_stack']
def __init__(self, traceback: Traceback | None, name_stack: NameStack):
self.traceback = traceback
self.name_stack = name_stack
def replace(self, *, traceback: Traceback | None = None,
name_stack: NameStack | None = None) -> SourceInfo:
return SourceInfo(
self.traceback if traceback is None else traceback,
self.name_stack if name_stack is None else name_stack
)
def new_source_info() -> SourceInfo:
return SourceInfo(None, NameStack())
@functools.cache
def is_user_filename(filename: str) -> bool:
"""Heuristic that guesses the identity of the user's code in a stack trace."""
return (_include_path_regex().search(filename) is not None
or _exclude_path_regex().search(filename) is None)
def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
loc = xla_client.Traceback.code_addr2location(code, lasti)
start_line, start_column, end_line, end_column = loc
return Frame(file_name=code.co_filename,
function_name=code.co_qualname,
start_line=start_line, start_column=start_column,
end_line=end_line, end_column=end_column)
def user_frames(traceback: Traceback | None) -> Iterator[Frame]:
"""Iterator over the user's frames, filtering jax-internal frames."""
# Guess the user's frame is the innermost frame not in the jax source tree or
# Python stdlib. We don't use traceback_util.path_starts_with because that
# incurs filesystem access, which may be slow; we call this function when
# e.g. adding source provenance annotations to XLA lowerings, so we don't
# want to incur the cost. We consider files that end with _test.py as user
# frames, to allow testing this mechanism from tests.
code, lasti = traceback.raw_frames() if traceback else ([], [])
return (raw_frame_to_frame(code[i], lasti[i]) for i in range(len(code))
if is_user_filename(code[i].co_filename))
@functools.lru_cache(maxsize=64)
def user_frame(traceback: Traceback | None) -> Frame | None:
return next(user_frames(traceback), None)
def _summarize_frame(frame: Frame) -> str:
if frame.start_column != 0:
return (f"{frame.file_name}:{frame.start_line}:{frame.start_column} "
f"({frame.function_name})")
else:
return f"{frame.file_name}:{frame.start_line} ({frame.function_name})"
def summarize(source_info: SourceInfo, num_frames=1) -> str:
frames = itertools.islice(user_frames(source_info.traceback), num_frames)
frame_strs = [_summarize_frame(frame) if frame else "unknown"
for frame in frames]
return '\n'.join(reversed(frame_strs))
class _SourceInfoContext(threading.local):
context: SourceInfo
def __init__(self):
self.context = new_source_info()
_source_info_context = _SourceInfoContext()
def current() -> SourceInfo:
source_info = _source_info_context.context
if not source_info.traceback:
source_info = source_info.replace(traceback=xla_client.Traceback.get_traceback())
return source_info
class JaxStackTraceBeforeTransformation(Exception): pass
_message = (
'The preceding stack trace is the source of the JAX operation that, once '
'transformed by JAX, triggered the following exception.\n'
'\n--------------------')
def has_user_context(e):
while e is not None:
if isinstance(e, JaxStackTraceBeforeTransformation):
return True
e = e.__cause__
return False
class UserContextManager:
__slots__ = ['traceback', 'name_stack', 'prev']
def __init__(self, traceback: Traceback | None, *,
name_stack: NameStack | None = None):
self.traceback = traceback
self.name_stack = name_stack
def __enter__(self):
self.prev = _source_info_context.context
_source_info_context.context = _source_info_context.context.replace(
traceback=self.traceback, name_stack=self.name_stack)
def __exit__(self, exc_type, exc_value, traceback):
_source_info_context.context = self.prev
if exc_type is None or exc_value is None:
return
if self.traceback is None or has_user_context(exc_value):
return
filtered_tb = traceback_util.filter_traceback(self.traceback.as_python_traceback())
if filtered_tb:
msg = traceback_util.format_exception_only(exc_value)
msg = f'{msg}\n\n{_message}'
exp = JaxStackTraceBeforeTransformation(msg).with_traceback(filtered_tb)
exp.__context__ = exc_value.__context__
exp.__cause__ = exc_value.__cause__
exp.__suppress_context__ = exc_value.__suppress_context__
exc_value.__context__ = None
exc_value.__cause__ = exp
user_context = UserContextManager
def current_name_stack() -> NameStack:
return _source_info_context.context.name_stack
class ExtendNameStackContextManager(contextlib.ContextDecorator):
__slots__ = ['name', 'prev']
def __init__(self, name: str):
self.name = name
def __enter__(self):
self.prev = prev = _source_info_context.context
name_stack = prev.name_stack.extend(self.name)
_source_info_context.context = prev.replace(name_stack=name_stack)
return name_stack
def __exit__(self, exc_type, exc_value, traceback):
_source_info_context.context = self.prev
extend_name_stack = ExtendNameStackContextManager
class SetNameStackContextManager(contextlib.ContextDecorator):
__slots__ = ['name_stack', 'prev']
def __init__(self, name_stack: NameStack):
self.name_stack = name_stack
def __enter__(self):
self.prev = prev = _source_info_context.context
_source_info_context.context = prev.replace(name_stack=self.name_stack)
def __exit__(self, exc_type, exc_value, traceback):
_source_info_context.context = self.prev
set_name_stack = SetNameStackContextManager
# TODO(mattjj,phawkins): figure out why the commented-out reset_name_stack
# implementation doesn't work. Luckily this context manager isn't called much so
# the performance shouldn't matter. See blame commit message for repro.
# reset_name_stack = lambda: SetNameStackContextManager(NameStack())
@contextlib.contextmanager
def reset_name_stack() -> Iterator[None]:
with set_name_stack(NameStack()):
yield
class TransformNameStackContextManager(contextlib.ContextDecorator):
__slots__ = ['name', 'prev']
def __init__(self, name: str):
self.name = name
def __enter__(self):
self.prev = prev = _source_info_context.context
name_stack = prev.name_stack.transform(self.name)
_source_info_context.context = prev.replace(name_stack=name_stack)
return name_stack
def __exit__(self, exc_type, exc_value, traceback):
_source_info_context.context = self.prev
transform_name_stack = TransformNameStackContextManager
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。