JAX: Jit Autograd XLA
Get Notion free

JAX: Jit Autograd XLA

Array

Attributes

devices(): storing device
sharding: (distributed?) location

Types promotion

unlike
numpy
,
jax.numpy
doesn’t promote
float64
for efficiency and compatibility (TPUs), it does also support
bfloat16
use
jax.config.update('jax_numpy_dtype_promotion', 'strict')
to disable implicit type promotion (still allows python scalars type promotion for convenience)

Transformations

Jit

wrap input args → record all operations
ns → produce a jaxpr using the tracer records
with jit
without jit
jaxpr operations passed all at once as a graph / sequence
operations are passed one by one as being interpreted
traces a function during the first call, wrap objects & functions in
Traced<…>
and keep track of shapes, types (like graph guards in torch compile) to produce operations that can be passed to
XLA
to compute
jax.make_jaxpr
can be used to extract the produced sequence of operations

static args

static operations are evaluated at compile-time in python
traced operations are evaluated at run-tim in XLA
arguments to be used in control flow should be flagged as static using
static_argnums=(static_arg_pos, )
or
static_argname=("arg_names",)
# @partial(jit, static_argnums=(1,)) def f(x, y): return x*2 if y else x double = True jax.jit(f, static_argnums=(1,))(x, double)

compiled func caching

jax.jit
decorator cache a compiled version of the function using it’s hash code, that as long as the traced inputs have the same type, shape will be used, else func will be recompiled and cached
when a jitted function is decorated with
jax.jit
the cached version of the inner function will be used
avoid producing a new hash each call
partial(function)
lambda function

implicit, unwanted behaviors

if a global var is passed, jit will hard trace its value, (subsequent changes won’t reflect in runs from the cached compilation)

notes

struggle with control flows: if, while, …
pure functions only: static shapes, no global vars, deterministic
operations fusion
avoid temporary allocation (local mem ←x→ global mem)

donated args

arguments passed to a jitted function can be “dnoated” i.e we use the same device memory allocated for the input values of that argument to store the output / updated value (post func)

Tracing shapes with no comp

using
jax.eval_shape
we can trace inputs to figure out shapes (useful to avoid compiling twice when one of the arguments will start as None then will be an array)

Auto-diff

#jax.grad

jax.grad
trace and transform a given function
f
and return a function that given a set of inputs, differentiate
f
w.r.t to the first one
by default a
jax.grad
transformed function compute the gradient of it’s first argument (
jnp.array
or
PyTree
) by tracing all operations inside that function
jax.value_and_grad
is similar to grad, but evaluate the function & it’s grad
y, dfdx = jax.value_and_grad(f)(x)

#jax.jacobian

jax.jacobian
is similar to
jax.grad
but operate on vector valued functions

#jax.jvp

forward mode auto-diff to compute the Jacobian vector product
i.e: directional derivative along an input vector v
v is a tangent vector in input space
y, jvp = jax.jvp(f, (x,), (v,))

#jax.vjp

forward mode vector Jacobian product: computes J.v\mathbf{J}.\mathbf{v}: directional derivative along the input vector v\mathbf{v}
c is a cotangent vector in output space
y, pullback = jax.vjp(f, x)

#jax.jacobian

#jax.jacfwd

compute the full Jacobian matrix in forward mode diff: efficient when input dim is small

#jax.jacrev

compute the full Jacobian matrix in forward mode diff: efficient when output dim is small

#jax.jacobian

a high level wrapper that dynamically chooses
jax.jacrev
or
jax.jacfwd
depending on input, output shapes

#jax.linearize

evaluate a function
f
at point / input vector v and return the output and a linear approximate of
f
used for efficient compute
y, f_lin = jax.linearize(f, x)

#jax.hessian

computes the hessian matrix
import jax def f(x): return x[0]**2 + x[0]*x[1] + x[1]**3 H = jax.jit(jax.hessian(f))(x) H @ v
alternatives:
using
jax.jacfwd
and
jax.jacrev
from jax import jacfwd, jacrev def hessian(f): return jax.jit(jacfwd(jacrev(f))) H = hessian(f)(x) Hv = H @ v
using
jax.linearize
and
jax.grad
import jax y, f_lin = jax.linearize( jax.grad(f), x ) v = jnp.array([1., 1.]) Hv = jax.jit(f_lin)(v)

vmap & pmap

vmap
traces inputs arguments just like
jax.jit
and adds a batch dimension at axis 0 by default, batch dimension position can be configured for inputs and outputs with
in_axes
&
out_axes
(if set to None, the corresponding argument won’t be vectorized)
xs = jnp.stack([x, x]) # (b, n) wst = jnp.stack([w, w]).T # (n, b) jax.vmap(convolve1d, in_axis=[1, None], out_axis=0)(xst, w) # (b, n)
note:
in_axes
&
out_axes
in
vmap
… can be
Pytrees
with the same structure on the corresponding input & output
Pytrees
and a whole branch of a
Pytree
can be set to the same value (no need to specify the whole structure + the same value at each leaf)
vectorized map
parallel map: SPMD

