dealing with dimensionality of input tensors
there's a lot of code that checks the number of dimensions in input tensors and adjust them accordingly. I would propose 2 changes here:
-
Only do this in top-level AI functions to be called by users of the library. Internal functions should make strong assumptions about the number and length of all dimensions, and these should be enforced at the beginning of the user-facing APIs.
-
Add in singleton dimensions for all "optional" axes. For example, when we only have 1 tracer, then add a dimension of length 1 and let all code be written under the assumption we have it. Nonetheless, it's still fine to have functions that work on variable number of batch dimensions, but these should be leading dimensions (placed first). But the actually relevant dimensions (e.g. channels, depth) should have fixed positions relative to the END of the shape tensor.