Einops.jl brings the readable and concise tensor operations of einops to Julia, unifying
reshape
,permutedims
,reduce
andrepeat
functions, with support for automatic differentiation.
The Python implementation uses strings to specify the exact operation, which is tricky to compile in Julia, so a string macro is exported for parity, e.g. einops"a b -> (b a)"
expands to the form (:a, :b) --> ((:b, :a),)
, where -->
is a custom operator that puts the left and right operands as type parameters of a special pattern type. This allows for compile-time awareness of dimensionalities, ensuring type stability.
The rearrange
combines reshaping and permutation operations into a single, expressive command.
julia> images = randn(3, 40, 30, 32); # channel, width, height, batch
# reorder axes to "w h c b" format:
julia> rearrange(images, (:c, :w, :h, :b) --> (:w, :h, :c, :b)) |> size
(40, 30, 3, 32)
# flatten each image into a vector
julia> rearrange(images, (:c, :w, :h, :b) --> ((:c, :w, :h), :b)) |> size
(32, 3600)
# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
julia> rearrange(images, (:c, (:w, :w2), (:h, :h2), :b) --> (:c, :w, :h, (:w2, :h2, :b)), h2=2, w2=2) |> size
(3, 20, 15, 128)
The method for Base.reduce
dispatches on ArrowPattern
, applying reduction operations (like sum
, mean
, maximum
) along specified axes. This is different from typical Base.reduce
functionality, which reduces using binary operations.
julia> x = randn(64, 32, 100);
# perform max-reduction on the first axis
# Axis t does not appear on the right - thus we reduce over t
julia> reduce(maximum, x, (:c, :b, :t) --> (:c, :b)) |> size
(64, 32)
julia> reduce(mean, x, (:c, :b, (:t, :t5)) --> (:b, :c, :t), t5=5) |> size
(32, 64, 20)
The method for Base.repeat
also dispatches on ArrowPattern
, and repeats elements along existing or new axes.
julia> image = randn(40, 30); # a grayscale image (of shape height x width)
# change it to RGB format by repeating in each channel
julia> repeat(image, (:w, :h) --> (:c, :w, :h), c=3) |> size
(3, 40, 30)
# repeat image 2 times along height (vertical axis)
julia> repeat(image, (:w, :h) --> ((:repeat, :h), :w), repeat=2) |> size
(60, 40)
# repeat image 2 time along height and 3 times along width
julia> repeat(image, (:w, :h) --> ((:w, :w3), (:h, :h2)), w3=3, h2=2) |> size
(120, 60)
- Implement
rearrange
. - Support Python implementation's string syntax for patterns with string macro.
- Implement
pack
andunpack
. - Implement
parse_shape
. - Implement
repeat
. - Implement
reduce
. - Support automatic differentiation (tested with Zygote.jl).
- Implement
einsum
(or wrap existing implementation) (see #3). - Support ellipsis notation (using
..
from EllipsisNotation.jl) (see #9). - Explore integration with
PermutedDimsArray
or TransmuteDims.jl for lazy and statically inferrable permutations (see #4).
Contributions are welcome! Please feel free to open an issue to report a bug or start a discussion.