It’s annoying and “unsafe” manipulating tensors in PyTorch through a DL model without knowing for sure that the dimensions are correctly being permuted, transposed or viewed through a view.

So I think explicit Tensor typing for:

  • each dimension’s meaning e.g. batch, sequence in time, embedding dimension
    • this subsumes (implies) type annotating the tensor dimension
  • dtype (float32, float64, int8, bf16)

would be useful.


jaxtyping (repo) supports this.

jaxtyping explicitly outlines typing of the shape and dtype in the docs.

from jaxtyping import Array, Float, PyTree
# Accepts floating-point 2D arrays with matching axes
def matrix_multiply(x: Float[Array, "dim1 dim2"],
                    y: Float[Array, "dim2 dim3"]
                  ) -> Float[Array, "dim1 dim3"]:
def accepts_pytree_of_ints(x: PyTree[int]):
def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):

What is a pytree? In JAX, we use the term pytree to refer to a tree-like structure built out of container-like Python objects. Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts


Project deprecated in favour of jaxtyping