Skip to content

Pytorch internals

PyTorch Internals

A summary of ezyang's blog.

Tensor View and Storage

## View (logical)
# shape: [2, 2], logical shape
# stride: [2, 1], the stride for each axis
# offset: 0, the offset of storage pointer
x = [[1, 2],
     [3, 4]]

## Storage (physical)
# always a 1D contiguous array.
x.storage = [1, 2, 3, 4]

## Slicing
y = x[1:] # [3, 4], shape = [2], stride = [1], offset = 2
z = x[:, 0] # [1, 3], shape = [2], stride = [2], offset = 0

## Stride and Index:
def index_to_pos(index, stride, offset=0):
    # index: logical, same dimension as stride
    # pos: physical, 1d value.
    # this works for both continuous and discoutinuous tensors!
    pos = offset
    for i, s in zip(index, stride):
        pos += i * s
    return pos

def strides_from_shape(shape):
    # assuming a continuous tensor!
    # example: shape [3, 4, 5] --> stride []
    layout = [1]
    offset = 1
    for s in reversed(shape[1:]): # [5, 4]
        layout.append(s * offset) # [1, 5, 20]
        offset = s * offset # [5, 20]
    return tuple(reversed(layout)) # [20, 5, 1]

def is_contiguous(stride):
    # continuous tensor <==> mono-decreasing stride && last stride is 1
    last = stride[0]
    for s in stride[1:]:
        if s > last: 
            return False
        else:
            last = s
    if last > 1:
        return False
    else;
        return True
  • Each tensor always has a view-storage pair.
  • Multiple tensor views can share the same storage.
  • Tensor trinity that decides its true implementation:

    • device: CPU, CUDA, XLA, ...
    • layout: Strided, Sparse, ...
    • dtype: int, float, ...

Source Structure

torch/ # torch frontend
torch/csrc/ # c++ frontent, python bindings, autograd engine, JIT compiler
aten/ # a tensor library, most tensor operations.
c10/ # implementation of core abstractions.

Mostly we would want to extend the aten part for new operators.

First, you should register your op in some config file: detailed workflow

Each op should be implemented in three versions:

myop_out(..., out) # write to out
# usually the following two funcs can just call myop_out()
myop_(...) # inplace
myop() # return the output

Then you should write kernels as the real implementation.

You'll need to dispatch based on dtype, and implement each (or use template).

Workflow efficiency

  • Editing headers sparingly, since it may cause re-compilation of lots of files...
  • setup ccache