Jaxpr

Jaxpr
: Jax exPRession (torch IR eq)
is a representation of a sequence of primitive operations encoding the shape & type of each operation, args, inputs, outputs

PyTree

*a functional freak solution for tossing states around*
a tree is a collection of tensors such as a model’s params, an optimizer state, …

Key paths

a leaf’s key path is a list of keys of length = depth of leaf
jax.tree_util.keystr()
provide a reader-friendly path key
repr
with a preset of keys for predefined nodes and using
__str__
in custom ones
function
jax.tree.flatten
jax.tree.map
path key function
jax.tree_util.tree_flatten_with_path
jax.tree_util.tree_map_with_path
difference
return path key for each leaf
expect path key with each as input with each leaf

#jax.tree.structure & #jax.tree.leaves

skims a collection and extract it’s structure / leaves
PyTreeDef([*, *, {“key”: *, “keys”: {”key1”: *, “key2”: *}}]) [num1, num2, Array([], dtype=), Array([], dtype=), Array([], dtype=)]
jax.tree.map(func, tree(s))
applies a function to all nodes (example: rescale)
jax.util_tree.register_pytree_node(Container_class, flatten_fn
,
unflatten_fn)
register a custom container to use in a
PyTree
jax.util_tree.register_dataclass(DataClass, data_fileds=[...], meta_fields=[...])
@functools.partial(jax.tree_util.register_dataclass, data_fields=[...], meta_fields=[...]) @dataclass class MyDataclassContainer(object): ...
jax.tree.leaves
with
is_leaf
to consider
None
as leaves
f
f

#TO-DO: add try_util functions

LAX

*jnp & LAX vs torch.nn.functional & aten*
lax is similar to jax.numpy, by a raw, lower-level XLA primitives focus mirror of it, requires careful use, enables more efficient & general operations
at core all lax operations are python wrapper of XLA operations, with few to no guards, checks or constraints (optimize at your own risk ye 3abit)

Debugging

#jax.debug.print(...)

python’s print will only work during tracing,
to keep printing / debugging in subsequent runs use
jax.debug.print
jax.debug.print("x shape: {shape}", shape=x.shape, ordered=False)

#jax.debug.breakpoint()

opens an interactive window at the specified point

#jax.debug.callback(python_func, inputs)

it calls a python function from inside a
jax
transformed function

Pseudo-Random Number Generator

in
numpy
the global random state is updated after each generation
which wouldn’t be compatible with what’s
jax
aiming to achieve while being reproducible mainly for being stateful, order depending (
jax
execute vectorized and parallelized operations in no deterministic order)
in
jax
the random generation state / key is explicitly handled
jnp.random.key(seed) new_key, *n_sub_keys = jnp.random.split(key, n)
jax.random.fold_in
is used to generate per-device random keys starting from the same root key injected with the device idx
jax.config.update('jax_num_cpu_devices', 8)
until we overcome unemployment

My Gotchas & Notes

jnp
functions are strict with type rules (stricter than
numpy
), most operations expect floating point arrays
[func].block_until_ready(): account for asynchronous dispatch
jnp.reshape
expect a static shape
# wrong traced shaped: known at run time by XLA jnp.reshape(x, jnp.array(x.shape).prod()) # correct static shape: known at compile time jnp.reshape(x, np.prod(x.shape))
jax.core.freeze
&
jax.core.unfreeze
: freeze & unfreeze python containers to use in a
pytree

Advanced & Experimental

dynamic shapes in
jit
trace
jax.config.update("jax_dynamic_shapes", True)
disable implicit type promotion
# still allows python scalars # standard is the default mode jax.config.update('jax_numpy_dtype_promotion', 'strict')

LAX utils

logical_and
logical_or
, and 
logical_not
(bit wise ops can be used under
jit
(no short circuit))
lax.cond(bool_check, true_fn, false_fn, fn_input)
both
true_fn
and
false_fn
should produce the same type and shape to have the same output trace from the condition block
example:
def breakpoint_if_nonfinite(x): is_finite = jnp.isfinite(x).all() def true_fn(x): pass def false_fn(x): jax.debug.breakpoint() jax.lax.cond(is_finite, true_fn, false_fn, x) @jax.jit def f(x, y): ... breakpoint_if_nonfinite(z) ...
lax.fori_loop
lax.fori_loop(i_start, i_finish, func, init_carry) -> final_carry func(i, pre_carry) -> new_carry
lax.while_loop
lax.scan
lax.scan(func, init_carry, xs_seq) -> final_carry, ys_seq func(pre_carry, x) -> new_carry, y
lax.map(f, xs)
a sequential map run
lax.select :
jnp.wher
’s core
lax.switch
:
jnp.piecewise
’s core

