1

Below is an example where a function with a custom-defined vector-Jacobian product (custom_vjp) is vmapped. For a simple function like this, invoking vjp fails:

@partial(custom_vjp, nondiff_argnums=(0,))
def test_func(f: Callable[..., float],
              R: Array
              ) -> float:

    return f(jnp.dot(R, R))


def test_func_fwd(f, primal):

    primal_out = test_func(f, primal)
    residual = 2. * primal * primal_out
    return primal_out, residual


def test_func_bwd(f, residual, cotangent):

    cotangent_out = residual * cotangent
    return (cotangent_out, )


test_func.defvjp(test_func_fwd, test_func_bwd)

test_func = vmap(test_func, in_axes=(None, 0))


if __name__ == "__main__":

    def f(x):
        return x

    # vjp
    primal, f_vjp = vjp(partial(test_func, f),
                        jnp.ones((10, 3))
                        )

    cotangent = jnp.ones(10)
    cotangent_out = f_vjp(cotangent)

    print(cotangent_out[0].shape)

The error message says:

ValueError: Shape of cotangent input to vjp pullback function (10,) must be the same as the shape of corresponding primal input (10, 3).

Here, I think the error message is misleading, because the cotangent input should have the same shape as the primal output, which should be (10, ) in this case. Still, it's not clear to me why this error occurs.

1 Answer 1

1

The problem is that in test_func_fwd, you recursively call test_func, but you've overwritten test_func in the global namespace with its vmapped version. If you leave the original test_func unchanged in the global namespace, your code will work as expected:

...

test_func_mapped = vmap(test_func, in_axes=(None, 0))

... 

primal, f_vjp = vjp(partial(test_func_mapped, f),
                    jnp.ones((10, 3))
                    )

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.