Array
Attributes
devices(): storing device
sharding: (distributed?) location
Types promotion
unlike
numpy
, jax.numpy
doesn’t promote float64
for efficiency and compatibility (TPUs), it does also support bfloat16
use
jax.config.update('jax_numpy_dtype_promotion', 'strict')
to disable implicit type promotion (still allows python scalars type promotion for convenience)Transformations
Jit
wrap input args → record all operations
ns → produce a jaxpr using the tracer records
traces a function during the first call, wrap objects & functions in
Traced<…>
and keep track of shapes, types (like graph guards in torch compile) to produce operations that can be passed to XLA
to computejax.make_jaxpr
can be used to extract the produced sequence of operations static args
static operations are evaluated at compile-time in python
traced operations are evaluated at run-tim in XLA
arguments to be used in control flow should be flagged as static using
static_argnums=(static_arg_pos, )
or
static_argname=("arg_names",)
# @partial(jit, static_argnums=(1,))
def f(x, y):
return x*2 if y else x
double = True
jax.jit(f, static_argnums=(1,))(x, double)
compiled func caching
jax.jit
decorator cache a compiled version of the function using it’s hash code, that as long as the traced inputs have the same type, shape will be used, else func will be recompiled and cachedwhen a jitted function is decorated with
jax.jit
the cached version of the inner function will be usedavoid producing a new hash each call
partial(function)
lambda function
implicit, unwanted behaviors
if a global var is passed, jit will hard trace its value, (subsequent changes won’t reflect in runs from the cached compilation)
notes
struggle with control flows: if, while, …
pure functions only: static shapes, no global vars, deterministic
operations fusion
avoid temporary allocation (local mem ←x→ global mem)
donated args
arguments passed to a jitted function can be “dnoated” i.e we use the same device memory allocated for the input values of that argument to store the output / updated value (post func)
Tracing shapes with no comp
using
jax.eval_shape
we can trace inputs to figure out shapes (useful to avoid compiling twice when one of the arguments will start as None then will be an array)Auto-diff
#jax.grad
jax.grad
trace and transform a given function f
and return a function that given a set of inputs, differentiate f
w.r.t to the first one by default a
jax.grad
transformed function compute the gradient of it’s first argument (jnp.array
or PyTree
) by tracing all operations inside that functionjax.value_and_grad
is similar to grad, but evaluate the function & it’s grady, dfdx = jax.value_and_grad(f)(x)
#jax.jacobian
jax.jacobian
is similar to jax.grad
but operate on vector valued functions #jax.jvp
forward mode auto-diff to compute the Jacobian vector product
i.e: directional derivative along an input vector v
v is a tangent vector in input space
y, jvp = jax.jvp(f, (x,), (v,))
#jax.vjp
forward mode vector Jacobian product: computes : directional derivative along the input vector
c is a cotangent vector in output space
y, pullback = jax.vjp(f, x)
#jax.jacobian
#jax.jacfwd
compute the full Jacobian matrix in forward mode diff: efficient when input dim is small
#jax.jacrev
compute the full Jacobian matrix in forward mode diff: efficient when output dim is small
#jax.jacobian
a high level wrapper that dynamically chooses
jax.jacrev
or jax.jacfwd
depending on input, output shapes#jax.linearize
evaluate a function
f
at point / input vector v and return the output and a linear approximate of f
used for efficient compute
y, f_lin = jax.linearize(f, x)
#jax.hessian
computes the hessian matrix
import jax
def f(x):
return x[0]**2 + x[0]*x[1] + x[1]**3
H = jax.jit(jax.hessian(f))(x)
H @ v
alternatives:
using
jax.jacfwd
and jax.jacrev
from jax import jacfwd, jacrev
def hessian(f):
return jax.jit(jacfwd(jacrev(f)))
H = hessian(f)(x)
Hv = H @ v
using
jax.linearize
and jax.grad
import jax
y, f_lin = jax.linearize(
jax.grad(f), x
)
v = jnp.array([1., 1.])
Hv = jax.jit(f_lin)(v)
vmap & pmap
vmap
traces inputs arguments just like jax.jit
and adds a batch dimension at axis 0 by default, batch dimension position can be configured for inputs and outputs with in_axes
& out_axes
(if set to None, the corresponding argument won’t be vectorized)
xs = jnp.stack([x, x]) # (b, n)
wst = jnp.stack([w, w]).T # (n, b)
jax.vmap(convolve1d, in_axis=[1, None], out_axis=0)(xst, w) # (b, n)
note:
in_axes
& out_axes
in vmap
… can be Pytrees
with the same structure on the corresponding input & output Pytrees
and a whole branch of a
Pytree
can be set to the same value (no need to specify the whole structure + the same value at each leaf)vectorized map
parallel map: SPMD
Jaxpr
Jaxpr
: Jax exPRession (torch IR eq)is a representation of a sequence of primitive operations encoding the shape & type of each operation, args, inputs, outputs
PyTree
*a functional freak solution for tossing states around*
a tree is a collection of tensors such as a model’s params, an optimizer state, …
Key paths
a leaf’s key path is a list of keys of length = depth of leaf
jax.tree_util.keystr()
provide a reader-friendly path key repr
with a preset of keys for predefined nodes and using __str__
in custom ones#jax.tree.structure & #jax.tree.leaves
skims a collection and extract it’s structure / leaves
PyTreeDef([*, *, {“key”: *, “keys”: {”key1”: *, “key2”: *}}])
[num1, num2, Array([…], dtype=…), Array([…], dtype=…), Array([…], dtype=…)]
jax.tree.map(func, tree(s))
applies a function to all nodes (example: rescale)
jax.util_tree.register_pytree_node(Container_class, flatten_fn
,unflatten_fn)
register a custom container to use in a
PyTree
jax.util_tree.register_dataclass(DataClass, data_fileds=[...], meta_fields=[...])
@functools.partial(jax.tree_util.register_dataclass,
data_fields=[...],
meta_fields=[...])
@dataclass
class MyDataclassContainer(object):
...
jax.tree.leaves
with is_leaf
to consider None
as leavesf
f
#TO-DO: add try_util functions
LAX
*jnp & LAX vs torch.nn.functional & aten*
lax is similar to jax.numpy, by a raw, lower-level XLA primitives focus mirror of it, requires careful use, enables more efficient & general operations
at core all lax operations are python wrapper of XLA operations, with few to no guards, checks or constraints (optimize at your own risk ye 3abit)
Debugging
#jax.debug.print(...)
python’s print will only work during tracing,
to keep printing / debugging in subsequent runs use
jax.debug.print
jax.debug.print("x shape: {shape}", shape=x.shape, ordered=False)
#jax.debug.breakpoint()
opens an interactive window at the specified point
#jax.debug.callback(python_func, inputs)
it calls a python function from inside a
jax
transformed functionPseudo-Random Number Generator
in
numpy
the global random state is updated after each generationwhich wouldn’t be compatible with what’s
jax
aiming to achieve while being reproducible mainly for being stateful, order depending (jax
execute vectorized and parallelized operations in no deterministic order)in
jax
the random generation state / key is explicitly handledjnp.random.key(seed)
new_key, *n_sub_keys = jnp.random.split(key, n)
jax.random.fold_in
is used to generate per-device random keys starting from the same root key injected with the device idxjax.config.update('jax_num_cpu_devices', 8)
until we overcome unemploymentMy Gotchas & Notes
jnp
functions are strict with type rules (stricter than numpy
), most operations expect floating point arrays[func].block_until_ready(): account for asynchronous dispatch
jnp.reshape
expect a static shape# wrong traced shaped: known at run time by XLA
jnp.reshape(x, jnp.array(x.shape).prod())
# correct static shape: known at compile time
jnp.reshape(x, np.prod(x.shape))
jax.core.freeze
& jax.core.unfreeze
: freeze & unfreeze python containers to use in a pytree
Advanced & Experimental
dynamic shapes in
jit
tracejax.config.update("jax_dynamic_shapes", True)
disable implicit type promotion
# still allows python scalars
# standard is the default mode
jax.config.update('jax_numpy_dtype_promotion', 'strict')
LAX utils
logical_and
, logical_or
, and logical_not
(bit wise ops can be used under jit
(no short circuit))lax.cond(bool_check, true_fn, false_fn, fn_input)
both
true_fn
and false_fn
should produce the same type and shape to have the same output trace from the condition blockexample:
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
jax.lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
...
breakpoint_if_nonfinite(z)
...
lax.fori_loop
lax.fori_loop(i_start, i_finish, func, init_carry) -> final_carry
func(i, pre_carry) -> new_carry
lax.while_loop
lax.scan
lax.scan(func, init_carry, xs_seq) -> final_carry, ys_seq
func(pre_carry, x) -> new_carry, y
lax.map(f, xs)
a sequential map runlax.select :
jnp.wher
’s corelax.switch
: jnp.piecewise
’s coreCollective operations
across shards op + replicate
jax.lax.psum
Distributed Gotchas
jax.device_put vs out_sharding
jax.device_put
moves an array from host to device according to a given shardingout_sharding
kwarg is used to indicate where / how to shard the result of an operation lowered in xla Asynchronous Dispatch
xla computations are delegated to device(s) asynchronously, only shape, dtype, sharding are retrieved instantly so that the Python code can “run ahead of time”
in a training loop (fetch batch, compute) both calls overlap achieving a pseudo parallel run of both, introducing a blocking operation will ruin the orchestration
host bubble
logging metrics for example (fetch batch, compute, log) would force block the fetching op until the compute operation is completed introducing a bottleneck, to mitigate that we can log the previous metrics
class RecordWriter:
prev_metrics = None
def __call__(self, cur_metrics: dict):
self.prev_metrics, log_metrics = cur_metrics, self.prev_metrics
if log_metrics is None:
return
print(*it.starmap("{}: {}".format, log_metrics.items()), sep="\t")
loop
fetch batch █
compute █, enqueue on accelerator: |
log █
=============================== host ===========================
[█:0]|[█:_][█:1]|[█:0][█:2]|[█:1]…
================|==========|== device ==========================
_____|__________[████0████]|[████1████]…
================================================================
# first data fetch, compute0 enqueued, log -1 (N/A),
# data fetch for next compute, log previous (compute0 res), compute1
# …
device bubble
printing computations outside of the main, jit-compiled “compute” call would force host-device synchronization
⇒ perform additional computations (learning rate update, metrics / logs computes) using
numpy
or on host in jax
to avoid the host-device roundtrip / syncwith jax.default_device("cpu"):
...
Efficient Data Loading
jax.make_array_from_process_local_data()
overlaps data transfer and training on accelerators given a python iterator yielding a dict at each step (: async host → device data transfer) host
████████████…░░░░░░░░
░░░░░░░░░░░░░░░░░░░░░
device
data transfer
██████████
███…░░░░░░
░░░░░░░░░░
train step
███████████████████████████████████████████████████████████████████████████
Data sharding ops: intermediate layout
jax.lax.with_sharding_constraint
a jax.device_put
like function that could be called niside a jax.jit
ted function to shard intermediate / output arrays manuallyParameters Sharding
flax.linen.Partitioned
is a jax.numpy.array
wrapper that track an axis name, across which the array is shardedParallel Programming:
auto parallel via jax.jit
partial replication across device computations are replicated
example: 2x4 mesh, 4x8 sharded array
summing across the first axis each two devices in the first mesh axis replicate / store the same values /
explicit parallel (manual configured auto compute)
sharding is explicitly handled
manual parallel
a function is map sharded (as in CUDA kernels but for blocks)
for across shards operations use [lax utils.collective operations]
a mesh is the set / structure of devices, naming axis and setting their types (Explicit, Auto, …) than the
jax.sharding.PartitionSpec
partitions an array across devicesjax.debug.visualize_array_sharding
:
visualize an array position(sharded)jax.make_mesh((arr_dim0, ...), (arr_dim0_name, ...),
axi_s_types=(jax.sharding.AxisType.Explicit, ...))
creates a mesh / matrix with named axis
jax.sharding.NamedSharding(mesh, partitioner)
with jax.set_mesh(mesh)
: set the native mesh to use for creationexample:
x : [4@x, 1]
██ ░░ ░░ ░░
██ ░░ ░░ ░░
x : [1, 8@y]
██ ██ ██ ██
░░ ░░ ░░ ░░
x + y : [4@x, 8@y]
██ ██ ██ ██
██ ██ ██ ██
jax.mesh_map(func, mesh, in_specs=..., out_specs=...)
example:
input is sharded 4@x
output is sharded 4@x
mesh
█ █ █ █
device0
█
█ █ █ device1
█
█
█ █ device2
█ █
█
█ device3
█ █ █
█
[collective ops ?]
mesh
█ █ █ █
each device does it’s computation that the output is sharded accordingly
potential practice micro projects
RNN with lax.scan
vmapped operation with pytree config (same weight over multiple inputs)
#jax.shard_map
a parallelization transform that execture a given function
func
on sharded inputs as specified by in_specs
then collects / re-assemble outputs as specified by out_specs
on the given device mesh(note for my fish memory:
shard_map
does the sharding according to in_specs
not need (generally) to pre-shard inputs unless intended)in_specs
None → replicate, axis_name → shard on this mesh axisout_specs
None → fetch a single copy, no need to tile it’s already replicated (unvarying), axis_name → collectuse
jax.lax.pvary
to label a varying value to execute operations with other varying values (mostly added jax.lax
in JAXPER
s to pass to the XLA, mentioned just in case a khazoo9 is encountered)read pages
Tutorials
jax.experimental.io_callback()
jax.check_tracer_leaks()
asynchronous dispatch

