Skip to content

Grid sample

Coordinate system of F.grid_sample

2D case

First we use meshgrid to build grid coordinates:

import torch
import torch.nn.functional as F

res = 2
dim = 2
device = torch.device("cpu")

line = torch.linspace(-1, 1, res, device=device)
""" [-1, 1] """

points_ij = torch.stack(torch.meshgrid([line]*dim, indexing="ij"), dim=-1)
"""
[-1, -1] [-1, 1]
[1,  -1] [1,  1]
"""

points_xy = torch.stack(torch.meshgrid([line]*dim, indexing="xy"), dim=-1)
"""
[-1, -1] [1, -1]
[-1,  1] [1,  1]
"""

# xy is transpose of ij
assert torch.allclose(points_ij[..., [1, 0]], points_xy)

Why ij and xy are transposed:

# coordinate system of a 2D image (i/x row, j/y col!)
O -------- x/j
|      
|  
|
y/i

# i.e.
-->
xy
ji
<--

Now we can try grid_sample:

# simplest image
grid = torch.arange(res**dim).view([res]*dim).float()
"""
0 1
2 3
"""

# use ij
val_ij = F.grid_sample(grid.unsqueeze(0).unsqueeze(0), points_ij.unsqueeze(0), mode='bilinear', align_corners=True).squeeze()
"""
0=[-1, -1] 2=[-1, 1]
1=[1,  -1] 3=[1,  1]
"""

# use xy (or ji)
val_xy = F.grid_sample(grid.unsqueeze(0).unsqueeze(0), points_xy.unsqueeze(0), mode='bilinear', align_corners=True).squeeze()
"""
0=[-1, -1] 1=[1, -1]
2=[-1,  1] 3=[1,  1]
"""

We can see the internal coordinate system of grid_sample is:

                 -       
                 |
      0=[-1,-1]--|---------1=[1,-1]
      |          |         |
      |          |         |
      |          |         |
 - ------------- O ------------ +x/j
      |          |         |
      |          |         |
      |          |         |
      2=[-1,1]---|---------3=[1,1]
                 |
                 +y/i    

Therefore, just use ji (i.e., xy) to make sure the output is aligned with input:

assert torch.allclose(grid, val_xy)

3D case

import torch
import torch.nn.functional as F

res = 2
dim = 3
device = torch.device("cpu")

line = torch.linspace(-1, 1, res, device=device) # [N]

points_ijk = torch.stack(torch.meshgrid([line]*dim, indexing="ij"), dim=-1) # [N, N, N, 3]
"""
        [1,-1,-1]---------[1,-1,1]
       /|                /|
     /  |              /  |
    [-1,-1,-1]--------[-1,-1,1]
    |   |             |   |  
    |   |             |   |
    |   [1,1,-1]------|---[1,1,1]
    |  /              |  /
    |/                |/
    [-1,1,-1]---------[-1,1,1]
"""

# NOTE: here the order is not xyz, but yzx !!!
points_yzx = torch.stack(torch.meshgrid([line]*dim, indexing="xy"), dim=-1) # [N, N, N, 3]
"""
        [-1,1,-1]---------[-1,1,1]
       /|                /|
     /  |              /  |
    [-1,-1,-1]--------[-1,-1,1]
    |   |             |   |  
    |   |             |   |
    |   [1,1,-1]------|---[1,1,1]
    |  /              |  /
    |/                |/
    [1,-1,-1]---------[1,-1,1]
"""

Relationship with xyz and ijk:

# coordinate system of a 3D volume
       z/i
     /
   /
 /  
O---------x/k
|   
|   
|
|
y/j

# this is because xy[z] and ij[k] notation are always reversed!
2D       3D
-->      --->
xy       xyz
ji       kji
<--      <---

# why "xy" returns yzx instead of xyz?
# this is because torch simply swaps the first and second dimensions when using "xy".
# for 2D case, ij --> ji == xy
# for 3D case, ijk --> jik == yzx
# ref: https://github.com/pytorch/pytorch/blob/8ed093217272f9978e33799cdcc928ae591664c2/aten/src/ATen/native/TensorShape.cpp#L3584

Try grid_sample:

# simplest volume
grid = torch.arange(res**dim).view([res]*dim).float()
"""
        4-----------5
       /|          /|
     /  |        /  |
    0-----------1   |
    |   |       |   |  
    |   6-------|---7
    |  /        |  /
    |/          |/
    2-----------3
"""

# use ijk
F.grid_sample(grid.unsqueeze(0).unsqueeze(0), points_ijk.unsqueeze(0), mode='bilinear', align_corners=True).squeeze()
"""
        1=[1,-1,-1]-------5=[1,-1,1]
       /|                /|
     /  |              /  |
    0=[-1,-1,-1]------4=[-1,-1,1]
    |   |             |   |  
    |   |             |   |
    |   3=[1,1,-1]----|---7=[1,1,1]
    |  /              |  /
    |/                |/
    2=[-1,1,-1]-------6=[-1,1,1]
"""

# use yzx
F.grid_sample(grid.unsqueeze(0).unsqueeze(0), points_yzx.unsqueeze(0), mode='bilinear', align_corners=True).squeeze()
"""
        2=[-1,1,-1]-------6=[-1,1,1]
       /|                /|
     /  |              /  |
    0=[-1,-1,-1]------4=[-1,-1,1]
    |   |             |   |  
    |   |             |   |
    |   3=[1,1,-1]----|---7=[1,1,1]
    |  /              |  /
    |/                |/
    1=[1,-1,-1]-------5=[1,-1,1]
"""

Internal coordinate system of grid_sample is:

                       -           +z/i
                       |          /
                       |        /
             4=[-1,-1,1]----------5=[1,-1,1]
            /|         |     /   /|
          /  |         |   /   /  |
         0=[-1,-1,-1]---------1=[1,-1,-1]
         |   |         |/     |   |  
- -------|---|-------- O -----|---|------------ +x/k
         |   |        /|      |   |
         |   6=[-1,1,1]-------|---7=[1,1,1]
         |  /      /   |      |  /
         |/      /     |      |/
         2=[-1,1,-1]----------3=[1,1,-1]
              /        |
            /          |
          /            |
        -              +y/j

Notice that both ijk and yzx cannot align the output with input, and instead we need kji (or xyz):

points_ijk = torch.stack(torch.meshgrid([line]*dim, indexing="ij"), dim=-1) # [N, N, N, 3]
points_kji = points_ijk[..., [2,1,0]]

val = F.grid_sample(grid.unsqueeze(0).unsqueeze(0), points_kji.unsqueeze(0), mode='bilinear', align_corners=True).squeeze()
assert torch.allclose(grid, val)