Ai
4 Star 11 Fork 2

Gitee 极速下载/JAX

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/google/JAX
克隆/下载
jaxpr_util.py 10.32 KB
一键复制 编辑 原始数据 按行查看 历史
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for the Jaxpr IR."""
from __future__ import annotations
from collections import Counter, defaultdict
from collections.abc import Callable
import gzip
import itertools
import json
import logging
import types
from typing import Any, Union
from collections.abc import Iterator
from jax._src import config
from jax._src import core
from jax._src import path
from jax._src import util
from jax._src import source_info_util
from jax._src.lib import xla_client
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
logger = logging.getLogger(__name__)
def _all_eqns(
jaxpr: core.Jaxpr, visited: set[core.Jaxpr] | None,
) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]:
for eqn in jaxpr.eqns:
yield (jaxpr, eqn)
for subjaxpr in core.subjaxprs(jaxpr):
if visited is None:
yield from _all_eqns(subjaxpr, visited)
elif subjaxpr not in visited:
visited.add(subjaxpr)
yield from _all_eqns(subjaxpr, visited)
def all_eqns(
jaxpr: core.Jaxpr, revisit_inner_jaxprs: bool = True
) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]:
yield from _all_eqns(jaxpr, None if revisit_inner_jaxprs else set())
def collect_eqns(jaxpr: core.Jaxpr, key: Callable):
d = defaultdict(list)
for _, eqn in all_eqns(jaxpr):
d[key(eqn)].append(eqn)
return dict(d)
def histogram(jaxpr: core.Jaxpr, key: Callable,
key_fmt: Callable = lambda x: x):
d = collect_eqns(jaxpr, key)
return {key_fmt(k): len(v) for k, v in d.items()}
def primitives(jaxpr: core.Jaxpr):
return histogram(jaxpr, lambda eqn: eqn.primitive.name)
def primitives_by_source(jaxpr: core.Jaxpr):
def key(eqn):
src = source_info_util.summarize(eqn.source_info)
return (eqn.primitive.name, src)
return histogram(jaxpr, key, ' @ '.join)
def primitives_by_shape(jaxpr: core.Jaxpr):
def shape_fmt(var):
return '*' if isinstance(var, core.DropVar) else var.aval.str_short()
def key(eqn):
return (eqn.primitive.name, ' '.join(map(shape_fmt, eqn.outvars)))
return histogram(jaxpr, key, ' :: '.join)
def source_locations(jaxpr: core.Jaxpr):
def key(eqn):
return source_info_util.summarize(eqn.source_info)
return histogram(jaxpr, key)
MaybeEqn = Union[core.JaxprEqn, None]
def var_defs_and_refs(jaxpr: core.Jaxpr):
defs: dict[core.Var, MaybeEqn] = {}
refs: dict[core.Var, list[MaybeEqn]] = {}
def read(a: core.Atom, eqn: MaybeEqn):
if not isinstance(a, core.Literal):
assert a in defs, a
assert a in refs, a
refs[a].append(eqn)
def write(v: core.Var, eqn: MaybeEqn):
assert v not in defs, v
assert v not in refs, v
if not isinstance(v, core.DropVar):
defs[v] = eqn
refs[v] = []
for v in jaxpr.constvars:
write(v, None)
for v in jaxpr.invars:
write(v, None)
for eqn in jaxpr.eqns:
for a in eqn.invars:
read(a, eqn)
for v in eqn.outvars:
write(v, eqn)
for a in jaxpr.outvars:
read(a, None)
res = [(v, defs[v], refs[v]) for v in defs]
subs = map(var_defs_and_refs, core.subjaxprs(jaxpr))
return [(jaxpr, res), *subs] if subs else (jaxpr, res)
def vars_by_fanout(jaxpr: core.Jaxpr):
def fmt_key(var, eqn):
if eqn is None:
return f'{var} <- invar'
else:
src = source_info_util.summarize(eqn.source_info)
return f'{var} <- {eqn.primitive.name} @ {src}'
def hist(jaxpr, reads):
return {fmt_key(var, var_def): len(var_refs)
for var, var_def, var_refs in reads}
return [(j, hist(j, reads)) for j, reads in var_defs_and_refs(jaxpr)] # pytype: disable=bad-unpacking
def print_histogram(histogram: dict[Any, int]):
count_width = max(len(str(v)) for v in histogram.values())
count_fmt = '{:>' + str(count_width) + 'd}'
pairs = [(v, k) for k, v in histogram.items()]
for count, name in sorted(pairs, reverse=True):
print(count_fmt.format(count), name)
DEFAULT_WORKSPACE_ROOT: str | None = None
def _strip_workspace_root(filename: str, workspace_root: str) -> str:
i = filename.rfind(workspace_root)
return filename[i+len(workspace_root):] if i >= 0 else filename
def _pprof_profile(
profile: dict[tuple[xla_client.Traceback | None, core.Primitive], int],
workspace_root: str | None = None,
) -> bytes:
"""Converts a profile into a compressed pprof protocol buffer.
The input profile is a map from (traceback, primitive) pairs to counts.
"""
s: defaultdict[str, int]
func: defaultdict[types.CodeType, int]
loc: defaultdict[tuple[types.CodeType, int], int]
s = defaultdict(itertools.count(1).__next__)
func = defaultdict(itertools.count(1).__next__)
loc = defaultdict(itertools.count(1).__next__)
s[""] = 0
primitive_key = s["primitive"]
samples = []
for (tb, primitive), count in profile.items():
if tb is None:
frames = []
else:
raw_frames = zip(*tb.raw_frames())
frames = [loc[(code, lasti)] for code, lasti in raw_frames
if source_info_util.is_user_filename(code.co_filename)]
samples.append({
"location_id": frames,
"value": [count],
"label": [{
"key": primitive_key,
"str": s[primitive.name]
}]
})
locations = [
{"id": loc_id,
"line": [{"function_id": func[code],
"line": xla_client.Traceback.code_addr2line(code, lasti)}]}
for (code, lasti), loc_id in loc.items()
]
functions = []
for code, func_id in func.items():
filename = code.co_filename
name = code.co_qualname
if workspace_root is not None:
filename = _strip_workspace_root(filename, workspace_root)
name = f"{filename.removesuffix('.py').replace('/', '.')}.{name}"
functions.append(
{"id": func_id,
"name": s[name],
"filename": s[filename],
"start_line": code.co_firstlineno}
)
sample_type = [{"type": s["equations"], "unit": s["count"]}]
# This is the JSON encoding of a pprof profile protocol buffer. See:
# https://github.com/google/pprof/blob/master/proto/profile.proto for a
# description of the format.
json_profile = json.dumps({
"string_table": list(s.keys()),
"location": locations,
"function": functions,
"sample_type": sample_type,
"sample": samples,
})
return gzip.compress(xla_client._xla.json_to_pprof_profile(json_profile))
def pprof_equation_profile(jaxpr: core.Jaxpr, *,
workspace_root: str | None = None) -> bytes:
"""Generates a pprof profile that maps jaxpr equations to Python stack traces.
By visualizing the profile using pprof, one can identify Python code that is
responsible for yielding large numbers of jaxpr equations.
Args:
jaxpr: a Jaxpr.
workspace_root: the root of the workspace. If specified, function names
will be fully qualified, with respect to the workspace root.
Returns:
A gzip-compressed pprof Profile protocol buffer, suitable for passing to
pprof tool for visualization.
"""
d = Counter(
(eqn.source_info.traceback, eqn.primitive)
for _, eqn in all_eqns(jaxpr, revisit_inner_jaxprs=False)
)
return _pprof_profile(d, workspace_root or DEFAULT_WORKSPACE_ROOT)
def eqns_using_var_with_invar_index(jaxpr: core.Jaxpr, invar: core.Var) -> Iterator[tuple[core.JaxprEqn, int]]:
"""Find all the equations which use invar and the positional index of its binder"""
for eqn in jaxpr.eqns:
for invar_index, eqn_var in enumerate(eqn.invars):
if eqn_var == invar:
yield eqn, invar_index
break # we found the var, no need to keep looking in this eqn
def jaxpr_and_binder_in_params(params, index: int) -> Iterator[tuple[core.Jaxpr, core.Var]]:
for val in params.values():
vals = val if isinstance(val, tuple) else (val,)
for v in vals:
if isinstance(v, core.Jaxpr):
if index >= len(v.invars):
raise RuntimeError(f"Failed to find index {index} in jaxpr.invars while building report")
yield v, v.invars[index]
elif isinstance(v, core.ClosedJaxpr):
if index >= len(v.jaxpr.invars):
raise RuntimeError(f"Failed to find index {index} in jaxpr.invars while building report")
yield v.jaxpr, v.jaxpr.invars[index]
def eqns_using_var(jaxpr: core.Jaxpr, invar: core.Var) -> Iterator[core.JaxprEqn]:
"""Find the leaf equations using a variable"""
# The complexity of this call is because the invar might originate from a nested jaxpr
for eqn, invar_index in eqns_using_var_with_invar_index(jaxpr, invar):
if (child_jaxprs_and_vars := tuple(jaxpr_and_binder_in_params(eqn.params, invar_index))):
for (jaxpr, invar) in child_jaxprs_and_vars:
yield from eqns_using_var(jaxpr, invar)
else:
# if the previous condition fails, there is no deeper jaxpr to explore =(
yield eqn
_jaxpr_id_counter = itertools.count()
def maybe_dump_jaxpr_to_file(
fun_name: str, jaxpr: core.Jaxpr
) -> str | None:
"""Maybe dumps the `jaxpr` to a file.
Dumps the jaxpr if JAX_DUMP_JAXPR_TO is defined.
Args:
fn: The name of the function whose jaxpr is being dumped.
jaxpr: The jaxpr to dump.
Returns:
The path to the file where the jaxpr was dumped, or None if no file was
dumped.
"""
if not (out_dir := path.make_jax_dump_dir(config.jax_dump_ir_to.value)):
return None
modes = config.jax_dump_ir_modes.value.split(",")
if "jaxpr" not in modes and "eqn_count_pprof" not in modes:
return None
id = next(_jaxpr_id_counter)
if "jaxpr" in modes:
logging.log(
logging.INFO, "Dumping jaxpr for %s to %s.", fun_name, out_dir
)
jaxpr_path = out_dir / f"jax_{id:06d}_{fun_name}.jaxpr.txt"
jaxpr_path.write_text(jaxpr.pretty_print())
if "eqn_count_pprof" in modes:
logging.log(
logging.INFO, "Dumping eqn count pprof for %s to %s.", fun_name, out_dir
)
eqn_prof_path = out_dir / f"jax_{id:06d}_{fun_name}.eqn_count_pprof"
eqn_prof_path.write_bytes(pprof_equation_profile(jaxpr))
return fun_name
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/JAX.git
git@gitee.com:mirrors/JAX.git
mirrors
JAX
JAX
main

搜索帮助