Below is an example where a function with a custom-defined vector-Jacobian product (custom_vjp
) is vmap
ped. For a simple function like this, invoking vjp
fails:
@partial(custom_vjp, nondiff_argnums=(0,))
def test_func(f: Callable[..., float],
R: Array
) -> float:
return f(jnp.dot(R, R))
def test_func_fwd(f, primal):
primal_out = test_func(f, primal)
residual = 2. * primal * primal_out
return primal_out, residual
def test_func_bwd(f, residual, cotangent):
cotangent_out = residual * cotangent
return (cotangent_out, )
test_func.defvjp(test_func_fwd, test_func_bwd)
test_func = vmap(test_func, in_axes=(None, 0))
if __name__ == "__main__":
def f(x):
return x
# vjp
primal, f_vjp = vjp(partial(test_func, f),
jnp.ones((10, 3))
)
cotangent = jnp.ones(10)
cotangent_out = f_vjp(cotangent)
print(cotangent_out[0].shape)
The error message says:
ValueError: Shape of cotangent input to vjp pullback function (10,) must be the same as the shape of corresponding primal input (10, 3).
Here, I think the error message is misleading, because the cotangent input should have the same shape as the primal output, which should be (10, )
in this case. Still, it's not clear to me why this error occurs.