Collective operations

across shards op + replicate
jax.lax.psum

Distributed Gotchas

jax.device_put
vs
out_sharding

jax.device_put
moves an array from host to device according to a given sharding
out_sharding
kwarg is used to indicate where / how to shard the result of an operation lowered in xla

Asynchronous Dispatch

xla computations are delegated to device(s) asynchronously, only shape, dtype, sharding are retrieved instantly so that the Python code can “run ahead of time”
in a training loop (fetch batch, compute) both calls overlap achieving a pseudo parallel run of both, introducing a blocking operation will ruin the orchestration

host bubble

logging metrics for example (fetch batch, compute, log) would force block the fetching op until the compute operation is completed introducing a bottleneck, to mitigate that we can log the previous metrics
class RecordWriter: prev_metrics = None def __call__(self, cur_metrics: dict): self.prev_metrics, log_metrics = cur_metrics, self.prev_metrics if log_metrics is None: return print(*it.starmap("{}: {}".format, log_metrics.items()), sep="\t")
loop
fetch batch
compute , enqueue on accelerator: |
log
=============================== host ===========================
[█:0]|[█:_][█:1]|[█:0][█:2]|[█:1]…
================|==========|== device ==========================
_____|__________[████0████]|[████1████]…
================================================================
# first data fetch, compute0 enqueued, log -1 (N/A),
# data fetch for next compute, log previous (compute0 res), compute1
# …

device bubble

printing computations outside of the main, jit-compiled “compute” call would force host-device synchronization
⇒ perform additional computations (learning rate update, metrics / logs computes) using
numpy
or on host in
jax
to avoid the host-device roundtrip / sync
with jax.default_device("cpu"): ...

Efficient Data Loading

jax.make_array_from_process_local_data()
overlaps data transfer and training on accelerators given a python iterator yielding a dict at each step (: async host → device data transfer)
host
████████████…░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░
device
data transfer
██████████
███…░░░░░░
░░░░░░░░░░
train step
███████████████████████████████████████████████████████████████████████████

Data sharding ops: intermediate layout

jax.lax.with_sharding_constraint
a
jax.device_put
like function that could be called niside a
jax.jit
ted function to shard intermediate / output arrays manually

Parameters Sharding

flax.linen.Partitioned
is a
jax.numpy.array
wrapper that track an axis name, across which the array is sharded

Parallel Programming:

auto parallel via
jax.jit

partial replication across device computations are replicated
example: 2x4 mesh, 4x8 sharded array
summing across the first axis each two devices in the first mesh axis replicate / store the same values /

explicit parallel (manual configured auto compute)

sharding is explicitly handled

manual parallel

a function is map sharded (as in CUDA kernels but for blocks) for across shards operations use [lax utils.collective operations]
a mesh is the set / structure of devices, naming axis and setting their types (Explicit, Auto, …) than the
jax.sharding.PartitionSpec
partitions an array across devices
jax.debug.visualize_array_sharding
: visualize an array position(sharded)
jax.make_mesh((arr_dim0, ...), (arr_dim0_name, ...), axi_s_types=(jax.sharding.AxisType.Explicit, ...))
creates a mesh / matrix with named axis
jax.sharding.NamedSharding(mesh, partitioner)
with jax.set_mesh(mesh)
: set the native mesh to use for creation
example:
x : [4@x, 1]
██ ░░ ░░ ░░ ██ ░░ ░░ ░░
x : [1, 8@y]
██ ██ ██ ██ ░░ ░░ ░░ ░░
x + y : [4@x, 8@y]
██ ██ ██ ██ ██ ██ ██ ██
jax.mesh_map(func, mesh, in_specs=..., out_specs=...)
example:
input is sharded 4@x
output is sharded 4@x
mesh
█ █ █ █
device0
█ █ █
device1
█ █
device2
█ █
device3 █ █ █
[collective ops ?]
mesh
█ █ █ █
each device does it’s computation that the output is sharded accordingly
potential practice micro projects
RNN with lax.scan
vmapped operation with pytree config (same weight over multiple inputs)

#jax.shard_map

a parallelization transform that execture a given function
func
on sharded inputs as specified by
in_specs
then collects / re-assemble outputs as specified by
out_specs
on the given device mesh
(note for my fish memory:
shard_map
does the sharding according to
in_specs
not need (generally) to pre-shard inputs unless intended)
in_specs
None → replicate, axis_name → shard on this mesh axis
out_specs
None → fetch a single copy, no need to tile it’s already replicated (unvarying), axis_name → collect
use
jax.lax.pvary
to label a varying value to execute operations with other varying values (mostly added
jax.lax
in
JAXPER
s to pass to the XLA, mentioned just in case a khazoo9 is encountered)