vmap ops.index_update in Jax

jax.vmap can express functionality in which a single operation is independently applied across multiple axes of an input. Your function is a bit different: you have a single operation iteratively applied to a single input.

Fortunately JAX provides lax.scan which can handle this situation. The implementation would look something like this:

from jax import lax

def paraUpdate(y, ind):
    return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19]), ind

def filter_jax2(y):
  ranger = jnp.arange(len(y))
  return lax.scan(paraUpdate, y, ranger)[0]

print(np.allclose(filter_jax(jnpData), filter_jax2(jnpData)))
# True

%timeit filter_jax(jnpData).block_until_ready()
# 10 loops, best of 3: 28.6 ms per loop

%timeit filter_jax2(jnpData).block_until_ready()
# 1000 loops, best of 3: 519 ┬Ás per loop

If you change your algorithm so that you’e applying the operation to every column in the array rather than the first N columns, it can be expressed with vmap like this:

def filter_jax3(y):
  f = lambda col: jscp.convolve(impulse_20, col)[:-19]
  return jax.vmap(f, in_axes=1, out_axes=1)(y)

CLICK HERE to find out more related problems solutions.

Leave a Comment

Your email address will not be published.

Scroll to Top