Ai
4 Star 11 Fork 2

Gitee 极速下载/JAX

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/google/JAX
克隆/下载
core.py 138.75 KB
一键复制 编辑 原始数据 按行查看 历史
Yash Katariya 提交于 2025-12-17 13:08 +08:00 . Add shape checks to ct_check function too
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import Counter, defaultdict, deque, namedtuple
from collections.abc import (Callable, Collection, Hashable, Iterable, Iterator,
Sequence, MutableSet, MutableMapping)
from contextlib import contextmanager
from dataclasses import dataclass
import functools
from functools import partial, total_ordering
import gc
import inspect
import itertools as it
import math
import operator
import threading
import types
from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar,
overload, Union, TYPE_CHECKING)
import warnings
import weakref
import numpy as np
from jax._src import dtypes
from jax._src import config
from jax._src import effects
from jax._src import mesh as mesh_lib
from jax._src.mesh import AxisType
from jax._src.partition_spec import PartitionSpec as P
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from jax._src import linear_util as lu
from jax._src.tree_util import tree_map
from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
tuple_delete, cache,
HashableFunction, HashableWrapper, weakref_lru_cache,
partition_list, StrictABCMeta, foreach,
weakref_cache_key_types, set_module)
import jax._src.pretty_printer as pp
from jax._src.named_sharding import NamedSharding
from jax._src.sharding import Sharding
from jax._src.layout import Format, AutoLayout
from jax._src.memory import Space as MemorySpace
from jax._src.lib import _jax
from jax._src.lib import jax_jit
from jax._src.lib import jaxlib_extension_version
from jax._src.lib import xla_client
from jax._src import traceback_util
from jax._src.typing import Array, ArrayLike, DimSize, Shape
from jax._src import xla_metadata_lib
traceback_util.register_exclusion(__file__)
zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
config_ext = xla_client._xla.config
PyTree = Any
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.int_flag(
'jax_tracer_error_num_traceback_frames',
config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
help='Set the number of stack frames in JAX tracer error messages.'
)
def identity(x): return x
# -------------------- jaxprs --------------------
Effect = effects.Effect
Effects = effects.Effects
EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects
DebugInfo = lu.DebugInfo
InitialResultPaths = lu.InitialResultPaths
initial_result_paths = lu.initial_result_paths
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
'_effects', '_debug_info', '_is_high']
_constvars: list[Var]
_invars: list[Var]
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: DebugInfo
_is_high: bool
@property
def constvars(self) -> list[Var]:
return self._constvars
@property
def invars(self) -> list[Var]:
return self._invars
@property
def outvars(self) -> list[Atom]:
return self._outvars
@property
def eqns(self) -> list[JaxprEqn]:
return self._eqns
@property
def effects(self) -> Effects:
return self._effects
@property
def debug_info(self) -> DebugInfo:
return self._debug_info
@property
def is_high(self) -> bool:
return self._is_high
@property
def in_avals(self):
return [v.aval for v in self.invars]
@property
def in_aval_qdds(self) -> list[AbstractValue | AvalQDD]:
return [v.aval if v.initial_qdd is None else AvalQDD(v.aval, v.initial_qdd)
for v in self.invars]
@property
def final_aval_qdds(self) -> list[AbstractValue | AvalQDD]:
return [v.aval if v.final_qdd is None else AvalQDD(v.aval, v.final_qdd)
for v in self.invars]
@property
def out_avals(self):
return [v.aval for v in self.outvars]
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
# We want all calls to pass a DebugInfo object, but for backwards
# compatibility we have to allow calls when the debug_info
# is missing.
debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment]
is_high: bool = False,
):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
replaced with such variables while scalar constants are kept inline.
invars: list of input variables. Together, `constvars` and `invars` are
the inputs to the Jaxpr.
outvars: list of output atoms.
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
debug_info: debugging information.
"""
self._constvars = list(constvars)
self._invars = list(invars)
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
# TODO(https://github.com/jax-ml/jax/issues/26480)
debug_info = debug_info or lu._missing_debug_info("core.Jaxpr")
self._debug_info = debug_info.resolve_result_paths()
config.enable_checks.value and self._debug_info.assert_arg_names(len(invars))
config.enable_checks.value and self._debug_info.assert_result_paths(len(outvars))
self._is_high = is_high
def __str__(self):
return str(self.pretty_print())
__repr__ = __str__
def pretty_print(self, *, source_info=False, print_shapes=True,
custom_pp_eqn_rules=True, name_stack=False,
print_effects: bool = False, **kwargs):
doc = pp_toplevel_jaxpr(
self, source_info=source_info, print_shapes=print_shapes,
custom_pp_eqn_rules=custom_pp_eqn_rules, name_stack=name_stack,
print_effects=print_effects)
return doc.format(**kwargs)
def _repr_pretty_(self, p, cycle):
return p.text(self.pretty_print(use_color=True))
def replace(self, **kwargs):
debug_default = self.debug_info
if (kwargs.get('invars', self.invars) != self.invars or
kwargs.get('outvars', self.outvars) != self.outvars):
debug_default = debug_default.with_unknown_names()
jaxpr = Jaxpr(
constvars=kwargs.pop("constvars", self.constvars),
invars=kwargs.pop("invars", self.invars),
outvars=kwargs.pop("outvars", self.outvars),
eqns=kwargs.pop("eqns", self.eqns),
effects=kwargs.pop("effects", self.effects),
debug_info=kwargs.pop("debug_info", debug_default),
is_high=kwargs.pop("is_high", self.is_high),
)
if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}")
return jaxpr
weakref_cache_key_types.add(Jaxpr)
def join_effects(*effects: Effects) -> Effects:
return set().union(*effects) if effects else no_effects
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
for val in params.values():
vals = val if isinstance(val, tuple) else (val,)
for v in vals:
if isinstance(v, Jaxpr):
yield v
elif isinstance(v, ClosedJaxpr):
yield v.jaxpr
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
Does not descend recursively into the found subjaxprs.
"""
for eqn in jaxpr.eqns:
yield from jaxprs_in_params(eqn.params)
class ClosedJaxpr:
__slots__ = ['__weakref__', '_jaxpr', '_consts']
_jaxpr: Jaxpr
_consts: list[Any]
jaxpr = property(lambda self: self._jaxpr)
consts = property(lambda self: self._consts)
literals = consts
constvars = property(lambda self: self._jaxpr.constvars)
invars = property(lambda self: self._jaxpr.invars)
outvars = property(lambda self: self._jaxpr.outvars)
eqns = property(lambda self: self._jaxpr.eqns)
effects = property(lambda self: self._jaxpr.effects)
debug_info = property(lambda self: self._jaxpr.debug_info)
is_high = property(lambda self: self._jaxpr.is_high)
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
assert len(consts) == len(jaxpr.constvars)
# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable
self._jaxpr = jaxpr
self._consts = list(consts)
@property
def in_avals(self):
return [v.aval for v in self.invars]
@property
def in_aval_qdds(self) -> list[AbstractValue | AvalQDD]:
return [v.aval if v.initial_qdd is None else AvalQDD(v.aval, v.initial_qdd)
for v in self.invars]
@property
def final_aval_qdds(self) -> list[AbstractValue | AvalQDD]:
return [v.aval if v.final_qdd is None else AvalQDD(v.aval, v.final_qdd)
for v in self.invars]
@property
def out_avals(self):
return [v.aval for v in self.outvars]
def map_jaxpr(self, f):
return ClosedJaxpr(f(self.jaxpr), self.consts)
def replace(self, *, jaxpr=None, consts=None):
jaxpr = self.jaxpr if jaxpr is None else jaxpr
consts = self.consts if consts is None else consts
return ClosedJaxpr(jaxpr, consts)
def __str__(self): return str(self.jaxpr)
def __repr__(self): return repr(self.jaxpr)
def pretty_print(self, *, source_info=False, print_shapes=True,
name_stack=False, custom_pp_eqn_rules=True,
print_effects=False, **kwargs):
return self.jaxpr.pretty_print(
source_info=source_info,
print_shapes=print_shapes,
name_stack=name_stack,
custom_pp_eqn_rules=custom_pp_eqn_rules,
print_effects=print_effects,
**kwargs)
def _repr_pretty_(self, p, cycle):
return p.text(self.pretty_print(use_color=True))
weakref_cache_key_types.add(ClosedJaxpr)
@curry
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
# TODO(dougalm): remove this hack when we add contexts to jaxpr.
# debug_nans is sometimes disabled locally at the traceable level by ops that
# work with nans internally, like jnp.var. The right thing to do is to add
# contexts to our jaxpr representation so that we can capture these local
# context modifications. In the meantime, disabling the checks when we
# round-trip prevents those ops producing spurious errors.
with config.debug_nans(False):
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
# This context manager is fairly hot, because it is frequently called for every
# jaxpr equation.
# This context manager is implemented as a class with explicit __enter__ and
# __exit__ methods since a @contextlib.contextmanager is significantly slower.
# We also in effect fuse four other context managers into one, mostly to
# save allocations.
class JaxprEqnContextManager:
__slots__ = ['context', 'prev_compute_type', 'prev_threefry_partitionable',
'prev_xla_metadata', 'prev_abstract_mesh']
def __init__(self, context):
self.context = context
def __enter__(self):
self.prev_compute_type = config.compute_on_context_manager.swap_local(
self.context.compute_type
)
if (
self.prev_compute_type is not None
and self.prev_compute_type is not config_ext.unset
and self.context.compute_type != self.prev_compute_type
):
config.compute_on_context_manager.set_local(self.prev_compute_type)
raise NotImplementedError(
"Nesting `compute_on` with different compute types is not supported"
f" yet. Current compute_on type: {self.prev_compute_type}"
)
self.prev_threefry_partitionable = config.threefry_partitionable.swap_local(
self.context.threefry_partitionable
)
if self.context.xla_metadata:
self.prev_xla_metadata = config.xla_metadata_context_manager.get_local()
updated = xla_metadata_lib.update_metadata(
self.prev_xla_metadata, self.context.xla_metadata
)
config.xla_metadata_context_manager.set_local(updated)
self.prev_abstract_mesh = config.abstract_mesh_context_manager.swap_local(
self.context.cur_abstract_mesh
)
def __exit__(self, exc_type, exc_value, traceback):
config.compute_on_context_manager.set_local(self.prev_compute_type)
config.threefry_partitionable.set_local(self.prev_threefry_partitionable)
if self.context.xla_metadata:
config.xla_metadata_context_manager.set_local(self.prev_xla_metadata)
config.abstract_mesh_context_manager.set_local(self.prev_abstract_mesh)
class JaxprEqnContext:
__slots__ = ['compute_type', 'threefry_partitionable', 'xla_metadata',
'cur_abstract_mesh']
compute_type: str | None
threefry_partitionable: bool
xla_metadata: dict[str, Any] | None
cur_abstract_mesh: mesh_lib.AbstractMesh
def __init__(self, compute_type: str | None, threefry_partitionable: bool,
xla_metadata: dict[str, Any] | None = None):
self.compute_type = compute_type
self.threefry_partitionable = threefry_partitionable
self.cur_abstract_mesh = mesh_lib.get_abstract_mesh()
self.xla_metadata = xla_metadata
@property
def manager(self):
return JaxprEqnContextManager(self)
def __repr__(self):
return (
f"JaxprEqnContext(compute_type={self.compute_type}, "
f"threefry_partitionable={self.threefry_partitionable}, "
f"cur_abstract_mesh={self.cur_abstract_mesh}, "
f"xla_metadata={self.xla_metadata})"
)
def __hash__(self):
return hash((
self.compute_type,
self.threefry_partitionable,
self.cur_abstract_mesh,
None if self.xla_metadata is None
else tuple(sorted(self.xla_metadata.items())),
))
def __eq__(self, other):
return (self.compute_type == other.compute_type and
self.threefry_partitionable == other.threefry_partitionable and
self.cur_abstract_mesh == other.cur_abstract_mesh and
self.xla_metadata == other.xla_metadata)
class JaxprEqn:
invars: list[Atom]
outvars: list[Var]
primitive: Primitive
params: dict[str, Any]
effects: Effects
# The source_info.name_stack is always relative to the enclosing jaxpr (only)
# and does not include any name context from the caller of the jaxpr. A jaxpr
# might have multiple callers, after all.
# TODO(phawkins): update source_info.tracebacks to also be relative to the
# enclosing jaxpr.
source_info: source_info_util.SourceInfo
ctx: JaxprEqnContext
# It's slightly faster to use a class with __slots__ than a NamedTuple.
__slots__ = ['invars', 'outvars', 'primitive', 'params', 'effects',
'source_info', 'ctx']
def __init__(self, invars, outvars, primitive, params, effs, source_info,
ctx):
self.invars = invars
self.outvars = outvars
self.primitive = primitive
self.params = params
self.effects = effs
self.source_info = source_info
self.ctx = ctx
def __repr__(self):
return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip()
def replace(
self,
invars: list[Atom] | None = None,
outvars: list[Var] | None = None,
primitive: Primitive | None = None,
params: dict[str, Any] | None = None,
effects: Effects | None = None,
source_info: source_info_util.SourceInfo | None = None,
ctx: JaxprEqnContext | None = None
):
return JaxprEqn(
self.invars if invars is None else invars,
self.outvars if outvars is None else outvars,
self.primitive if primitive is None else primitive,
self.params if params is None else params,
self.effects if effects is None else effects,
self.source_info if source_info is None else source_info,
self.ctx if ctx is None else ctx,
)
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
ctx=None) -> JaxprEqn:
source_info = source_info or source_info_util.new_source_info()
ctx = ctx or JaxprEqnContext(
config.compute_on_context_manager.value,
config.threefry_partitionable.value,
xla_metadata_lib.current_xla_metadata())
if config.enable_checks.value:
assert all(isinstance(x, (Var, Literal)) for x in invars)
assert all(isinstance(v, Var) for v in outvars)
return JaxprEqn(invars, outvars, primitive, params, effects, source_info, ctx)
_var_counter = it.count()
class Var:
__slots__ = ["count", "aval", "initial_qdd", "final_qdd"]
count: int
aval: AbstractValue
# these are only useful for jaxpr binders but rather than create a separate
# type for those, breaking existing interpreters, we add fields here.
initial_qdd : QuasiDynamicData | None
final_qdd : QuasiDynamicData | None
def __init__(self, aval: AbstractValue, initial_qdd=None, final_qdd=None):
assert isinstance(aval, AbstractValue), aval
self.count = next(_var_counter)
self.aval = aval
self.initial_qdd = initial_qdd
self.final_qdd = final_qdd
def __repr__(self):
return f'Var(id={id(self)}):{self.aval.str_short()}'
def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True):
del print_dtype # unused
return f"{context.var_names[self]}"
gensym = lambda: Var
# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that
# the assignment is dropped, i.e. that an expression's output value will never
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
# treat it as a special case of one. Its `aval` is similarly inexact.
class DropVar(Var):
def __init__(self, aval: AbstractValue):
super().__init__(aval)
def __repr__(self): return '_'
def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True):
del context, print_dtype # unused
return '_'
class Literal:
# See https://docs.jax.dev/en/latest/internals/constants.html
__slots__ = ["val", "aval"]
val: Any
aval: AbstractValue
def __init__(self, val, aval):
self.val = val
self.aval = aval
@property
def hash(self):
try:
return hash(self.val)
except TypeError:
if type(self.val) in literalable_types:
try:
return hash((self.val.item(), self.val.dtype))
except (TypeError, AttributeError, ValueError):
return None
__hash__ = None # type: ignore
def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True):
del context # unused
dtype = getattr(self.aval, 'dtype', None)
if not np.shape(self.val):
val_str = str(np.asarray(self.val).item())
else:
val_str = "[...]"
if print_dtype and dtype:
return f'{val_str}:{self.aval.str_short(short_dtypes=True)}'
else:
return val_str
def __repr__(self):
return f'Literal({self.val})'
# The types of constants that can be used with core.Literal. Other constants
# end up as `constvars`.
literalable_types: set[type] = set()
def is_literalable(x: Any) -> bool:
# See https://docs.jax.dev/en/latest/internals/constants.html
for t in type(x).__mro__:
if t in literalable_types:
return (not np.shape(x) or config.use_simplified_jaxpr_constants.value)
return False
@partial(weakref_lru_cache, trace_context_in_key=False)
def jaxpr_const_args(jaxpr: Jaxpr) -> list[tuple[ArrayLike, AbstractValue]]:
# The non-scalar constants in core.Literal, in the entire Jaxpr,
# uniquified by id. These will be hoisted as const arguments to the functions
# in which they appear.
# See https://docs.jax.dev/en/latest/internals/constants.html
if not config.use_simplified_jaxpr_constants.value:
return []
consts_by_id: dict[int, tuple[ArrayLike, AbstractValue]] = {}
for v in jaxpr.outvars:
if type(v) is Literal and np.shape(v.val): # type: ignore
consts_by_id[id(v)] = (v.val, v.aval) # type: ignore
for eqn in jaxpr.eqns:
for v in eqn.invars:
if type(v) is Literal and np.shape(v.val): # type: ignore
consts_by_id[id(v)] = (v.val, v.aval) # type: ignore
consts_by_id.update({id(v_aval[0]): v_aval
for v_aval in eqn_params_const_args(eqn.params)})
return list(consts_by_id.values())
def eqn_params_const_args(params) -> list[tuple[ArrayLike, AbstractValue]]:
consts_by_id: dict[int, tuple[ArrayLike, AbstractValue]] = {}
for j in jaxprs_in_params(params):
consts_by_id.update(
{id(v_aval[0]): v_aval for v_aval in jaxpr_const_args(j)}
)
return list(consts_by_id.values())
Atom = Union[Var, Literal]
class Primitive:
name: str
# set for multi-output primitives.
multiple_results: bool = False
# set for call primitives processed in final style.
call_primitive: bool = False
# set for map primitives processed in final style.
map_primitive: bool = False
# set for ref primitives
ref_primitive: bool = False
# set for primitives that can skip canonicalization of values
skip_canonicalization: bool = False
is_effectful = None
def __init__(self, name: str):
self.name = name
def __repr__(self):
return f'{self.name}'
def bind(self, *args, **params):
args = args if self.skip_canonicalization else map(canonicalize_value, args)
return self._true_bind(*args, **params)
def _true_bind(self, *args, **params):
for arg in args:
if isinstance(arg, Tracer) and not arg._trace.is_valid():
raise escaped_tracer_error(arg)
# TODO: figure out how to handle function arguments for this assert
# assert (not config.enable_checks.value or
# all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
# This is equivalent to "with take_current_trace()", but the bind() code
# is called frequently and it's slightly faster to avoid using a context
# manager object.
prev_trace = trace_ctx.trace
trace_ctx.set_trace(eval_trace)
try:
return self.bind_with_trace(prev_trace, args, params)
finally:
trace_ctx.set_trace(prev_trace)
def bind_with_trace(self, trace, args, params):
# TODO(mattjj,dougalm): remove this block?
try: in_type = map(typeof, args)
except: pass # try lojax error message
else:
if self.is_high(*in_type, **params) and trace.requires_low:
with set_current_trace(trace):
return self.to_lojax(*args, **params) # type: ignore
return trace.process_primitive(self, args, params)
trace.process_primitive(self, args, params) # may raise lojax error
raise Exception(f"couldn't apply typeof to args: {args}")
def def_impl(self, impl):
self.impl = impl
return impl
def def_abstract_eval(self, abstract_eval):
self.abstract_eval = _effect_free_abstract_eval(abstract_eval)
return abstract_eval
def def_effectful_abstract_eval(self, effectful_abstract_eval):
self.abstract_eval = effectful_abstract_eval
return effectful_abstract_eval
def def_effectful_abstract_eval2(self, abstract_eval):
self.abstract_eval = _generic_effectful_abstract_eval(abstract_eval, self)
return abstract_eval
def def_bind_with_trace(self, bind_with_trace):
self.bind_with_trace = bind_with_trace
return bind_with_trace
def impl(self, *args, **params):
raise NotImplementedError("Evaluation rule for '{}' not implemented"
.format(self.name))
def abstract_eval(self, *args, **params):
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
.format(self.name))
def get_bind_params(self, params):
return [], params
def is_high(self, *avals, **params) -> bool:
return False
def _effect_free_abstract_eval(abstract_eval):
def abstract_eval_(*args, **kwargs):
return abstract_eval(*args, **kwargs), no_effects
return abstract_eval_
@dataclass(frozen=True)
class GenericEffect(Effect):
prim: Primitive
effects.lowerable_effects.add_type(GenericEffect)
effects.control_flow_allowed_effects.add_type(GenericEffect)
effects.custom_derivatives_allowed_effects.add_type(GenericEffect)
def _generic_effectful_abstract_eval(abstract_eval, prim):
def abstract_eval_(*args, **kwargs):
return abstract_eval(*args, **kwargs), {GenericEffect(prim)}
return abstract_eval_
# -------------------- lifting --------------------
# TODO(mattjj): replace this approach with a primitive-keyed table of rules
def traverse_jaxpr_params(f, params):
"""Applies f to each jaxpr parameter and returns a tuple of returned values."""
return {name: f(p)
for name, param in params.items()
for p in (param if isinstance(param, (tuple, list)) else [param])
if type(p) in (Jaxpr, ClosedJaxpr)}
def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[Any]:
def read(v: Atom) -> Any:
return v.val if isinstance(v, Literal) else env[v]
def write(v: Var, val: Any) -> None:
if config.enable_checks.value:
assert typecheck(v.aval, val), (v.aval, get_aval(val), val)
env[v] = val
env: dict[Var, Any] = {}
foreach(write, jaxpr.constvars, consts)
foreach(write, jaxpr.invars, args)
lu = last_used(jaxpr)
for eqn in jaxpr.eqns:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
traceback = eqn.source_info.traceback if propagate_source_info else None
with source_info_util.user_context(
traceback, name_stack=name_stack), eqn.ctx.manager:
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
foreach(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
clean_up_dead_vars(eqn, env, lu)
return map(read, jaxpr.outvars)
def check_avals_context_mesh(avals, prim_name):
cur_mesh = mesh_lib.get_abstract_mesh()
for a in avals:
# TODO(yashkatariya): Should be cur_mesh.unset
if cur_mesh.empty or a.sharding.mesh.empty:
continue
# avals can have meshes with different axis_names so allow that in
# full auto mode.
if a.sharding.mesh.are_all_axes_auto and cur_mesh.are_all_axes_auto:
continue
if a.sharding.mesh != cur_mesh:
raise ValueError(
f"For primitive {prim_name}, context mesh {cur_mesh} should match"
f" the aval mesh {a.sharding.mesh} for shape {a.str_short()}. This"
" error occurs at source: "
f" {source_info_util.summarize(source_info_util.current())}")
# -------------------- tracing --------------------
TracerType = TypeVar('TracerType', bound='Tracer')
class Trace(Generic[TracerType]):
__slots__ = ("__weakref__", "_invalidated", "_weakref", "requires_low")
def __init__(self):
self._invalidated = False
# We frequently need a weakref to a trace, so let's precompute one.
self._weakref = weakref.ref(self)
self.requires_low = True
def process_primitive(self, primitive, tracers, params):
raise NotImplementedError("must override")
def invalidate(self):
self._invalidated = True
def is_valid(self):
return not self._invalidated
def __repr__(self):
return f'{self.__class__.__name__}'
def process_call(self, call_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_call to handle call-like "
"primitives")
raise NotImplementedError(msg)
def process_map(self, map_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_map to handle map-like "
"primitives")
raise NotImplementedError(msg)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
symbolic_zeros):
msg = (f"{type(self)} must override process_custom_jvp_call "
"to handle custom_jvp primitives")
raise NotImplementedError(msg)
def process_custom_transpose(self, prim: Primitive,
call: lu.WrappedFun, tracers, **params):
msg = (f"{type(self)} must override process_custom_transpose "
"to handle custom_transpose_call primitives")
raise NotImplementedError(msg)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros):
msg = (f"{type(self)} must override process_custom_vjp_call "
"to handle custom_vjp primitives")
raise NotImplementedError(msg)
# TODO(dougalm): deprecate/delete
def full_raise(self, x):
return x
# TODO(dougalm): deprecate/delete
@property
def main(self):
return getattr(self, "tag", None)
def escaped_tracer_error(tracer, detail=None):
num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
'had a side effect, allowing for a reference to an intermediate value '
f'with type {tracer.aval.str_short()} wrapped in a '
f'{type(tracer).__name__} to escape the scope of the transformation.\n'
'JAX transformations require that functions explicitly return their '
'outputs, and disallow saving intermediate values to global state.')
dbg = getattr(tracer, '_debug_info', None)
if dbg is not None:
msg += ('\nThe function being traced when the value leaked was '
f'{dbg.func_src_info} traced for {dbg.traced_for}.')
line_info = getattr(tracer, '_line_info', None)
if line_info is not None:
divider = '\n' + '-'*30 + '\n'
msg += divider
msg += ('The leaked intermediate value was created on line '
f'{source_info_util.summarize(line_info)}. ')
msg += divider
if num_frames > 0:
msg += (f'When the value was created, the final {num_frames} stack '
'frames (most recent last) excluding JAX-internal frames were:')
msg += divider + source_info_util.summarize(
line_info, num_frames=num_frames) + divider
msg += ('\nTo catch the leak earlier, try setting the environment variable '
'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context '
'manager.')
if detail:
msg += f'Detail: {detail}'
return UnexpectedTracerError(msg)
def check_scalar_conversion(arr: Array):
if arr.ndim > 0:
raise TypeError("Only scalar arrays can be converted to Python scalars; "
f"got {arr.ndim=}")
def check_integer_conversion(arr: Array):
if not (arr.shape == () and dtypes.issubdtype(arr.dtype, np.integer)):
raise TypeError("Only integer scalar arrays can be converted to a scalar index.")
def check_bool_conversion(arr: Array):
if arr.size == 0:
raise ValueError("The truth value of an empty array is ambiguous. Use"
" `array.size > 0` to check that an array is not empty.")
if arr.size > 1:
raise ValueError("The truth value of an array with more than one element"
" is ambiguous. Use a.any() or a.all()")
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
pytype_aval_mappings[str] = _str_abstractify
def _aval_property(name):
return property(lambda self: getattr(self.aval, name))
if TYPE_CHECKING or jaxlib_extension_version < 388:
# We want Python type checkers to accept `some_tracer: jax.Array`, even though
# tracers can represent non-arrays. That is, ideally we would only accept that
# annotation when the Tracer instance has a ShapedArray aval, but we can't
# decide that at Python type checking time. So instead we're overly permissive
# and allow all Tracer instances to typecheck against a jax.Array annotation.
TracerBase = Array
TracerMeta = StrictABCMeta
else:
TracerBase = object
TracerMeta = type
class Tracer(TracerBase, metaclass=TracerMeta):
__array_priority__ = 1000
if jaxlib_extension_version >= 388:
__slots__ = ['__weakref__', '_trace', '_line_info']
else:
__slots__ = ['_trace', '_line_info']
__hash__ = None # type: ignore
_trace: Trace
_line_info: source_info_util.SourceInfo | None
dtype = _aval_property('dtype')
ndim = _aval_property('ndim')
size = _aval_property('size')
shape = _aval_property('shape')
def __init__(self, trace: Trace):
self._trace = trace
def _error_repr(self):
if self.aval is None:
return f"traced array with aval {self.aval}"
return f"traced array with shape {self.aval.str_short()}"
def __array__(self, *args, **kw):
raise TracerArrayConversionError(self)
# helper for isinstance(tracer, jax.Array), here to avoid circular imports
def _is_traced_array(self):
return isinstance(self.aval, ShapedArray)
def __dlpack__(self, *args, **kw):
raise ConcretizationTypeError(self,
f"The __dlpack__() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def tolist(self):
raise ConcretizationTypeError(self,
f"The tolist() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def tobytes(self, order="C"):
del order
raise ConcretizationTypeError(self,
f"The tobytes() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
# TODO(dougalm): deprecate/delete
def full_lower(self):
raise NotImplementedError("must override: ", type(self))
def __iter__(self):
return iter(self.aval._iter(self))
def __reversed__(self):
return iter(self[::-1])
def __len__(self):
return self.aval._len(self)
def to_concrete_value(self):
# Should return the concrete value if there is one, or else None.
return None
@property
def sharding(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def committed(self):
raise ConcretizationTypeError(
self,
f"The 'committed' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def device(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(
f"The 'device' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def addressable_shards(self):
raise ConcretizationTypeError(self,
f"The 'addressable_shards' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def at(self):
return self.aval.at.fget(self)
@property
def aval(self):
raise NotImplementedError("must override")
def get_referent(self) -> Any:
return self # Override for object equivalence checking
def __bool__(self):
if is_concrete(self): return bool(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_bool_conversion(self)
return self.aval._bool(self)
def __int__(self):
if is_concrete(self): return int(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_scalar_conversion(self)
return self.aval._int(self)
def __float__(self):
check_scalar_conversion(self)
return self.aval._float(self)
def __complex__(self):
check_scalar_conversion(self)
return self.aval._complex(self)
def __hex__(self):
if is_concrete(self): return hex(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._hex(self)
def __oct__(self):
if is_concrete(self): return oct(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._oct(self)
def __index__(self):
if is_concrete(self): return operator.index(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._index(self)
# raises a useful error on attempts to pickle a Tracer.
def __reduce__(self):
raise ConcretizationTypeError(
self, ("The error occurred in the __reduce__ method, which may "
"indicate an attempt to serialize/pickle a traced value."))
# raises the better error message from ShapedArray
def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val)
# NumPy also only looks up special methods on classes.
def __array_module__(self, types): return self.aval._array_module(self, types)
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert not config.enable_checks.value or name != "aval"
if name == 'sharding':
raise AttributeError(
f"The 'sharding' attribute is not available on {self._error_repr()}. "
"To query sharding information on tracers, use `jax.typeof(x)`.")
try:
attr = getattr(self.aval, name)
except AttributeError as err:
raise AttributeError(
f"{self.__class__.__name__} has no attribute {name}"
) from err
else:
t = type(attr)
if t is aval_property:
return attr.fget(self)
elif t is aval_method:
return types.MethodType(attr.fun, self)
else:
return attr
def _short_repr(self) -> str:
return f'{self.__class__.__name__}<{self.aval}>'
def _pretty_print(self, verbose: bool = False) -> pp.Doc:
if not verbose:
return pp.text(self._short_repr())
base = pp.text(f'Traced<{self.aval}>with<{self._trace}>')
contents = [(name, attr._pretty_print() if isinstance(attr, Tracer)
else pp.text(repr(attr))) for name, attr in self._contents()]
if contents:
base = pp.group(pp.nest(2, pp.concat([
base, pp.text(' with'), pp.brk(), pp.join(pp.brk(), [
pp.text(f'{name} = ') + pp_payload
for name, pp_payload in contents])
])))
return base
def __repr__(self):
return self._pretty_print(verbose=False).format()
def _contents(self):
try:
return [(name, getattr(self, name)) for name in self.__slots__]
except AttributeError:
return ()
def _origin_msg(self) -> str:
return ""
# Methods that are only valid for materialized arrays
def addressable_data(self, index):
raise ConcretizationTypeError(self,
f"The addressable_data() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def block_until_ready(self):
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(
f"The 'block_until_ready' method is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def copy_to_host_async(self):
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(
f"The 'copy_to_host_async' method is not available on {self._error_repr()}."
f"{self._origin_msg()}")
def delete(self):
raise ConcretizationTypeError(self,
f"The delete() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def devices(self):
raise ConcretizationTypeError(self,
f"The devices() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def global_shards(self):
raise ConcretizationTypeError(self,
f"The global_shards property was called on {self._error_repr()}."
f"{self._origin_msg()}")
def is_deleted(self):
raise ConcretizationTypeError(self,
f"The is_deleted() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def is_fully_addressable(self):
raise ConcretizationTypeError(self,
f"The is_fully_addressable property was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def is_fully_replicated(self):
raise ConcretizationTypeError(self,
f"The is_fully_replicated property was called on {self._error_repr()}."
f"{self._origin_msg()}")
def on_device_size_in_bytes(self):
raise ConcretizationTypeError(self,
f"The on_device_size_in_bytes() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def traceback(self):
raise ConcretizationTypeError(self,
f"The traceback property was called on {self._error_repr()}."
f"{self._origin_msg()}")
def unsafe_buffer_pointer(self):
raise ConcretizationTypeError(self,
f"The unsafe_buffer_pointer() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
if jaxlib_extension_version >= 388:
_jax.set_tracer_class(Tracer)
# these can be used to set up forwarding of properties and instance methods from
# Tracer instances to the underlying avals
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])
pytype_aval_mappings[Tracer] = lambda x: x.aval
def check_eval_args(args):
for arg in args:
if isinstance(arg, Tracer):
raise escaped_tracer_error(arg)
class EvalTrace(Trace):
def process_primitive(self, primitive, args, params):
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
return call_impl_with_key_reuse_checks(primitive, primitive.impl, *args, **params)
else:
# TODO(dougalm): delete. this shouldn't be necessary
args = map(full_lower, args)
check_eval_args(args)
return primitive.impl(*args, **params)
def process_call(self, primitive, f, tracers, params):
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
return call_impl_with_key_reuse_checks(primitive, primitive.impl, f, *tracers, **params)
else:
return primitive.impl(f, *tracers, **params)
process_map = process_call
def process_custom_transpose(self, primitive, call, tracers, **_):
del primitive, _
return call.call_wrapped(*tracers)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, **_):
del primitive, jvp, _ # Unused.
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # pytype: disable=signature-mismatch
del primitive, fwd, bwd, _ # Unused.
return fun.call_wrapped(*tracers)
def cur_qdd(self, x):
return x.cur_qdd()
class TraceTag:
# TODO: this works for surprisingly subtle reasons. Function transformations
# like `jvp_subtrace` are parameterized by a tag that identifies the set of
# pre-existing tracers we want to unpack during the transformation. A function
# defined in an outer scope can't have any closed-over traces, so the tag is
# irrelevant. A function defined in the current scope may have closed-over
# traces, but the tag will never change so we'll never get a spurious cache
# hit. The plan is to do away with `lu.cache` altogether, and use a simpler
# caching scheme that only caches top-level functions. Then we can remove this
# hack.
def __hash__(self):
return hash(TraceTag)
def __eq__(self, other):
return isinstance(other, TraceTag)
ParamDict = dict[str, Any]
AxisName = Hashable
no_axis_name = object()
@dataclass(frozen=True)
class AxisEnv:
axis_sizes : dict[AxisName, int]
spmd_axis_names : set[AxisName]
explicit_mesh_axis_names: frozenset[AxisName]
def axis_size(self, axis_name):
if axis_name not in self.axis_sizes:
raise NameError(f"unbound axis name: {axis_name}")
else:
return self.axis_sizes[axis_name]
def axis_exists(self, axis_name):
return axis_name in self.axis_sizes
def axis_names(self):
return tuple(k for k in self.axis_sizes)
def pop_pure(self, axis_name):
new_sizes = self.axis_sizes.copy()
new_sizes.pop(axis_name)
return AxisEnv(new_sizes, self.spmd_axis_names,
self.explicit_mesh_axis_names)
def extend_pure(self, name_size_pairs):
new_sizes = self.axis_sizes.copy()
new_sizes.update((name, size) for name, size in name_size_pairs
if name is not no_axis_name)
return AxisEnv(new_sizes, self.spmd_axis_names,
self.explicit_mesh_axis_names)
def add_spmd_axis_names(self, axis_names):
new_spmd_axis_names = self.spmd_axis_names | set(axis_names)
return AxisEnv(self.axis_sizes, new_spmd_axis_names,
self.explicit_mesh_axis_names)
def add_explicit_mesh_axis_names(self, axis_names):
new_ema = self.explicit_mesh_axis_names | frozenset(axis_names)
return AxisEnv(self.axis_sizes, self.spmd_axis_names, new_ema)
def as_hashable_key(self):
return tuple((name, size) for (name, size) in self.axis_sizes.items()
if name is not no_axis_name)
eval_trace = EvalTrace()
top_axis_env = AxisEnv({}, set(), frozenset())
class TracingContext(threading.local):
trace: Trace | None
axis_env : AxisEnv
def __init__(self):
self.reset()
def reset(self):
self.trace = eval_trace
self.axis_env = top_axis_env
def is_top_level(self) -> bool:
return (self.trace is eval_trace and
self.axis_env is top_axis_env)
def set_trace(self, trace):
self.trace = trace
ts = trace._weakref if trace is not None else None
config.trace_state.set_local(ts)
def set_axis_env(self, axis_env):
self.axis_env = axis_env
config.axis_env_state.set_local(axis_env.as_hashable_key())
def update_thread_local_jit_state(self):
ts = self.trace._weakref if self.trace is not None else None
config.trace_state.set_local(ts)
config.axis_env_state.set_local(self.axis_env.as_hashable_key())
trace_ctx = TracingContext()
class TakeCurrentTraceContextManager:
__slots__ = ['prev']
def __enter__(self):
self.prev = trace_ctx.trace
trace_ctx.set_trace(eval_trace)
return self.prev
def __exit__(self, exc_type, exc_value, traceback):
trace_ctx.set_trace(self.prev)
take_current_trace = TakeCurrentTraceContextManager
class SetCurrentTraceContextManager:
__slots__ = ['trace', 'check_leaks', 'prev']
def __init__(self, trace, check_leaks=False):
self.trace = trace
self.check_leaks = check_leaks
def __enter__(self):
self.prev = trace_ctx.trace
trace_ctx.set_trace(self.trace)
def __exit__(self, exc_type, exc_value, traceback):
trace_ctx.set_trace(self.prev)
if self.check_leaks and config.check_tracer_leaks.value:
self.trace.invalidate()
trace_ref = self.trace._weakref
del self.trace
live_trace = trace_ref()
if live_trace is not None:
leaked_tracers = maybe_find_leaked_tracers(live_trace)
if leaked_tracers:
raise leaked_tracer_error("trace", live_trace, leaked_tracers)
set_current_trace = SetCurrentTraceContextManager
class ExtendAxisEnvNdContextManager:
__slots__ = ['prev', 'name_size_pairs']
def __init__(self, name_size_pairs: Iterable[tuple[AxisName, int]]):
self.name_size_pairs = name_size_pairs
def __enter__(self):
self.prev = trace_ctx.axis_env
trace_ctx.set_axis_env(self.prev.extend_pure(self.name_size_pairs))
def __exit__(self, exc_type, exc_value, traceback):
trace_ctx.set_axis_env(self.prev)
extend_axis_env_nd = ExtendAxisEnvNdContextManager
class AddSpmdAxisNamesContextManager:
__slots__ = ['prev', 'axis_names']
def __init__(self, axis_names: AxisName | None):
self.axis_names = axis_names
def __enter__(self):
self.prev = trace_ctx.axis_env
if self.axis_names is not None:
trace_ctx.set_axis_env(self.prev.add_spmd_axis_names(self.axis_names))
def __exit__(self, exc_type, exc_value, traceback):
trace_ctx.set_axis_env(self.prev)
add_spmd_axis_names = AddSpmdAxisNamesContextManager
class AddExplicitMeshAxisNamesContextManager:
__slots__ = ['prev', 'axis_names']
def __init__(self, axis_names: AxisName | None):
self.axis_names = axis_names
def __enter__(self):
self.prev = trace_ctx.axis_env
if self.axis_names is not None:
trace_ctx.set_axis_env(self.prev.add_explicit_mesh_axis_names(
self.axis_names))
def __exit__(self, exc_type, exc_value, traceback):
trace_ctx.set_axis_env(self.prev)
add_explicit_mesh_axis_names = AddExplicitMeshAxisNamesContextManager
def get_axis_env():
return trace_ctx.axis_env
def _initialize_jax_jit_thread_local_state():
"""Initializes the C++ thread-local context.
When the user spawns threads, the C++ `jax_jit.thread_local_state` is None.
The C++ accessor calls this function if it realizes the thread_local_state
is None (which means it's not yet initialized for this thread).
This function does not live in `config.py`, to prevent circular imports.
"""
trace_ctx.update_thread_local_jit_state()
jax_jit.set_thread_local_state_initialization_callback(
_initialize_jax_jit_thread_local_state)
def trace_state_clean() -> bool:
return trace_ctx.is_top_level()
def reset_trace_state() -> bool:
"""Resets the global trace state and returns True if it was already clean."""
if not trace_ctx.is_top_level():
trace_ctx.reset()
trace_ctx.update_thread_local_jit_state()
return False
else:
return True
TRACER_LEAK_DEBUGGER_WARNING = """\
JAX check_tracer_leaks behavior can trigger false positives when used with a debugger.
To avoid false positives and silence this warning, you can disable thread tracing using
the following:
import threading
threading.current_thread().pydev_do_not_trace = True
"""
@contextmanager
def ensure_no_leaks(trace:Trace):
yield
trace.invalidate()
if config.check_tracer_leaks.value:
trace_ref = trace._weakref
del trace
live_trace = trace_ref()
if live_trace is not None:
leaked_tracers = maybe_find_leaked_tracers(live_trace)
if leaked_tracers:
raise leaked_tracer_error("trace", live_trace, leaked_tracers)
def maybe_find_leaked_tracers(trace: Trace) -> list[Tracer]:
"""Find the leaked tracers holding a reference to the Trace
"""
if not getattr(threading.current_thread(), 'pydev_do_not_trace', True):
warnings.warn(TRACER_LEAK_DEBUGGER_WARNING)
# Trigger garbage collection to filter out unreachable objects that are alive
# only due to cyclical dependencies. (We don't care about unreachable leaked
# tracers since they can't interact with user code and cause a problem.)
gc.collect()
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(trace)))
return tracers
def leaked_tracer_error(name: str, t, tracers: list[Tracer]) -> Exception:
assert tracers
why = partial(_why_alive, {id(tracers)})
msgs = '\n\n'.join(f'{tracers[i]}{tracers[i]._origin_msg()}{why(tracers[i])}'
for i in range(len(tracers)))
return Exception(f'Leaked {name} {t}. Leaked tracer(s):\n\n{msgs}\n')
def _why_alive(ignore_ids: set[int], x: Any) -> str:
parents = lambda x: [r for r in gc.get_referrers(x) if id(r) not in ignore_ids]
child, lines, seen = x, [], set()
while (id(child) not in seen and type(child) is not types.ModuleType
and parents(child)):
parent = parents(child)[0] # just pick one parent
# For namespaces (like modules and class instances) and closures, the
# references may form a simple chain: e.g. instance refers to its own
# __dict__ which refers to child, or function refers to its __closure__
# which refers to cells which refer to child. In these cases, we can provide
# a more intuitive description by collapsing the chain into a single
# parent->child jump. We do that by setting `parent` here to be a
# grandparent (or great-grandparent) of `child`, and then handling that case
# in _why_alive_container_info. See example:
# https://github.com/jax-ml/jax/pull/13022#discussion_r1008456599
# To prevent this collapsing behavior, just comment out this code block.
if (isinstance(parent, dict) and
getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]):
parent = parents(parent)[0]
elif type(parent) is types.CellType:
parent = parents(parents(parent)[0])[0]
line = f'<{type(child).__name__} {id(child)}> is referred to by '
lines.append(line + _why_alive_container_info(parent, id(child)))
seen.add(id(child))
child = parent
return '\n' + '\n'.join(lines) if lines else ''
def _why_alive_container_info(container, obj_id) -> str:
name = f'<{type(container).__name__} {id(container)}>'
if type(container) is types.ModuleType:
name = getattr(container, '__name__', name)
if type(container) is types.FunctionType:
name_ = getattr(container, '__name__', '<no-name>')
closure = inspect.getclosurevars(container)
keys = [k for k, v in dict(closure.nonlocals, **closure.globals).items()
if id(v) == obj_id]
if len(keys) == 1: return f'{name} ({name_}) closed-over variable {keys[0]}'
elif len(keys) > 1: return (f'{name} in closed-over variables ' +
', '.join(map(repr, keys)))
if hasattr(container, '__dict__'):
keys = [k for k in vars(container) if id(vars(container)[k]) == obj_id]
if len(keys) == 1: return f'{name}.{keys[0]}'
elif len(keys) > 1: return f'{name} in vars ' + ', '.join(map(repr, keys))
if isinstance(container, (list, tuple)):
idxs = [i for i, x in enumerate(container) if id(x) == obj_id]
if len(idxs) == 1: return f'{name}[{idxs[0]}]'
else: return f'{name} at indices ' + ', '.join(map(str, idxs))
if isinstance(container, dict):
keys = [k for k in container if id(container[k]) == obj_id]
if len(keys) == 1: return f'{name}[{keys[0]!r}]'
else: return f'{name} at keys ' + ', '.join(map(repr, keys))
if isinstance(container, types.ModuleType):
return f' named {container.__name__}'
return name
@contextmanager
def ensure_compile_time_eval():
"""Context manager to ensure evaluation at trace/compile time (or error).
Some JAX APIs like :func:`jax.jit` and :func:`jax.lax.scan` involve staging,
i.e., delaying the evaluation of numerical expressions (like :mod:`jax.numpy`
function applications) so that instead of performing those computations
eagerly while evaluating the corresponding Python expressions, their
computation is carried out separately, e.g. after optimized compilation. But
this delay can be undesirable. For example, numerical values might be needed
to evaluate Python control flow and so their evaluation cannot be delayed. As
another example, it may be beneficial to ensure compile time evaluation (or
"constant folding") for performance reasons.
This context manager ensures that JAX computations are evaluated eagerly. If
eager evaluation is not possible, a ``ConcretizationTypeError`` is raised.
Here's a contrived example::
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
with jax.ensure_compile_time_eval():
y = jnp.sin(3.0)
z = jnp.sin(y)
z_positive = z > 0
if z_positive: # z_positive is usable in Python control flow
return jnp.sin(x)
else:
return jnp.cos(x)
Here's a real-world example from https://github.com/jax-ml/jax/issues/3974::
import jax
import jax.numpy as jnp
from jax import random
@jax.jit
def jax_fn(x):
with jax.ensure_compile_time_eval():
y = random.randint(random.key(0), (1000,1000), 0, 100)
y2 = y @ y
x2 = jnp.sum(y2) * x
return x2
A similar behavior can often be achieved simply by 'hoisting' the constant
expression out of the corresponding staging API::
y = random.randint(random.key(0), (1000,1000), 0, 100)
@jax.jit
def jax_fn(x):
y2 = y @ y
x2 = jnp.sum(y2)*x
return x2
But in some cases it can be more convenient to use this context manager.
"""
with config.eager_constant_folding(True):
yield
@contextmanager
def eval_context():
with set_current_trace(eval_trace):
yield
# TODO(dougalm): deprecate/delete
def full_lower(val):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
def get_referent(x: Any) -> Any:
return x.get_referent() if isinstance(x, Tracer) else x
def same_referent(x: Any, y: Any) -> bool:
return get_referent(x) is get_referent(y)
def dedup_referents(itr: Iterable[Any]) -> list[Any]:
return list({HashableWrapper(get_referent(x)):x for x in itr}.values())
def definitely_equal(x, y):
if isinstance(x, Tracer) or isinstance(y, Tracer):
return same_referent(x, y)
elif x is y:
return True
try:
return x == y
except InconclusiveDimensionOperation:
return False
# -------------------- abstract values --------------------
class AbstractValue:
__slots__: list[str] = []
is_high = False
has_qdd = False
def to_tangent_aval(self):
raise NotImplementedError("must override")
def to_cotangent_aval(self):
raise NotImplementedError("must override")
# TODO(dougalm): deprecate this alias
def at_least_vspace(self):
return self.to_tangent_aval()
def __repr__(self):
try:
kv_pairs = (f'{k}={v}' for k, v in self.__dict__.items())
return '{}({})'.format(self.__class__.__name__, ','.join(kv_pairs))
except AttributeError:
return self.__class__.__name__
def update_weak_type(self, weak_type):
return self
def update_vma(self, vma):
return self
def strip_weak_type(self) -> AbstractValue:
return self.update_weak_type(False)
def normalize(self) -> AbstractValue:
return self.strip_weak_type()
def update(self, **kwargs):
raise NotImplementedError("must override")
def lo_ty(self):
return [self]
def lo_ty_qdd(self, qdd):
raise NotImplementedError("avals with qdd must override")
def str_short(self, short_dtypes=False, mesh_axis_types=False):
return str(self)
InputType = tuple[AbstractValue]
OutputType = tuple[AbstractValue]
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
Value = Any
def valid_jaxtype(x) -> bool:
try:
aval = abstractify(x)
except TypeError:
return False
else:
if hasattr(aval, "dtype") and dtypes.is_string_dtype(aval.dtype):
return False
else:
return True
def check_valid_jaxtype(x):
if not valid_jaxtype(x):
raise TypeError(
f"Value {x!r} of type {type(x)} is not a valid JAX type")
def mem_kind_to_space(mem_kind: str) -> MemorySpace:
if mem_kind == 'pinned_host':
return MemorySpace.Host
return MemorySpace.Device
def mem_space_to_kind(mem_space: MemorySpace) -> str:
if mem_space == MemorySpace.Device:
return 'device'
elif mem_space == MemorySpace.Host:
return 'pinned_host'
else:
assert False, "unreachable"
@cache(max_size=4096,
trace_context_in_key=lambda: config.remove_size_one_mesh_axis_from_type.value)
def update_aval_with_sharding(aval, sharding, vma=None):
if isinstance(sharding, NamedSharding):
s = NamedSharding(sharding.mesh.abstract_mesh,
sharding.spec._normalized_spec_for_aval(aval.ndim))
return aval.update(sharding=s, vma=aval.vma if vma is None else vma,
memory_space=mem_kind_to_space(sharding.memory_kind))
return aval if vma is None else aval.update(vma=vma)
# We have three flavors of abstractification APIs here which each used to have
# their own separate implementation. Now they're effectively the same, with the
# following differences:
#
# - abstractify returns avals for non-traced array-like objects.
# - get_aval is like abstractify, but also accepts tracers.
# - shaped_abstractify is like get_aval, but also accepts duck-typed arrays.
#
# TODO(jakevdp): can these be unified further?
def shaped_abstractify(x):
typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
for t in typ.__mro__[1:]:
if (aval_fn := pytype_aval_mappings.get(t)):
return aval_fn(x)
if isinstance(x, AbstractValue):
return x
if hasattr(x, '__jax_array__'):
raise ValueError(
'Triggering __jax_array__() during abstractification is no longer'
' supported. To avoid this error, either explicitly convert your object'
' using jax.numpy.array(), or register your object as a pytree.'
)
if hasattr(x, 'dtype'):
aval = ShapedArray(
np.shape(x),
dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=getattr(x, "weak_type", False),
)
return update_aval_with_sharding(aval, getattr(x, 'sharding', None))
raise TypeError(
f"Cannot interpret value of type {typ} as an abstract array; it "
"does not have a dtype attribute")
def abstractify(x):
if isinstance(x, Tracer):
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
return get_aval(x)
# TODO(phawkins): the return type should be AbstractValue.
def get_aval(x: Any) -> Any:
typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
for t in typ.__mro__[1:]:
if (aval_fn := pytype_aval_mappings.get(t)):
return aval_fn(x)
if hasattr(x, '__jax_array__'):
raise ValueError(
'Triggering __jax_array__() during abstractification is no longer'
' supported. To avoid this error, either explicitly convert your object'
' using jax.numpy.array(), or register your object as a pytree.'
)
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
# TODO(phawkins): the return type should be AbstractValue.
def typeof(x: Any, /) -> Any:
"""Return the JAX type (i.e. :class:`AbstractValue`) of the input.
Raises a ``TypeError`` if ``x`` is not a valid JAX type.
"""
return get_aval(x)
def is_concrete(x):
return to_concrete_value(x) is not None
def to_concrete_value(x):
if isinstance(x, Tracer):
return x.to_concrete_value()
else:
return x
def concretization_function_error(fun, suggest_astype=False):
fname = getattr(fun, "__name__", fun)
fname_context = f"The problem arose with the `{fname}` function. "
if suggest_astype:
fname_context += ("If trying to convert the data type of a value, "
f"try using `x.astype({fun.__name__})` "
f"or `jnp.array(x, {fun.__name__})` instead.")
if fun is bool:
def error(self, arg):
raise TracerBoolConversionError(arg)
elif fun in (hex, oct, operator.index):
def error(self, arg):
raise TracerIntegerConversionError(arg)
else:
def error(self, arg):
raise ConcretizationTypeError(arg, fname_context)
return error
def concrete_or_error(force: Any, val: Any, context=""):
"""Like force(val), but gives the context in the error message."""
if force is None:
force = lambda x: x
if isinstance(val, Tracer):
maybe_concrete = val.to_concrete_value()
if maybe_concrete is None:
raise ConcretizationTypeError(val, context)
else:
return force(maybe_concrete)
else:
return force(val)
def concrete_dim_or_error(val: Any, context=""):
"""Like concrete_or_error(operator.index), allowing symbolic dimensions."""
if is_symbolic_dim(val):
return val
else:
return concrete_or_error(operator.index, val, context=context)
### Quasi-dynamic data
# Quasi-dynamic data includes things like liveness bits and the content type of
# a type-changeable box. These change throughout the program but at a given
# point in the program they have a single statically known value.
class MutableQuasiDynamicData:
def __init__(self, val : QuasiDynamicData | None):
self.init_val = val
self.cur_val = val # immutable payload
def update(self, val):
self.cur_val = val
def __repr__(self):
return f'MutableQuasiDynamicData(init_val={self.init_val}, cur_val={self.cur_val})'
class QuasiDynamicData:
pass
@dataclass(frozen=True)
class AvalQDD:
is_high = True
aval: AbstractValue
qdd: QuasiDynamicData | None # immutable
has_qdd = True
def lo_ty(self):
return self.aval.lo_ty_qdd(self.qdd) # type: ignore
def read_loval(self, val):
return self.aval.read_loval(self.qdd, val) # type: ignore
def new_from_loval(self, *lovals):
return self.aval.new_from_loval(self.qdd, *lovals) # type: ignore
def to_tangent_aval(self):
return AvalQDD(self.aval.to_tangent_aval(), self.qdd.to_tangent_qdd())
@dataclass(frozen=True)
class AvalMutableQDD:
aval: AbstractValue
mutable_qdd: MutableQuasiDynamicData
def cur_qdd(x):
prev_trace = trace_ctx.trace
trace_ctx.set_trace(eval_trace)
try:
return prev_trace.cur_qdd(x)
finally:
trace_ctx.set_trace(prev_trace)
def cur_aval_qdd(x):
aval = typeof(x)
qdd = cur_qdd(x) if aval.has_qdd else None
return AvalQDD(aval, qdd)
### Extended dtypes
#
# Extended dtypes are JAX-specific dtypes that allow us to represent logical
# arrays of element types that do not have an obvious direct correspondence
# to ("physical") arrays of basic types in a compiler. In particular, their
# element types differ from those of XLA and NumPy (e.g. int32). These dtypes
# are only known to JAX. Their implementation is determined by:
# a) an object representing the extended dtype, accessible via the `dtype`
# attribute on corresponding JAX arrays and, internally, on avals such
# as ShapedArrays that correspond to such JAX arrays;
# b) a set of rules, available via a private attribute on the extended dtype
# object in (a).
# The rules in (b) tell JAX internals how to ground out the element
# type for interaction with the compiler and runtime, e.g. when lowering
# to the compiler's language.
@overload
def physical_aval(aval: ShapedArray) -> ShapedArray: ...
@overload # TODO(frostig): remove this case
def physical_aval(aval: AbstractValue) -> AbstractValue: ...
def physical_aval(aval):
if (isinstance(aval, ShapedArray) and
isinstance(aval.dtype, dtypes.ExtendedDType)):
elt_aval = physical_element_aval(aval.dtype)
from jax._src.sharding_impls import physical_sharding # type: ignore
return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype,
sharding=physical_sharding(aval, aval.sharding),
vma=aval.vma)
return aval
def physical_shape(logical_shape, dtype):
elt_aval = physical_element_aval(dtype)
return (*logical_shape, *elt_aval.shape)
def physical_element_aval(edtype: dtypes.ExtendedDType) -> ShapedArray:
duck = edtype._rules.physical_element_aval(edtype) # type: ignore
return ShapedArray(duck.shape, dtypes.dtype(duck.dtype))
def _dtype_object(dtype):
return dtype if isinstance(dtype, dtypes.ExtendedDType) else np.dtype(dtype)
def _canonicalize_dimension(dim: DimSize) -> DimSize:
# Dimensions are most commonly integral (by far), so we check that first.
try:
return operator.index(dim)
except TypeError as e:
type_error = e
if is_dim(dim):
return dim
else:
raise type_error
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of canonical dimension values.
"""
if isinstance(shape, int):
shape = shape,
try:
return tuple(unsafe_map(_canonicalize_dimension, shape))
except TypeError:
pass
raise _invalid_shape_error(shape, context)
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
Args:
d: a Python value that represents a dimension.
Returns:
A canonical dimension value.
"""
return canonicalize_shape((d,), context)[0]
def _invalid_shape_error(shape: Shape, context: str=""):
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
f"got {shape}.")
if context:
msg += f" {context}."
if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not is_concrete(x) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
for x in shape:
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
msg += x._origin_msg()
return TypeError(msg)
class ShardingTypeError(Exception):
pass
# TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values
# passed to primitives are always have avals, etc i.e. they are canonical.
def canonicalize_value(val):
try:
aval = get_aval(val)
except TypeError:
return val
if not isinstance(aval, ShapedArray):
return val
if aval.sharding.mesh.empty:
return val
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh == aval.sharding.mesh:
return val
# TODO(yashkatariya): Casting to Explicit is not yet allowed. Maybe we need
# cast_and_slice_p for it since shape might change?
# Atleast 1 mesh axis should be Manual and all other axes should be
# Manual or Auto to allow casting.
if cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual:
if aval.sharding.mesh.are_all_axes_auto:
from jax._src.pjit import reshard # pytype: disable=import-error
return reshard(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim)))
elif aval.sharding.mesh._any_axis_explicit:
raise NotImplementedError(
"Closing over inputs to shard_map where the input is sharded on"
" `Explicit` axes is not implemented. As a workaround, please pass"
" those inputs as an argument to shard_map. Got input with shape"
f" {aval.str_short(True, True)}")
return val
def get_cur_mesh_sharding(spec=None):
spec = P() if spec is None else spec
return NamedSharding(mesh_lib.get_abstract_mesh(), spec)
def _make_lengths_same(sharding, ndim):
pspec = sharding.spec
if ndim > len(pspec):
return sharding.update(spec=pspec._normalized_spec_for_aval(ndim))
if ndim < len(pspec):
assert all(s is None for s in pspec[ndim:]), (ndim, pspec)
return sharding.update(spec=P(*pspec[:ndim], unreduced=pspec.unreduced,
reduced=pspec.reduced))
assert False, "unreachable"
def modify_spec_for_auto_manual(spec, mesh) -> P:
new_spec = [] # type: ignore
# PartitionSpec can only mention mesh axes that are Explicit.
for s in spec:
if s is None:
new_spec.append(s) # type: ignore
elif isinstance(s, tuple):
new_spec.append(tuple(
p for p in s if mesh._name_to_type[p] == AxisType.Explicit))
else:
new_spec.append(s if mesh._name_to_type[s] == AxisType.Explicit else None) # type: ignore
# Unreduced and reduced can mention mesh axes that are Explicit and Manual.
new_unreduced = {u for u in spec.unreduced
if mesh._name_to_type[u] != AxisType.Auto}
new_reduced = {u for u in spec.reduced
if mesh._name_to_type[u] != AxisType.Auto}
return P(*new_spec, unreduced=new_unreduced, reduced=new_reduced)
def remove_size_one_mesh_axis(spec, mesh) -> P:
new_spec = [] # type: ignore
for s in spec:
if s is None:
new_spec.append(s) # type: ignore
elif isinstance(s, tuple):
new_spec.append(tuple(i for i in s if mesh.shape[i] != 1))
else:
new_spec.append(None if mesh.shape[s] == 1 else s) # type: ignore
return P(*new_spec, unreduced=spec.unreduced, reduced=spec.reduced)
def _maybe_modify_sharding(sharding, ndim):
if len(sharding.spec) == 0 or all(s is None for s in sharding.spec):
out = sharding
elif sharding.mesh.are_all_axes_explicit:
out = sharding
else:
out = sharding.update(spec=modify_spec_for_auto_manual(
sharding.spec, sharding.mesh))
if config.remove_size_one_mesh_axis_from_type.value:
out = out.update(spec=remove_size_one_mesh_axis(out.spec, out.mesh))
if len(out.spec) != ndim:
out = _make_lengths_same(out, ndim)
return out
def _check_divisibility(sharding, shape):
mesh = sharding.mesh
for dim, (spec, sh) in enumerate(zip(sharding.spec, shape)):
if spec is None:
continue
spec = spec if isinstance(spec, tuple) else (spec,)
size = math.prod(mesh.shape[s] for s in spec)
_, remainder = divmod(sh, size)
if remainder != 0:
raise ValueError(
f"Sharding spec {spec} implies that array axis {dim} is partitioned"
f" {size} times, but does not evenly divide the dimension size {sh}."
f" Got shape: {shape} and sharding {sharding}")
@cache(max_size=4096,
trace_context_in_key=lambda: config.remove_size_one_mesh_axis_from_type.value)
def get_sharding(sharding, shape):
"""Modifies and checks the sharding.
Some modifications/checks include:
* Making the length of specs the same as ndim
* If a mesh axis is mentioned in pspec is Auto/Manual, replace it with None
* Checking for len(spec)-ndim match
* Checking if the mesh is an AbstractMesh.
"""
ndim = len(shape)
if sharding is None:
return NamedSharding(mesh_lib.empty_abstract_mesh, P(*[None] * ndim))
out_s = _maybe_modify_sharding(sharding, ndim)
if len(out_s.spec) != ndim:
raise ValueError(
"Length of sharding.spec must be equal to aval's ndim. Got"
f" sharding.spec {out_s.spec}, aval.ndim {ndim} and sharding {out_s}")
if not isinstance(out_s.mesh, mesh_lib.AbstractMesh):
raise ValueError("Mesh of an aval must be an AbstractMesh. "
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
_check_divisibility(out_s, shape)
assert out_s.memory_kind is None
return out_s
@cache(max_size=4096,
trace_context_in_key=lambda: config.remove_size_one_mesh_axis_from_type.value)
def get_vma(vma, sharding):
mesh = sharding.mesh
spec = sharding.spec
if mesh.empty:
assert not vma, vma
return vma
axis_env = get_axis_env()
for i in vma:
if axis_env.axis_exists(i) and i not in mesh._name_to_type:
continue
if mesh._name_to_type[i] != AxisType.Manual:
raise ValueError(
"Axes mentioned in `vma` field of ShapedArray should"
f" be of type `Manual`. Got axis: {i} of type {mesh._name_to_type[i]}")
if config.remove_size_one_mesh_axis_from_type.value:
vma = frozenset(i for i in vma if mesh.shape[i] != 1)
if vma & spec.unreduced:
raise ValueError(
f"vma and unreduced cannot have common mesh axes. Got {vma=} and"
f" unreduced={spec.unreduced}")
if vma & spec.reduced:
raise ValueError(
f"vma and reduced cannot have common mesh axes. Got {vma=} and"
f" reduced={spec.reduced}")
assert isinstance(vma, frozenset)
return vma
def get_memory_space(memory_space):
assert isinstance(memory_space, MemorySpace)
return memory_space
class ShapedArray(AbstractValue):
# inherits slots from parent
__slots__ = ['shape', 'dtype', 'weak_type', 'sharding', 'vma', 'memory_space']
array_abstraction_level = 2
def __init__(self, shape, dtype, weak_type=False, *, sharding=None,
vma: frozenset[AxisName] = frozenset(),
memory_space: MemorySpace = MemorySpace.Device):
self.shape = canonicalize_shape(shape)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
self.sharding = get_sharding(sharding, self.shape)
# short for varying_manual_axes. See docs at
# https://docs.jax.dev/en/latest/notebooks/shard_map.html#tracking-how-values-vary-over-manual-mesh-axes-and-check-vma-true
self.vma = get_vma(vma, self.sharding)
# See description of https://github.com/jax-ml/jax/pull/30556
self.memory_space = get_memory_space(memory_space)
def lower_val(self, val): return [val]
def raise_val(self, val): return val
def lo_ty(self): return [self]
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
if 'sharding' not in kwargs:
kwargs['sharding'] = self.sharding
if 'vma' not in kwargs:
kwargs['vma'] = self.vma
if 'memory_space' not in kwargs:
kwargs['memory_space'] = self.memory_space
return ShapedArray(shape, dtype, weak_type, **kwargs)
ndim = property(lambda self: len(self.shape))
size = property(lambda self:
0 if any(type(d) is int and d == 0 for d in self.shape)
else math.prod(self.shape))
broadcast: ClassVar[aval_method | None] = None
transpose: ClassVar[aval_method | None] = None
reshape: ClassVar[aval_method | None] = None
_iter: ClassVar[staticmethod | None] = None
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type
and self.sharding == other.sharding
and self.vma == other.vma
and self.memory_space == other.memory_space)
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.shape, self.dtype, self.weak_type, self.sharding,
self.vma, self.memory_space))
def __ne__(self, other):
return not self == other
def __repr__(self):
wt_str = ", weak_type=True" if self.weak_type else ""
return f'ShapedArray({self.str_short()}{wt_str})'
def __str__(self):
wt_str = "~" if self.weak_type else ""
return f'{wt_str}{self.str_short()}'
def to_tangent_aval(self):
return ShapedArray(
self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, sharding=self.sharding, vma=self.vma,
memory_space=self.memory_space)
def to_cotangent_aval(self):
dtype = primal_dtype_to_tangent_dtype(self.dtype)
sharding = primal_sharding_to_cotangent_sharding(self.sharding)
return ShapedArray(
self.shape, dtype, self.weak_type, sharding=sharding, vma=self.vma,
memory_space=self.memory_space)
def str_short(self, short_dtypes=False, mesh_axis_types=False):
return str_short_aval(
self.shape, self.dtype, self.sharding.mesh, self.sharding.spec,
self.vma, self.memory_space, short_dtypes, mesh_axis_types)
def _len(self, ignored_tracer):
try:
return self.shape[0]
except IndexError as err:
raise TypeError("len() of unsized object") from err # same as numpy error
def update_vma(self, vma):
return self.update(vma=vma)
def update_weak_type(self, weak_type):
return self.update(weak_type=weak_type)
_bool = concretization_function_error(bool)
_int = concretization_function_error(int, True)
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
_hex = concretization_function_error(hex)
_oct = concretization_function_error(oct)
_index = concretization_function_error(operator.index)
def _get_shape_sharding_str(shape, spec):
out = []
for s1, s2 in zip(shape, spec):
if s2 is None:
out.append(f"{s1}")
elif isinstance(s2, tuple):
ss = ','.join(s for s in s2)
out.append(f"{s1}@({ss})")
else:
out.append(f"{s1}@{s2}")
return ','.join(out)
@cache(max_size=1024, trace_context_in_key=False)
def _axis_types_dict(mesh):
if not mesh.axis_names:
return {}
d = defaultdict(list)
for n, t in safe_zip(mesh.axis_names, mesh.axis_types):
d[t].append(n)
return {t: tuple(n) for t, n in d.items()}
def str_short_aval(shape, dtype, mesh, spec, vma, memory_space,
short_dtypes=False, mesh_axis_types=False) -> str:
dt_str = dtypes.short_dtype_name(dtype) if short_dtypes else dtype.name
dt_str = dt_str.replace('void', 'float0')
shapestr = _get_shape_sharding_str(shape, spec)
mesh_axes = f'({_axis_types_dict(mesh)})' if mesh_axis_types else ''
vma_ur = _vma_ur_str(vma, spec.unreduced, spec.reduced, mesh)
ms_str = ("" if memory_space == MemorySpace.Device else
f"<{memory_space.name.lower()}>")
return f'{dt_str}{ms_str}[{shapestr}]{vma_ur}{mesh_axes}'
def _create_str(x, prefix):
x_str = f"{','.join(i for i in x)}"
x_str = x_str if len(x) == 1 else f"({x_str})"
return f"{prefix}:{x_str}, "
def order_wrt_mesh(mesh, x):
return tuple(a for a in mesh.axis_names if a in x)
def _vma_ur_str(vma, unreduced, reduced, mesh):
if not vma and not unreduced and not reduced:
return ''
vma_str = _create_str(order_wrt_mesh(mesh, vma), 'V') if vma else ''
ur_str = _create_str(unreduced, 'U') if unreduced else ''
red_str = _create_str(reduced, 'R') if reduced else ''
m_str = f"{vma_str}{ur_str}{red_str}".rstrip(', ')
return f"{{{m_str}}}"
def primal_dtype_to_tangent_dtype(primal_dtype):
if isinstance(primal_dtype, dtypes.ExtendedDType):
return primal_dtype._rules.tangent_dtype(primal_dtype)
elif not dtypes.issubdtype(primal_dtype, np.inexact):
return dtypes.float0
else:
return primal_dtype
def primal_spec_to_cotangent_spec(spec):
return P(*spec, unreduced=spec.reduced, reduced=spec.unreduced)
def primal_sharding_to_cotangent_sharding(sharding):
return sharding.update(spec=primal_spec_to_cotangent_spec(sharding.spec))
############################## pvary #################################
# Invariant -> Variant no-op cast
def pvary(x, axis_name):
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
if not axis_name:
return x
# TODO(yashkatariya): Maybe move `order_wrt_mesh` to pvary_transpose_rule?
# Across hosts we should have the same order of axes during lowering time and
# pvary_p transposes to psum_invariant_p.
cur_mesh = mesh_lib.get_abstract_mesh()
new_axes = axes if cur_mesh.empty else order_wrt_mesh(cur_mesh, axes)
assert set(new_axes) == set(axes)
del axes
return tree_map(lambda leaf: pvary_p.bind(leaf, axes=new_axes), x)
pvary_p = Primitive('pvary')
####################### reduced_vary_cast #############################
# Reduced -> Varying no-op cast
def reduced_vary_cast(x, axis_name):
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
if not axis_name:
return x
return tree_map(lambda leaf: reduced_vary_cast_p.bind(leaf, axes=axes), x)
reduced_vary_cast_p = Primitive('reduced_vary_cast_p')
#######################################################################
def check_unreduced_args(args, name):
for a in args:
if a.sharding.spec.unreduced:
raise ValueError(
f"{name} cannot accept args which are unreduced. Got"
f" {a.str_short(True)}")
if a.sharding.spec.reduced:
raise ValueError(
f"{name} cannot accept args which are reduced. Got"
f" {a.str_short(True)}")
def standard_insert_pvary(*args):
if not config._check_vma.value:
return args
if not args:
return args
in_vma = [aval.vma if isinstance(aval := get_aval(a), ShapedArray)
else frozenset() for a in args]
in_reduced = [aval.sharding.spec.reduced
if isinstance(aval := get_aval(a), ShapedArray) else frozenset()
for a in args]
out_vma = frozenset.union(*in_vma)
out = []
for arg, src_vma, src_reduced in zip(args, in_vma, in_reduced):
if (isinstance(get_aval(arg), ShapedArray) and
(rest_vma := out_vma - src_vma)):
# TODO(yashkatariya): Handle partial reduced_vary_cast and partial pvary.
# Will need more changes to pvary to allow such partialness.
if src_reduced == rest_vma:
out.append(
reduced_vary_cast(arg, tuple(n for n in out_vma if n in rest_vma)))
else:
out.append(pvary(arg, tuple(n for n in out_vma if n in rest_vma)))
else:
out.append(arg)
return out
def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]:
if not config._check_vma.value:
return frozenset()
avals = tuple(a for a in avals if a is not abstract_token)
if not avals:
return frozenset()
vma, *vmas = (a.vma for a in avals)
if not all(vma == vma_ for vma_ in vmas):
raise ValueError(
f'Primitive {prim_name} requires varying manual axes '
f'to match, but got {[vma, *vmas]}. Please open an issue at '
'https://github.com/jax-ml/jax/issues and as a temporary '
'workaround pass the check_vma=False argument to `jax.shard_map`')
return vma
@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
bound: int
@property
def type(self) -> type:
return dtypes.extended
@property
def name(self) -> str:
return f'bint{{≤{self.bound}}}'
def __str__(self) -> str:
return self.name
AxisSize = Union[int, Tracer, Var]
class RefMeta(type):
def __instancecheck__(self, inst):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
return (super().__instancecheck__(inst) or
isinstance(inst, Tracer) and isinstance(inst.aval, AbstractRef))
class Ref(metaclass=RefMeta):
"""Mutable array reference.
In most cases this should not be constructed directly, but rather
via :func:`jax.ref.new_ref`. For examples of how this can be
used, refer to the `Ref guide`_.
.. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html
"""
_aval: AbstractValue
_refs: PyTree # list of ArrayRefImpl
def __init__(self, aval, refs):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
assert isinstance(aval, AbstractRef)
self._aval = aval
self._refs = refs
# TODO(mattjj): update repr to handle non-lojax refs
def __repr__(self) -> str: return 'Ref' + repr(self._refs._buf)[5:]
# forward type-level info to aval
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
ndim = property(lambda self: len(self._aval.shape))
dtype = property(lambda self: self._aval.dtype)
# get operations from aval, munging the name
def __getitem__(self, idx): return self._aval._getitem(self, idx)
def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x)
def __len__(self) -> int: return self._aval._len(self)
def addupdate(self, x, idx=()): return self._aval._addupdate(self, idx, x)
# some attributes/methods only work for lojax refs
sharding = property(lambda self: self._refs._buf.sharding)
format = property(lambda self: self._refs._buf.format)
committed = _committed = property(lambda self: True)
def unsafe_buffer_pointer(self): return self._refs._buf.unsafe_buffer_pointer()
@property
def at(self): raise NotImplementedError() # TODO(mattjj)
class ArrayRefImpl:
_aval: ShapedArray
_buf: Array # mutable field
def __init__(self, aval, buf):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
assert isinstance(aval, AbstractRef) and isinstance(aval.inner_aval, ShapedArray)
self._aval = aval
self._buf = buf
pytype_aval_mappings[Ref] = lambda x: x._aval
dtypes.canonicalize_value_handlers[Ref] = lambda x: x
def new_ref(init_val: Any, *, memory_space: Any = None, kind: Any = None):
"""Create a mutable array reference with initial value ``init_val``.
For more discussion, see the `Ref guide`_.
Args:
init_val: A :class:`jax.Array` representing the initial state
of the buffer.
memory_space: An optional memory space attribute for the Ref.
Returns:
A :class:`jax.ref.Ref` containing a reference to a mutable buffer.
.. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html
"""
return ref_p.bind(init_val, memory_space=memory_space, kind=kind)
ref_p = Primitive('new_ref')
ref_p.is_effectful = lambda params: True # type: ignore
ref_p.ref_primitive = True
ref_p.is_high = lambda aval, *, memory_space, kind: aval.is_high # type: ignore
def _ref_to_lojax(init_val, *, memory_space, kind):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
val_ty = typeof(init_val)
hival_of_refs = val_ty.raise_val(*map(new_ref, val_ty.lower_val(init_val))) # type: ignore
aval = AbstractRef(typeof(init_val))
return Ref(AbstractRef(val_ty), hival_of_refs)
ref_p.to_lojax = _ref_to_lojax # type: ignore
class InternalMutableArrayEffect(effects.Effect):
pass
array_ref_effect = internal_mutable_array_effect = InternalMutableArrayEffect()
effects.control_flow_allowed_effects.add_type(InternalMutableArrayEffect)
effects.remat_allowed_effects.add_type(InternalMutableArrayEffect)
@ref_p.def_effectful_abstract_eval
def _ref_abstract_eval(init_aval, *, memory_space: Any, kind: Any):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
return (AbstractRef(init_aval, memory_space=memory_space, kind=kind),
{internal_mutable_array_effect})
@ref_p.def_impl
def _ref_impl(init_val, *, memory_space: Any, kind: Any):
if memory_space is not None:
raise NotImplementedError(
"array ref with memory space only works inside of a `jit`.")
from jax._src.state.types import AbstractRef # pytype: disable=import-error
from jax._src.lax.lax import _array_copy # pytype: disable=import-error
aval = AbstractRef(typeof(init_val), kind=kind)
return Ref(aval, ArrayRefImpl(aval, _array_copy(init_val)))
def freeze(ref: Ref) -> Array:
"""Invalidate a given reference and return its final value.
For more information about mutable array references, refer to the
`Ref guide`_.
Args:
ref: A :class:`jax.ref.Ref` object.
Returns:
A :class:`jax.Array` containing the contents of ``ref``.
Examples:
>>> import jax
>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> ref[3] = 100
>>> ref
Ref([ 0, 1, 2, 100, 4], dtype=int32)
>>> jax.ref.freeze(ref)
Array([ 0, 1, 2, 100, 4], dtype=int32)
.. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html
"""
return freeze_p.bind(ref)
freeze_p = Primitive('freeze')
freeze_p.is_effectful = lambda params: True # type: ignore
freeze_p.ref_primitive = True
@freeze_p.def_effectful_abstract_eval
def freeze_abstract_eval(ref_aval):
return ref_aval.inner_aval, {internal_mutable_array_effect}
@freeze_p.def_impl
def _freeze_impl(ref):
return ref[()]
def accum_grad_in_ref(x):
return accum_grad_in_ref_p.bind(x)
accum_grad_in_ref_p = Primitive('accum_grad_in_ref')
accum_grad_in_ref_p.is_high = lambda *_: True # type: ignore
accum_grad_in_ref_p.to_lojax = lambda x: x # type: ignore
accum_grad_in_ref_p.def_abstract_eval(lambda x: x) # type: ignore
accum_grad_in_ref_p.def_impl(lambda x: x) # type: ignore
class AbstractToken(AbstractValue):
def str_short(self, short_dtypes=False, mesh_axis_types=False): return 'Tok'
def to_tangent_aval(self): return self
def to_cotangent_aval(self): return self
abstract_token: AbstractToken = AbstractToken()
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
def get_token_aval():
return ShapedArray((0,), np.dtype(np.bool_), sharding=None)
# Concrete token object
class Token:
# The underlying data wrapped by the token, could be used to threaded in and
# out of computations to build up data dependency.
_buf: Array
def __init__(self, buf):
self._buf = buf
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token
dtypes.canonicalize_value_handlers[Token] = lambda x: x
### Operations on shapes and dimension sizes.
class InconclusiveDimensionOperation(Exception):
"""Raised when we cannot conclusively compute with symbolic dimensions."""
def is_symbolic_dim(v: Any) -> bool:
"""Checks if a value is a symbolic dimension used for shape polymorphism.
This should be used very rarely, because symbolic dimensions overload all
operators, and should just work.
"""
return hasattr(v, "dimension_as_value")
def is_constant_dim(d: DimSize) -> bool:
# Whether the dimension is a static integer constant.
# Try using a fast path for non-concrete Tracers.
if isinstance(d, Tracer) and not is_concrete(d):
return False
try:
operator.index(d)
return True
except:
return False
def is_dim(v: Any) -> bool:
return is_symbolic_dim(v) or is_constant_dim(v)
def is_constant_shape(s: Shape) -> bool:
# Whether the shape is a static constant.
return all(is_constant_dim(d) for d in s)
def definitely_equal_one_of_dim(d1: DimSize, dlist: Sequence[DimSize]) -> bool:
return any(definitely_equal(d1, d) for d in dlist)
def definitely_equal_shape(s1: Shape, s2: Shape) -> bool:
"""Check that two shapes are guaranteed to be element-wise equal.
In presence of dynamic shapes may return False even when the shapes may
be equal at runtime.
"""
return (len(s1) == len(s2) and
all(unsafe_map(definitely_equal, s1, s2)))
def divide_shape_sizes(s1: Shape, s2: Shape) -> DimSize:
"""Returns an integer "i" s.t., i * size(s2) == size(s1).
Raises InconclusiveDimensionOperation if there is no such integer."""
sz1 = math.prod(s1)
sz2 = math.prod(s2)
if definitely_equal(sz1, sz2): # Takes care of sz1 and sz2 being 0
return 1
q, r = divmod(sz1, sz2)
if isinstance(r, Tracer) or r != 0:
raise InconclusiveDimensionOperation(
f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}. "
f"The remainder {r} should be 0.")
return q
def cancel_divide_tracers(num, denom):
partition = lambda l: partition_list([isinstance(d, Tracer) for d in l], l)
num, num_tracers = partition(num)
denom, denom_tracers = partition(denom)
if num_tracers or denom_tracers:
factor = _cancel_divide(num_tracers, denom_tracers)
if factor is not None:
size1 = math.prod(num)
size2 = math.prod(denom)
if size1 == size2 or size2 != 0:
return factor * (size1 // size2 if size1 != size2 else 1)
def _cancel_divide(num, denom):
num = list(num)
for a in denom:
i = next((i for i, b in enumerate(num) if definitely_equal(a, b)), None)
if i is None:
break # couldn't cancel
del num[i]
else:
return math.prod(num)
def is_empty_shape(s: Shape) -> bool:
return any(definitely_equal(d, 0) for d in s)
def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
"""max(0, 1 + dilation * (d - 1)).
Assumes dilation >= 1.
"""
if definitely_equal(dilation, 1): # fast path
return d
return max_dim(1 + dilation * (d - 1), 0)
def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
"""max(0, (d - window_size) // window_stride + 1)
If d < window_size, returns 0.
We assume window_size >= 1 and window_stride >= 1.
"""
# If d < window_size then (d - window_size) // window_stride < 0
return max_dim((d - window_size) // window_stride + 1, 0)
def min_dim(d1: DimSize, d2: DimSize) -> DimSize:
"""Like min(d1, d2) but for both constant and symbolic dimensions."""
d1_is_constant = is_constant_dim(d1)
if d1_is_constant and is_constant_dim(d2):
return min(d1, d2)
d1 = concrete_dim_or_error(d1, "argument `d1` of `core.min_dim`")
d2 = concrete_dim_or_error(d2, "argument `d2` of `core.min_dim`")
if d1_is_constant:
return d2.rmin(d1)
else:
return d1.min(d2)
def max_dim(d1: DimSize, d2: DimSize) -> DimSize:
"""Like max(d1, d2) but for both constant and symbolic dimensions."""
d1_is_constant = is_constant_dim(d1)
if d1_is_constant and is_constant_dim(d2):
return max(d1, d2)
d1 = concrete_dim_or_error(d1, "argument `d1` of `core.max_dim`")
d2 = concrete_dim_or_error(d2, "argument `d2` of `core.max_dim`")
if d1_is_constant:
return d2.rmax(d1)
else:
return d1.max(d2)
def dimension_as_value(d: DimSize):
"""Turns a dimension size into a JAX array.
This is the identity function for constant dimensions.
Has the same abstract value as Python constants.
"""
if isinstance(d, (int, Tracer, np.int32, np.int64)): return d
# For shape_poly._DimPolynomial
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
return operator.index(d)
def canonicalize_slice(
s: slice,
axis_size: DimSize
) -> tuple[DimSize, DimSize, DimSize]:
"""Computes the start index, step, and size of the slice `x[s]`.
This is similar to `s.indices(axis_size)`, except that it returns
`(start, step, size)`, and it works when the slice and/or the
`axis_size` are symbolic.
See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
"""
def convert_to_index(d: DimSize) -> DimSize:
# Convert np.array and jax.Array to int, leave symbolic dimensions alone
try:
return operator.index(d)
except:
return d
# Must resolve statically if step is {<0, ==0, >0}
step = convert_to_index(s.step) if s.step is not None else 1
try:
if step == 0:
raise ValueError("slice step cannot be zero")
step_gt_0 = (step > 0)
except InconclusiveDimensionOperation as e:
raise InconclusiveDimensionOperation(
f"In slice with non-constant elements the step ({step}) must " +
f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
def clamp_index(i: DimSize, which: str):
try:
i_ge_0 = (i >= 0)
except InconclusiveDimensionOperation as e:
raise InconclusiveDimensionOperation(
f"In slice with non-constant elements the {which} ({i}) must " +
f"be resolved statically if it is >= 0.\nDetails: {e}")
if i_ge_0:
if step_gt_0:
return min_dim(axis_size, i)
else:
return min_dim(axis_size - 1, i)
else:
if step_gt_0:
return max_dim(0, axis_size + i)
else:
return max_dim(-1, axis_size + i)
if s.start is None:
start = 0 if step_gt_0 else axis_size - 1
else:
start = clamp_index(convert_to_index(s.start), "start")
if s.stop is None:
stop = axis_size if step_gt_0 else -1
else:
stop = clamp_index(convert_to_index(s.stop), "stop")
gap = step if step_gt_0 else - step
distance = (stop - start) if step_gt_0 else (start - stop)
slice_size = max_dim(0, distance + gap - 1) // gap
return start, step, slice_size
class SomeTracer:
__slots__ = ()
def __repr__(self): return "[dynamic]"
def replace_tracer_for_error_message(obj):
# TODO(mattjj): Many ideas for improving this. Crawl the stack and see if
# there are user variables whose value is == to this object? Or search
# parameters of functions being transformed, at least? Or at least assign
# short unique ids to them?
if isinstance(obj, Tracer):
return SomeTracer()
else:
return obj
def evaluate_shape(shape: Shape, dim_vars: Sequence[str],
*dim_values: Array) -> Sequence[Array]:
"""Evaluates a shape possibly containing non-constants.
Args:
shape: the shape to evaluate.
dim_vars: the dimension variables names that may appear in `shape`.
dim_values: the dimension values corresponding to `dim_vars`.
Returns:
a tuple of JAX values corresponding to `shape`, of type
`dim_value_dtype`.
"""
env = dict(zip(dim_vars, dim_values))
def eval_one_dim(d: DimSize):
try:
return operator.index(d)
except:
# Is a _DimExpr
return d._evaluate(env) # type: ignore
return tuple(eval_one_dim(d) for d in shape)
def dim_value_dtype():
"""The dtype to be used for dimension values."""
return dtypes.default_int_dtype()
def dim_constant(ct: int):
dtype = dim_value_dtype()
assert dtype in (np.int32, np.int64)
if dtype == np.int32:
return np.int32(ct)
elif dtype == np.int64:
return np.int64(ct)
def dim_value_aval() -> AbstractValue:
return ShapedArray((), dim_value_dtype(), weak_type=True, sharding=None)
# ------------------- Call -------------------
class CallPrimitive(Primitive):
multiple_results = True
call_primitive = True
def bind(self, *args, **params):
return self._true_bind(*args, **params)
def bind_with_trace(self, trace, fun_and_args, params):
fun = fun_and_args[0]
args = fun_and_args[1:]
return trace.process_call(self, fun, args, params)
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ())
return [subfun], new_params
def call_impl(f: lu.WrappedFun, *args, **params):
del params # params parameterize the call primitive, not the function
return f.call_wrapped(*args)
call_p: CallPrimitive = CallPrimitive('call')
call = call_p.bind
call_p.def_impl(call_impl)
class ClosedCallPrimitive(CallPrimitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr: ClosedJaxpr = new_params.pop('call_jaxpr')
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
debug_info=jaxpr.jaxpr.debug_info)
return [subfun], new_params
closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
closed_call_p.def_impl(call_impl)
closed_call_p.def_effectful_abstract_eval(
lambda *_, call_jaxpr: (call_jaxpr.out_avals, eqn_effects(call_jaxpr)))
# ------------------- Map -------------------
class MapPrimitive(Primitive):
multiple_results = True
map_primitive = True
def bind(self, *args, **params):
return self._true_bind(*args, **params)
def bind_with_trace(self, trace, fun_and_args, params):
fun: lu.WrappedFun = fun_and_args[0]
args = fun_and_args[1:]
assert len(params['in_axes']) == len(args)
return trace.process_map(self, fun, args, params)
def process(self, trace, fun, tracers, params):
return trace.process_map(self, fun, tracers, params)
def get_bind_params(self, params):
new_params = dict(params)
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ())
axes = new_params.pop('out_axes')
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params
def mapped_aval(size: AxisSize, axis: int | None,
aval: AbstractValue) -> AbstractValue:
handler, _ = aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis, aval)
else:
raise TypeError(f"no mapping handler for {aval} of type {type(aval)}")
def unmapped_aval(size: AxisSize, axis: int | None,
aval: AbstractValue, explicit_mesh_axis=None) -> AbstractValue:
_, handler = aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis, explicit_mesh_axis, aval)
else:
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
def _map_shaped_array(
size: int, axis: int | None, aval: ShapedArray) -> ShapedArray:
assert axis is None or aval.shape[axis] == size
if axis is None:
return aval
aval_s = aval.sharding
sharding = aval_s.update(
spec=aval_s.spec.update(partitions=tuple_delete(aval_s.spec, axis)))
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
weak_type=aval.weak_type, sharding=sharding, vma=aval.vma,
memory_space=aval.memory_space)
def _unmap_shaped_array(
size: int, axis: int | None, explicit_mesh_axis, aval: ShapedArray
) -> ShapedArray:
if axis is None:
return aval
elif type(axis) is int:
aval_s = aval.sharding
sharding = aval_s.update(spec=aval_s.spec.update(partitions=tuple_insert(
aval_s.spec, axis, explicit_mesh_axis)))
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
weak_type=aval.weak_type, sharding=sharding,
vma=aval.vma, memory_space=aval.memory_space)
else:
raise TypeError(axis)
AvalMapHandlerPair = tuple[Callable, Callable]
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
AbstractToken: (lambda _, __, a: a, lambda _, __, ____, a: a)
}
# When a mapped function is given no axis name, we generate a name object based
# on the id of the function object. Collisions aren't important because this
# name can't be used in collectives, as user code never gets a ref to this
# object. We don't want to use the function object itself because that might
# persist references to the function object.
# TODO(mattjj): revisit this unique axis name strategy
@total_ordering
class _TempAxisName:
def __init__(self, obj):
self.id = id(obj)
def __repr__(self):
return f'<axis {hex(self.id)}>'
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
return type(other) is _TempAxisName and self.id == other.id
def __lt__(self, other):
return type(other) is _TempAxisName and self.id < other.id
@dataclass(frozen=True)
class NamedAxisEffect(effects.Effect):
"""A side-effect introducing a new named axis into the current scope."""
name: AxisName
effects.control_flow_allowed_effects.add_type(NamedAxisEffect)
effects.custom_derivatives_allowed_effects.add_type(NamedAxisEffect)
effects.lowerable_effects.add_type(NamedAxisEffect)
effects.remat_allowed_effects.add_type(NamedAxisEffect)
def filter_named_axis_effects(
effects: Effects, names: Collection[AxisName]
) -> Effects:
return {e for e in effects
if not isinstance(e, NamedAxisEffect) or e.name not in names}
def remove_named_axis_effects(
jaxpr: Jaxpr, names: Collection[AxisName]
) -> Jaxpr:
if not names or not jaxpr.effects:
return jaxpr
return jaxpr.replace(effects=filter_named_axis_effects(jaxpr.effects, names))
def used_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr):
return {e.name for e in jaxpr.effects if isinstance(e, NamedAxisEffect)}
def replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: Effects):
return _replace_jaxpr_effects(jaxpr, frozenset(effects))
@weakref_lru_cache
def _replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: frozenset[Effect]):
return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(effects=set(effects)))
# ------------------- Jaxpr checking -------------------
def typecheck(aval: AbstractValue, x) -> bool:
return typecompat(aval, get_aval(x))
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
"""Determine whether `aval` conforms to `aval_ref`. Ignores weak_type."""
try:
return typematch(aval_ref, aval)
except TypeError:
return False
def typematch(t1: AbstractValue, t2: AbstractValue,
only_shape_shd_check: bool = False) -> bool:
"""Determine whether `t1` and `t2` are equivalent. Ignores weak_type."""
t1 = t1.normalize()
t2 = t2.normalize()
from jax._src.state.types import AbstractRef # pytype: disable=import-error
if t1 == t2:
return True
elif isinstance(t1, ShapedArray) and isinstance(t2, ShapedArray):
if only_shape_shd_check:
return cmp_shape_sharding_vma(t1, t2)
return (t1.dtype == t2.dtype and cmp_shape_sharding_vma(t1, t2) and
t1.memory_space == t2.memory_space)
elif isinstance(t1, AbstractRef) and isinstance(t2, AbstractRef):
# We want to use the regular typecheck for ShapedArray here.
return (typematch(t1.inner_aval, t2.inner_aval, only_shape_shd_check) and # type: ignore
(t1.memory_space is None or t2.memory_space is None or # type: ignore
t1.memory_space == t2.memory_space)) # type: ignore
else:
return False
def cmp_shape_sharding_vma(t1, t2):
# TODO(yashkatariya): Expand this to Manual and Auto mode.
# See https://github.com/jax-ml/jax/issues/26474
if (not t1.sharding.mesh.empty and not t2.sharding.mesh.empty and
(t1.sharding.mesh._any_axis_explicit or
t2.sharding.mesh._any_axis_explicit)):
shd_eq = t1.sharding == t2.sharding
else:
shd_eq = True
return (shd_eq and definitely_equal_shape(t1.shape, t2.shape) and
t1.vma == t2.vma)
def aval_mismatch_extra(a1: AbstractValue, a2: AbstractValue) -> str:
assert not typematch(a1, a2)
if isinstance(a1, ShapedArray) and isinstance(a2, ShapedArray):
mismatches = []
if a1.dtype != a2.dtype:
mismatches.append('the dtypes do not match')
if a1.shape != a2.shape:
mismatches.append('the shapes do not match')
if a1.vma != a2.vma:
mismatches.append('the varying manual axes do not match')
# TODO(yashkatariya,mattjj): add check for sharding-in-types mismatch
if len(mismatches) == 0:
return ''
elif len(mismatches) == 1:
return ', so ' + mismatches[0]
else:
return ', so ' + ', '.join(mismatches[:-1]) + ', and ' + mismatches[-1]
return ''
class JaxprTypeError(TypeError): pass
custom_typechecks: dict[Primitive, Callable] = {}
def _check_closed_call(_, *in_atoms, call_jaxpr):
in_avals = [x.aval for x in in_atoms]
if not all(map(typecompat, call_jaxpr.in_avals, in_avals)):
raise JaxprTypeError("Closed call in_avals mismatch")
return call_jaxpr.out_avals, eqn_effects(call_jaxpr)
custom_typechecks[closed_call_p] = _check_closed_call
def check_jaxpr(jaxpr: Jaxpr):
"""Checks well-formedness of a jaxpr.
Specifically, check that:
- variables that are read are bound beforehand
- variables are typed equally throughout a jaxpr
- variable type annotations are compatible with their binding expression
Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
otherwise.
"""
@functools.cache
def ctx_factory():
ctx = JaxprPpContext()
pp_settings = JaxprPpSettings()
try: pp_jaxpr(jaxpr, ctx, pp_settings) # side-effect on ctx, build variable names
except: pass
return ctx, pp_settings
try:
_check_jaxpr(ctx_factory, jaxpr)
except JaxprTypeError as e:
ctx, pp_settings = ctx_factory()
if len(e.args) == 2:
msg, eqnidx = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqnidx - 10, eqnidx + 10, ctx,
pp_settings))
else:
msg, = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20, ctx, pp_settings))
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
raise JaxprTypeError(msg) from None
# Run key reuse checker after validating jaxpr:
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error
check_key_reuse_jaxpr(jaxpr)
# A place to track the quasi-dynamic data associated with a variable during typechecking
@dataclass(frozen=True)
class MutableTypecheckVal:
aval : AbstractValue
mutable_qdd : MutableQuasiDynamicData
_ref_allocating_primitives = {ref_p}
def _check_jaxpr(
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
jaxpr: Jaxpr
) -> None:
env: dict[Var, Atom | MutableTypecheckVal] = {}
def read(x: Atom) -> Atom | MutableTypecheckVal:
# Check the type annotation is itself well-typed.
check_type(ctx_factory, env, x.aval)
if isinstance(x, Var):
# Check the variable is in-scope and consistently typed.
if x not in env:
ctx, _ = ctx_factory()
raise JaxprTypeError(f"Variable '{pp_var(x, ctx)}' not defined")
return env[x]
elif isinstance(x, Literal):
# Check that the literal matches its type annotation.
if not typecheck(x.aval, x.val):
ctx, _ = ctx_factory()
raise JaxprTypeError(
f"Literal value {x.val} does not match its type annotation "
f"{pp_aval(x.aval, ctx)}")
return x
else:
assert False, "syntactically invalid jaxpr"
def write(v: Var, a: AvalQDD) -> None:
aval, qdd = a.aval, a.qdd
assert isinstance(v, Var), "syntactically invalid jaxpr"
# Check the type annotation of the binder is itself well-typed.
check_type(ctx_factory, env, v.aval)
# Check that the variable is not already bound.
if v in env:
ctx, _ = ctx_factory()
raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' already bound")
# Check that the computed type is consistent with the binder annotation.
if not typematch(v.aval, aval):
ctx, _ = ctx_factory()
raise JaxprTypeError(
f"Value for variable '{pp_var(v, ctx)}' inconsistently typed "
f"as {pp_aval(aval, ctx)} for let-binder of type {pp_aval(v.aval, ctx)}")
# If the variable is not a DropVar, add it to the environment.
if not isinstance(v, DropVar):
if qdd is None:
env[v] = v
else:
env[v] = MutableTypecheckVal(aval, MutableQuasiDynamicData(qdd))
# # Don't return refs
if config.mutable_array_checks.value:
from jax._src.state.types import AbstractRef # pytype: disable=import-error
for v in jaxpr.outvars:
if isinstance(v.aval, AbstractRef):
raise JaxprTypeError("returned a ref!")
# Check type annotations on lambda binders.
for v in it.chain(jaxpr.constvars, jaxpr.invars):
check_type(ctx_factory, env, v.aval)
write(v, AvalQDD(v.aval, v.initial_qdd))
# Check each eqn.
sentinel = object()
in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))}
mut_arrays = set()
for eqn_idx, eqn in enumerate(jaxpr.eqns):
prim = eqn.primitive
try:
in_atoms = map(read, eqn.invars)
in_avals = [AvalMutableQDD(x.aval, x.mutable_qdd) if isinstance(x, MutableTypecheckVal)
else x.aval for x in in_atoms] # use in_atoms for dyn shapes
# Compute the type of the primitive application.
with eqn.ctx.manager:
if prim in custom_typechecks:
out_type, eqn_effects = custom_typechecks[prim](
ctx_factory, *in_atoms, **eqn.params)
elif prim.call_primitive:
out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms,
eqn.params)
elif prim.map_primitive:
out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals,
eqn.params)
else:
out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params)
# Check the computed effect type matches the eqn's annotation, and is
# included in the jaxpr's annotation.
if prim.ref_primitive:
if prim in _ref_allocating_primitives:
outvar, = eqn.outvars
in_idx[outvar] = None # type: ignore
mut_arrays.add(outvar)
if eqn.effects != eqn_effects:
raise JaxprTypeError("Inferred effects do not match equation effects. "
f"Equation effects: {eqn.effects}. "
f"Inferred effects: {eqn_effects}")
for eff in eqn.effects:
if isinstance(eff, effects.JaxprInputEffect):
eqn_invar = eqn.invars[eff.input_index]
if type(eqn_invar) is Literal or eqn_invar in mut_arrays:
continue
if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel:
raise JaxprTypeError(
"Invalid `JaxprInputEffect`: must correspond to a jaxpr invar")
jaxpr_effect = eff.replace(input_index=jaxpr_index)
if jaxpr_effect not in jaxpr.effects:
raise JaxprTypeError(
"Invalid `JaxprInputEffect`: must be present in jaxpr. "
f"{jaxpr_effect} is not in {jaxpr.effects}.")
elif isinstance(eff, NamedAxisEffect):
# It is valid for a primitive to discharge the named axis effect.
continue
elif eff not in jaxpr.effects:
raise JaxprTypeError("Equation effect not present in jaxpr effects. "
f"Equation effect: {eff}. "
f"Jaxpr effects: {jaxpr.effects}")
# Check out_type matches the let-binders' annotation (after substitution).
out_type = [t if isinstance(t, AvalQDD) else AvalQDD(t, None) for t in out_type]
foreach(write, eqn.outvars, out_type)
except JaxprTypeError as e:
ctx, settings = ctx_factory()
msg, = e.args
src = source_info_util.summarize(eqn.source_info)
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn, ctx, settings))),
f"from source: {src}"])
raise JaxprTypeError(msg, eqn_idx) from None
# Check there are no output refs
# TODO(mattjj): improve this error message
if config.mutable_array_checks.value:
from jax._src.state.types import AbstractRef # pytype: disable=import-error
for v in jaxpr.outvars:
if isinstance(v.aval, AbstractRef): raise TypeError("returned ref")
# TODO(mattjj): include output type annotation on jaxpr and check it here
foreach(read, jaxpr.outvars)
def check_type(
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
env: dict[Var, Atom | MutableTypecheckVal],
ty: AbstractValue,
) -> None:
return # Except in above case(s), all syntactic forms are valid
def check_eqn(prim, in_avals, params):
for jaxpr in jaxprs_in_params(params):
check_jaxpr(jaxpr)
out_avals, effects = prim.abstract_eval(*in_avals, **params)
if not prim.multiple_results:
out_avals = [out_avals]
return out_avals, effects
def _check_call(ctx_factory, prim, in_atoms, params):
if "call_jaxpr" not in params:
raise JaxprTypeError(
f"Call primitive {prim} missing 'call_jaxpr' parameter")
if isinstance(prim, ClosedCallPrimitive):
call_jaxpr = params["call_jaxpr"].jaxpr
else:
call_jaxpr = params["call_jaxpr"]
if len(in_atoms) != len(call_jaxpr.invars):
raise JaxprTypeError(f"Call primitive {prim} with {len(in_atoms)} "
f"operands cannot call jaxpr with "
f"{len(call_jaxpr.invars)} inputs")
# Check `call_jaxpr` can be applied to in_atoms.
env: dict[Var, Atom | MutableTypecheckVal] = {}
for v, x in zip(call_jaxpr.invars, in_atoms):
if not typecompat(v.aval, x.aval):
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
f"{x.aval} to jaxpr expecting type "
f"{v.aval}")
env[v] = x.val if type(x) is Literal else x
check_jaxpr(call_jaxpr)
invars, outvars = call_jaxpr.invars, call_jaxpr.outvars
out_avals = [x.aval for x in call_jaxpr.outvars]
out_type = out_avals
# jaxpr input effects are indexed to include jaxpr.constvars, but the eqn
# should have effects indexed only on its explicit arguments
effs = {e.replace(input_index=e.input_index - len(call_jaxpr.constvars))
if isinstance(e, effects.JaxprInputEffect)
else e for e in call_jaxpr.effects}
return out_type, effs
def _check_map(ctx_factory, prim, in_avals, params):
if "call_jaxpr" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
ordered_effects_ = effects.ordered_effects.filter_in(call_jaxpr.effects)
if ordered_effects_:
raise JaxprTypeError(
f"Map primitive {prim} mapping ordered effects: {ordered_effects_}")
if "axis_size" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_size' parameter")
axis_size = params["axis_size"]
if "axis_name" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_name' parameter")
axis_name = params["axis_name"]
if "in_axes" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'in_axes' parameter")
in_axes = params["in_axes"]
if "out_axes" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'out_axes' parameter")
out_axes = params["out_axes"]
binder_avals = [unmapped_aval(axis_size, in_axis, v.aval)
if in_axis is not None else v.aval
for v, in_axis in zip(call_jaxpr.invars, in_axes)]
for binder_aval, in_aval in zip(binder_avals, in_avals):
if not typecompat(binder_aval, in_aval):
raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")
with extend_axis_env_nd([(params['axis_name'], axis_size)]):
_check_jaxpr(ctx_factory, call_jaxpr)
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [unmapped_aval(axis_size, out_axis, aval)
if out_axis is not None else aval
for aval, out_axis in zip(mapped_out_avals, out_axes)]
return out_avals, filter_named_axis_effects(call_jaxpr.effects, {axis_name})
def eqn_effects(jaxpr):
# jaxpr input effects are indexed to include jaxpr.constvars, but the eqn
# should have effects indexed only on its explicit arguments
effs = jaxpr.effects
return {e.replace(input_index=e.input_index - len(jaxpr.constvars))
if isinstance(e, effects.JaxprInputEffect) else e for e in effs}
# ------------------- ShapeDtypeStruct -------------------
def _check_sharding(sharding, shape):
if sharding is None:
return
if isinstance(sharding, P):
sharding._check_compatible_wrt_shape(shape)
else:
sharding.check_compatible_aval(shape)
@set_module("jax")
class ShapeDtypeStruct:
"""A container for the shape, dtype, and other static attributes of an array.
``ShapeDtypeStruct`` is often used in conjunction with :func:`jax.eval_shape`.
Args:
shape: a sequence of integers representing an array shape
dtype: a dtype-like object
sharding: (optional) a :class:`jax.Sharding` object
"""
__slots__ = ["shape", "dtype", "_sharding", "_dll", "weak_type", "vma",
"is_ref"]
def __init__(self, shape, dtype, *, sharding=None, weak_type=False,
vma=None, is_ref=False):
self.shape = tuple(shape)
if dtype is None:
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype)
if sharding is not None and not isinstance(sharding, (Sharding, Format, P)):
raise ValueError(
"sharding should be an instance of `jax.sharding.Sharding`, "
"`jax.sharding.PartitionSpec` or"
f" `jax.experimental.layout.Format`. Got {sharding} of type"
f" {type(sharding)}.")
if (isinstance(sharding, Format) and
isinstance(sharding.layout, AutoLayout)):
raise TypeError(
"`Layout.AUTO` cannot be used in place of a device-local"
f" layout in a `ShapeDtypeStruct`. Got {sharding}")
self._sharding = (sharding.sharding if isinstance(sharding, Format)
else sharding)
_check_sharding(self._sharding, self.shape)
self._dll = sharding.layout if isinstance(sharding, Format) else None
self.weak_type = weak_type
if vma is not None and not isinstance(vma, (set, frozenset)):
raise TypeError(
"`vma` argument passed to ShapeDtypeStruct should be of type `set`"
f" or `frozenset`. Got type {type(vma)}")
self.vma = None if vma is None else frozenset(vma)
self.is_ref = is_ref
size = property(lambda self: math.prod(self.shape))
ndim = property(lambda self: len(self.shape))
@property
def sharding(self):
if isinstance(self._sharding, P):
# TODO(yashkatariya): Maybe use `get_abstract_mesh()` here but switch
# on `core.trace_state_clean()`?
cur_mesh = mesh_lib.get_concrete_mesh()
if cur_mesh.empty:
raise TypeError(
"When specifying PartitionSpec to `ShapeDtypeStruct`, the context"
" mesh cannot be empty. Please use `jax.set_mesh` to set"
" the mesh context.")
return NamedSharding(cur_mesh, self._sharding)
else:
return self._sharding
@property
def format(self):
return Format(self._dll, self.sharding)
def __len__(self):
try:
return self.shape[0]
except IndexError as e:
raise TypeError("len() of unsized object") from e # same as numpy error
def __repr__(self):
sh = f", sharding={self.sharding}" if self.sharding is not None else ""
l = f", format={self._dll}" if self._dll is not None else ""
wt = f", weak_type={self.weak_type}" if self.weak_type else ""
vma = f", vma={self.vma}" if self.vma else ""
is_ref = f", is_ref={self.is_ref}" if self.is_ref else ""
return (f"{type(self).__name__}(shape={self.shape}, "
f"dtype={self.dtype.name}{sh}{l}{wt}{vma}{is_ref})")
__str__ = __repr__
def __eq__(self, other):
if not isinstance(other, ShapeDtypeStruct):
return False
else:
return ((self.shape, self.dtype, self.sharding, self._dll,
self.weak_type, self.vma, self.is_ref) ==
(other.shape, other.dtype, other.sharding, other._dll,
other.weak_type, other.vma, other.is_ref))
def __hash__(self):
# TODO(frostig): avoid the conversion from dict by addressing
# https://github.com/jax-ml/jax/issues/8182
return hash((self.shape, self.dtype, self.sharding, self._dll,
self.weak_type, self.vma, self.is_ref))
def __setattr__(self, name, value):
if hasattr(self, name):
if getattr(self, name) == value:
# This can happen if two threads race, for example if two threads
# are trying to hash the same SDS instance.
return
raise RuntimeError(
f"Cannot reassign attributes ({name}) of immutable ShapeDtypeStruct"
" objects")
super().__setattr__(name, value)
def update(self, **kwargs):
if 'sharding' in kwargs:
s = kwargs['sharding']
if self._dll is not None and isinstance(s, Sharding):
raise ValueError(
f"You are updating ShapeDtypeStruct with a {type(s)} when the"
f" original ShapeDtypeStruct had a concrete layout {self.format}."
" This might lead to bugs. If you want to do this, create a new"
" ShapeDtypeStruct via the constructor.")
sharding = s
else:
sharding = self.format
return ShapeDtypeStruct(
shape=kwargs.pop('shape', self.shape),
dtype=kwargs.pop('dtype', self.dtype),
sharding=sharding,
weak_type=kwargs.pop('weak_type', self.weak_type),
vma=kwargs.pop('vma', self.vma),
is_ref=kwargs.pop('is_ref', self.is_ref))
def _sds_aval_mapping(x):
aval = ShapedArray(
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=x.weak_type)
aval = update_aval_with_sharding(
aval, x.sharding, vma=(frozenset() if x.vma is None else x.vma))
if x.is_ref:
from jax._src.state.types import AbstractRef # type: ignore
return AbstractRef(aval)
return aval
pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
# ------------------- Jaxpr printed representation -------------------
def pp_toplevel_jaxpr(jaxpr_to_print: Jaxpr, *,
source_info: bool = False,
print_shapes: bool = True,
custom_pp_eqn_rules : bool = True,
name_stack: bool = False,
print_effects: bool = False) -> pp.Doc:
context = JaxprPpContext()
settings = JaxprPpSettings(
source_info=source_info,
print_shapes=print_shapes,
custom_pp_eqn_rules=custom_pp_eqn_rules,
name_stack=name_stack,
print_effects=print_effects)
# Compute how many times each jaxpr is used.
names = defaultdict[Jaxpr, str](lambda: "jaxpr")
jaxpr_counts = Counter[Jaxpr]()
s = deque([jaxpr_to_print])
while s:
jaxpr = s.popleft()
jaxpr_counts[jaxpr] += 1
for eqn in jaxpr.eqns:
# TODO(slebedev): Come up with a more elaborate heuristic for name=.
name = eqn.params.get("name")
if name is None:
s.extend(jaxprs_in_params(eqn.params))
continue
name = name.strip("<>") # <lambda> -> lambda
for subjaxpr in jaxprs_in_params(eqn.params):
s.append(subjaxpr)
names.setdefault(subjaxpr, name)
# Pull jaxprs occurring more than once to the top-level, making sure
# that their names are unique.
docs = []
name_counts = Counter[str]()
for jaxpr, c in jaxpr_counts.items():
if c == 1:
continue
name = names[jaxpr]
if (count := name_counts[name]) > 0:
name_counts[name] += 1
name += str(count)
name_counts[name] += 1
else:
name_counts[name] += 1
docs.append(pp_shared_jaxpr(name, jaxpr, context, settings))
context.shared_jaxpr_names.add(name)
context.shared_jaxprs[jaxpr] = name
docs.append(pp_jaxpr(jaxpr_to_print, context, settings))
return pp.concat(docs)
class JaxprPpSettings(NamedTuple):
print_shapes: bool = True
source_info: bool = False
name_stack: bool = False
custom_pp_eqn_rules: bool = True
print_effects: bool = False
def _encode_digits_alphabetic(n: int) -> str:
if n == -1:
return '*'
s = ''
while len(s) == 0 or n:
n, i = n // 26, n % 26
s = chr(97 + i % 26) + s
return s
# A JaxprPpContext allows us to globally uniquify variable names within nested
# Jaxprs.
class JaxprPpContext:
var_names: defaultdict[Var, str]
# Shared jaxprs are those that are used multiple times and are printed
# first.
shared_jaxprs: MutableMapping[Jaxpr, str] # maps shared jaxpr to its name
shared_jaxpr_names: MutableSet[str]
def __init__(self) -> None:
self.shared_jaxprs = {}
self.shared_jaxpr_names = set()
fresh_names: Iterator[str] = (
name
for i in it.count()
if (name := _encode_digits_alphabetic(i)) not in self.shared_jaxpr_names
)
self.var_names = defaultdict(fresh_names.__next__)
def suggest_same_var_names(self,
for_vars: Sequence[Atom],
like_vars: Sequence[Atom]) -> None:
"""Suggests the names for `for_vars` to match those of `like_vars`.
`for_vars` are distinct Vars, and are aliased with `like_vars`.
"""
used_like_vars: set[Var] = set()
if len(for_vars) != len(like_vars):
# The mismatch can happen if a primitive containing a subjaxpr is invoked
# with the wrong number of arguments, e.g., when printing an invalid Jaxpr.
return
for for_v, like_v in zip(for_vars, like_vars):
if (isinstance(like_v, Var) and
like_v not in used_like_vars and
isinstance(for_v, Var) and
for_v not in self.var_names):
used_like_vars.add(like_v)
self.var_names[for_v] = pp_var(like_v, self)
def pp_var(v: Var | Literal, context: JaxprPpContext, *,
print_literal_dtype: bool = True) -> str:
return v.pretty_print(context, print_dtype=print_literal_dtype)
def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str:
return a.str_short(short_dtypes=True)
def pp_vars(vs: Sequence[Atom], context: JaxprPpContext,
*, separator="", print_shapes: bool = False) -> pp.Doc:
if print_shapes:
return pp.nest(2, pp.group(
pp.join(pp.text(separator) + pp.group(pp.brk()), [
pp.text(pp_var(v, context)) +
pp.type_annotation(pp.text(":" + pp_aval(v.aval, context)))
for v in vs
])
))
else:
return pp.nest(2, pp.group(
pp.join(pp.text(separator) + pp.group(pp.brk()),
[pp.text(pp_var(v, context)) for v in vs])
))
def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
pp_v = pp_jaxprs(v, context, settings)
elif isinstance(v, Jaxpr):
pp_v = pp_jaxpr(v, context, settings)
elif isinstance(v, ClosedJaxpr):
pp_v = pp_jaxpr(v.jaxpr, context, settings)
else:
pp_v = pp.text(str(v))
return pp.text(f'{k}=') + pp_v
def pp_kv_pairs(kv_pairs, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
if not kv_pairs:
return pp.nil()
return pp.group(pp.concat([
pp.nest(2, pp.concat([
pp.text("["), pp.brk(""),
pp.join(pp.brk(), [pp_kv_pair(k, v, context, settings) for k, v in kv_pairs])
])),
pp.brk(""), pp.text("]")
]))
def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings
) -> pp.Doc:
rule = (_pp_eqn if not settings.custom_pp_eqn_rules else
pp_eqn_rules.get(eqn.primitive, _pp_eqn))
doc = rule(eqn, context, settings)
user_frame = source_info_util.user_frame(eqn.source_info.traceback)
return doc if user_frame is None else pp.source_map(doc, user_frame)
def _pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings,
params: Sequence[str] | None = None) -> pp.Doc:
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
if params is None:
params = sorted(eqn.params)
name_stack_annotation = f'[{eqn.source_info.name_stack}]' if settings.name_stack else None
lhs = pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name, annotation=name_stack_annotation),
pp_kv_pairs([(p, eqn.params[p]) for p in params], context, settings),
pp.text(" ") + pp_vars(eqn.invars, context)]
if eqn.outvars:
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
else:
return pp.concat(rhs)
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext, JaxprPpSettings], pp.Doc]
pp_eqn_rules: dict[Primitive, CustomPpEqnRule] = {}
def pp_eqns(eqns: Sequence[JaxprEqn],
context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
return pp.join(
pp.brk("; "),
[pp_eqn(e, context, settings) for e in eqns])
def _compact_eqn_should_include(k: str, v: Any) -> bool:
if k == 'branches': return False
if isinstance(v, (Jaxpr, ClosedJaxpr)): return False
if (isinstance(v, tuple) and
any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)):
return False
return True
def str_eqn_compact(primitive: Primitive, params: dict[Any, Any]) -> str:
"Compact equation to string conversion used in HLO metadata."
if primitive in custom_str_eqn_compact_rules:
return custom_str_eqn_compact_rules[primitive](primitive, params)
primitive_name = primitive.name
kvs = " ".join(f"{k}={v}" for k, v in params.items()
if _compact_eqn_should_include(k, v))
return f"{primitive_name}[{kvs}]" if len(kvs) > 0 else primitive_name
custom_str_eqn_compact_rules: dict[
Primitive, Callable[[Primitive, dict[Any, Any]], str]
] = {}
def pp_jaxpr_skeleton(jaxpr: Jaxpr, eqns_fn, context: JaxprPpContext,
settings: JaxprPpSettings) -> pp.Doc:
constvars = pp_vars(jaxpr.constvars, context, print_shapes=settings.print_shapes)
invars = pp_vars(jaxpr.invars, context, print_shapes=settings.print_shapes)
eqns = eqns_fn()
outvars = pp.concat([
pp.text("("), pp_vars(jaxpr.outvars, context, separator=","),
pp.text(")" if len(jaxpr.outvars) != 1 else ",)")])
if settings.print_effects:
# TODO(sharadmv): render an entire signature here
eff_text = [pp.text(" : { ")]
for i, eff in enumerate(jaxpr.effects):
if i > 0:
eff_text.append(pp.text(", "))
if isinstance(eff, effects.JaxprInputEffect):
index = eff.input_index
all_vars = [*jaxpr.constvars, *jaxpr.invars]
eff_text.append(pp_effect(eff.replace(input_index=all_vars[index]),
context))
else:
eff_text.append(pp_effect(eff, context))
eff_text.append(pp.text(" }"))
else:
eff_text = []
return pp.group(pp.nest(2, pp.concat([
pp.text("{ "), pp.keyword(pp.text("lambda ")),
constvars, pp.text("; "), invars,
pp.text(". "), pp.keyword(pp.text("let")),
pp.nest(2, pp.brk() + eqns), pp.brk(),
pp.keyword(pp.text("in ")), outvars,
pp.concat(eff_text)
])) + pp.text(" }"))
def pp_shared_jaxpr(
name: str,
jaxpr: Jaxpr,
context: JaxprPpContext,
settings: JaxprPpSettings,
) -> pp.Doc:
return pp.concat([
pp.text("let " + name + " = "),
pp_jaxpr(jaxpr, context, settings),
pp.text(" in"),
pp.brk(),
])
def pp_jaxpr(
jaxpr: Jaxpr,
context: JaxprPpContext,
settings: JaxprPpSettings,
) -> pp.Doc:
if name := context.shared_jaxprs.get(jaxpr):
return pp.text(name)
eqns_fn = lambda: pp_eqns(jaxpr.eqns, context, settings)
return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, settings)
def pp_jaxprs(jaxprs: Sequence[ClosedJaxpr | Jaxpr],
context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc:
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
return pp.group(pp.concat([pp.nest(2, pp.concat([
pp.text('('), pp.brk(""),
pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context, settings), jaxprs))]
)), pp.brk(""), pp.text(')')])
)
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, context: JaxprPpContext,
settings: JaxprPpSettings) -> pp.Doc:
lo = max(lo, 0)
hi = max(lo, min(hi, len(jaxpr.eqns)))
eqns = jaxpr.eqns[lo:hi]
def eqns_fn():
pps = []
if len(eqns) == 0 and len(jaxpr.eqns) != 0:
pps.append(pp.text('...'))
else:
if lo != 0:
pps.append(pp.text('...'))
pps.extend(map((lambda e: pp_eqn(e, context, settings)), eqns))
if hi != len(jaxpr.eqns):
pps.append(pp.text('...'))
return pp.join(pp.brk("; "), pps)
return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, settings)
def pp_effect(effect: Effect, context: JaxprPpContext) -> pp.Doc:
if hasattr(effect, "_pretty_print"):
return effect._pretty_print(context)
return pp.text(str(effect))
# ------------------- Jaxpr util -------------------
def last_used(jaxpr: Jaxpr) -> dict[Var, JaxprEqn | None]:
"""Returns a mapping from every var in jaxpr to what equation uses it last."""
last_used: dict[Var, JaxprEqn | None] = {
v: None for v in jaxpr.outvars if not isinstance(v, Literal)}
for eqn in reversed(jaxpr.eqns):
for v in eqn.invars:
if not isinstance(v, Literal) and v not in last_used:
last_used[v] = eqn
return last_used
def clean_up_dead_vars(eqn: JaxprEqn, env: dict[Var, Any],
last_used: dict[Var, JaxprEqn | None]):
"""Remove all eqn.invars from env if eqn is the last time they were used."""
for v in {v for v in eqn.invars if not isinstance(v, Literal)}:
if last_used[v] is eqn:
# Delete ref to variable when it is no longer needed by next equations.
del env[v]
# Used in shard_map for converting avals
shard_aval_handlers = {} # type: ignore
unshard_aval_handlers = {} # type: ignore
def shard_aval(mesh, manual_axes, check_vma, spec, aval: AbstractValue
) -> AbstractValue:
if type(aval) in shard_aval_handlers:
return shard_aval_handlers[type(aval)](mesh, manual_axes, check_vma,
spec, aval)
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
def unshard_aval(mesh, check_vma, spec, aval: AbstractValue
) -> AbstractValue:
if type(aval) in unshard_aval_handlers:
return unshard_aval_handlers[type(aval)](mesh, check_vma, spec, aval)
else:
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
# ----------------- external APIs for querying tracing context -----------------
# TODO(dougalm, jakevdp): expose these via jax.extend
# Comparable object for checking whether JAX's trace state has changed.
class OpaqueTraceState:
def __init__(self, trace_ref):
self._trace_ref = trace_ref
def __eq__(self, other):
if isinstance(other, OpaqueTraceState):
return self._trace_ref == other._trace_ref
else:
return False
def get_opaque_trace_state(convention=None):
del convention
return OpaqueTraceState(trace_ctx.trace._weakref)
def nonempty_axis_env() -> bool:
return bool(trace_ctx.axis_env.axis_sizes)
def unsafe_am_i_under_a_jit() -> bool:
return 'DynamicJaxprTrace' in str(unsafe_get_trace_stack(trace_ctx.trace))
def unsafe_am_i_under_a_vmap() -> bool:
return 'BatchTrace' in str(unsafe_get_trace_stack(trace_ctx.trace))
# TODO(douglam): deprecate/delete
def find_top_trace(_):
return unsafe_get_current_trace()
def unsafe_get_current_trace():
return trace_ctx.trace
def unsafe_get_trace_stack(trace):
if hasattr(trace, "parent_trace"):
return unsafe_get_trace_stack(trace.parent_trace) + [trace]
else:
return [trace]
def unsafe_get_axis_names() -> list[Any]:
return list(trace_ctx.axis_env.axis_sizes)
# TODO(douglam): deprecate/delete
def axis_frame(axis_name):
return trace_ctx.axis_env.axis_size(axis_name)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/JAX.git
git@gitee.com:mirrors/JAX.git
mirrors
JAX
JAX
main

搜索帮助