Replies: 3 comments 6 replies
-
Hi - the pytorch |
Beta Was this translation helpful? Give feedback.
-
Even though there hasn't been an update in a while, I'm sharing my implementation of grid_sample in JAX here. I have tested it, and it works the same as the PyTorch version. It currently only supports 2D input. While it outperforms the PyTorch version on CPU (it's about 5× faster), it is roughly 10× slower on GPU. Please check it out, and let me know if you have any feedback on this code. |
Beta Was this translation helpful? Give feedback.
-
Seems like a good potential use of FFI interface to call a C++ / CUDA kernel. |
Beta Was this translation helpful? Give feedback.
-
It would be great to have an equivalent of torch.functional.grid_sample in jax. This is widely used in 3D vision (view synthesis, re-projections to other camera positions, etc.). My understanding is, that currently, to do a similar thing in JAX, one would have to implement this from scratch like this Tensorflow example, which seems verbose and slow, whereas the PyTorch version seems to do this in one CUDA kernel.
Thank you for your consideration!
Beta Was this translation helpful? Give feedback.
All reactions