The same JAX code that runs on CPU and GPU can also be run on TPU. Cloud TPUs
have the advantage of quickly giving you access to multiple TPU accelerators,
including in Colab. All of the
example notebooks here use
jax.pmap
to run JAX
computation across multiple TPU cores from Colab. You can also run the same code
directly on a Cloud TPU
VM.
The following notebooks showcase how to use and what you can do with Cloud TPUs on Colab:
A guide to getting started with pmap
, a transform for easily distributing SPMD
computations across devices.
Contributed by Alex Alemi (alexalemi@)
Solve and plot parallel ODE solutions with pmap
.
Contributed by Stephan Hoyer (shoyer@)
Solve the wave equation with pmap
, and make cool movies! The spatial domain is partitioned across the 8 cores of a Cloud TPU.
An overview of JAX presented at the Program Transformations for ML workshop at NeurIPS 2019 and the Compilers for ML workshop at CGO 2020. Covers basic numpy usage, grad
, jit
, vmap
, and pmap
.
The guidance on running TensorFlow on TPUs applies to JAX as well, with the exception of TensorFlow-specific details. Here we highlight a few important details that are particularly relevant to using TPUs in JAX.
One of the most common culprits for surprisingly slow code on TPUs is inadvertent padding:
By default*, matrix multiplication in JAX on TPUs uses bfloat16 with float32 accumulation. This can be controlled with the precision
keyword argument on relevant jax.numpy
functions (matmul
, dot
, einsum
, etc). In particular:
precision=jax.lax.Precision.DEFAULT
: uses mixed bfloat16 precision (fastest)precision=jax.lax.Precision.HIGH
: uses multiple MXU passes to achieve higher precisionprecision=jax.lax.Precision.HIGHEST
: uses even more MXU passes to achieve full float32 precisionJAX also adds the bfloat16
dtype, which you can use to explicitly cast arrays to bfloat16, e.g., jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on this issue if it affects you!
Refer to the Cloud TPU VM documentation.
If you run into Cloud TPU-specific issues (e.g. trouble creating a Cloud TPU VM), please email cloud-tpu-support@google.com, or trc-support@google.com if you are a TRC member. You can also file a JAX issue or ask a discussion question for any issues with these notebooks or using JAX in general.
If you have any other questions or comments regarding JAX on Cloud TPUs, please email jax-cloud-tpu-team@google.com. We’d like to hear from you!
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。