It’s annoying and “unsafe” manipulating tensors in PyTorch through a DL model without knowing for sure that the dimensions are correctly being permuted, transposed or viewed through a view.

So I think explicit Tensor typing for:

  • each dimension’s meaning e.g. batch, sequence in time, embedding dimension
    • this subsumes (implies) type annotating the tensor dimension
  • dtype (float32, float64, int8, bf16)

would be useful.

jaxtyping

jaxtyping (repo) supports this.

jaxtyping explicitly outlines typing of the shape and dtype in the docs.

from jaxtyping import Array, Float, PyTree
 
# Accepts floating-point 2D arrays with matching axes
def matrix_multiply(x: Float[Array, "dim1 dim2"],
                    y: Float[Array, "dim2 dim3"]
                  ) -> Float[Array, "dim1 dim3"]:
    ...
 
def accepts_pytree_of_ints(x: PyTree[int]):
    ...
 
def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
    ...

What is a pytree? In JAX, we use the term pytree to refer to a tree-like structure built out of container-like Python objects. Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts

torchtyping

Project deprecated in favour of jaxtyping