jax.nn.softmax
is defined as:
def softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
I'm particularly interested in the lax.stop_gradient(x_max)
part. I would love an explanation for why it's needed. From a practical standpoint, it seems that stop_gradient
doesn't change the gradient calculation:
import jax
import jax.numpy as jnp
def softmax_unstable(x):
return jnp.exp(x) / jnp.sum(jnp.exp(x))
def softmax_stable(x):
x = x - jnp.max(x)
return jnp.exp(x) / jnp.sum(jnp.exp(x))
def softmax_stop_gradient(x):
x = x - jax.lax.stop_gradient(jnp.max(x))
return jnp.exp(x) / jnp.sum(jnp.exp(x))
# example input
x = jax.random.normal(jax.random.PRNGKey(123), (100,))
# make sure all forward passes are equal
a = softmax_unstable(x)
b = softmax_stable(x)
c = softmax_stop_gradient(x)
d = jax.nn.softmax(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)
# make sure all gradient calculations are the same
a = jax.grad(lambda x: -jnp.log(softmax_unstable(x))[2])(x)
b = jax.grad(lambda x: -jnp.log(softmax_stable(x))[2])(x)
c = jax.grad(lambda x: -jnp.log(softmax_stop_gradient(x))[2])(x)
d = jax.grad(lambda x: -jnp.log(jax.nn.softmax(x))[2])(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)
# make sure all gradient calculations are the same, this time we use softmax functions twice
a = jax.grad(lambda x: -jnp.log(softmax_unstable(softmax_unstable(x)))[2])(x)
b = jax.grad(lambda x: -jnp.log(softmax_stable(softmax_stable(x)))[2])(x)
c = jax.grad(lambda x: -jnp.log(softmax_stop_gradient(softmax_stop_gradient(x)))[2])(x)
d = jax.grad(lambda x: -jnp.log(jax.nn.softmax(jax.nn.softmax(x)))[2])(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)
^ all implementations are equal, even the one where we apply the x - x_max
trick but WITHOUT stop_gradient
.