Skip to content

Commit 1d24622

Browse files
committed
stow away --> notation in favor of =>, docstrings
1 parent 654c2db commit 1d24622

File tree

6 files changed

+92
-56
lines changed

6 files changed

+92
-56
lines changed

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
> [!WARNING]
99
> This package is still under development, and does not implement all of the features of the original (see [Roadmap](#Roadmap)).
1010
11-
Einops.jl is a Julia implementation of the [einops](https://einops.rocks) Python package, providing an elegant and intuitive notation for tensor operations. We currently implement `rearrange`, offering a unified way to perform Julia's `reshape` and `permutedims`, with plans for also implementing `reduce` and `repeat`.
11+
Einops.jl is a Julia implementation of the [einops](https://einops.rocks) Python package, providing an elegant and intuitive notation for tensor operations. We currently implement `rearrange`, offering a unified way to perform Julia's `reshape` and `permutedims`, as well as `pack` and `unpack`, with plans for implementing `reduce` and `repeat`.
1212

13-
The Python implementation uses strings to specify the operation, but that would be tricky to compile in Julia, so a string macro `@einops_str` is exported for parity, e.g. `einops"a 1 b c -> (c b) a"`, which expands to the form `(:a, 1, :b, :c,) --> ((:c, :b), :a)` where `-->` is an operator that creates an `Einops.Pattern{(:a, 1, :b, :c), ((:c, :b), :a)}`, allowing for compile-time awareness of dimensionalities and permutations—this is not yet taken advantage of, since the tuple types are sufficient for at least ensuring type stability (see [Roadmap](#Roadmap)).
13+
The Python implementation uses strings to specify the operation, but that would be tricky to compile in Julia, so a string macro `@einops_str` is exported for parity, e.g. `einops"a 1 b c -> (c b) a"`, which expands to the form `(:a, 1, :b, :c,) => ((:c, :b), :a)`, allowing for compile-time awareness of dimensionalities, ensuring type stability.
1414

1515
## Operations
1616

@@ -22,15 +22,15 @@ The `rearrange` combines reshaping and permutation operations into a single, exp
2222
julia> images = randn(32, 30, 40, 3); # batch, height, width, channel
2323

2424
# reorder axes to "b c h w" format:
25-
julia> rearrange(images, (:b, :h, :w, :c) --> (:b, :c, :h, :w)) |> size
25+
julia> rearrange(images, (:b, :h, :w, :c) => (:b, :c, :h, :w)) |> size
2626
(32, 3, 30, 40)
2727

2828
# flatten each image into a vector
29-
julia> rearrange(images, (:b, :h, :w, :c) --> (:b, (:h, :w, :c))) |> size
29+
julia> rearrange(images, (:b, :h, :w, :c) => (:b, (:h, :w, :c))) |> size
3030
(32, 3600)
3131

3232
# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
33-
julia> rearrange(images, (:b, (:h1, :h), (:w1, :w), :c) --> ((:b, :h1, :w1), :h, :w, :c), h1=2, w1=2) |> size
33+
julia> rearrange(images, (:b, (:h1, :h), (:w1, :w), :c) => ((:b, :h1, :w1), :h, :w, :c), h1=2, w1=2) |> size
3434
(128, 15, 20, 3)
3535
```
3636

@@ -41,7 +41,7 @@ The `reduce` function will allow for applying reduction operations (like `sum`,
4141
```julia
4242
# Example (conceptual):
4343
x = randn(100, 32, 64)
44-
y = reduce(maximum, x, (:t, :b, :c) --> (:b, :c)) # max-reduction on the first axis
44+
y = reduce(maximum, x, (:t, :b, :c) => (:b, :c)) # max-reduction on the first axis
4545
```
4646

4747
### `repeat` (Planned)
@@ -51,7 +51,7 @@ The `repeat` function will provide a concise way to repeat elements along existi
5151
```julia
5252
# Example (conceptual):
5353
image = randn(30, 40)
54-
rgb_image = repeat(image, (:h, :w) --> (:repeat, :h, :w), repeat=3)
54+
rgb_image = repeat(image, (:h, :w) => (:repeat, :h, :w), repeat=3)
5555
```
5656

5757
## Roadmap
@@ -63,7 +63,7 @@ rgb_image = repeat(image, (:h, :w) --> (:repeat, :h, :w), repeat=3)
6363
* [ ] Implement `repeat`.
6464
* [ ] Explore integration with `PermutedDimsArray` or `TransmuteDims.jl` for lazy and statically inferrable permutations.
6565
* [ ] Implement `einsum` (or wrap existing implementation)
66-
* [ ] Implement `pack`, and `unpack`.
66+
* [x] Implement `pack`, and `unpack`.
6767

6868
## Contributing
6969

src/Einops.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,7 @@ export ..
66

77
# TODO: use TransmuteDims.jl
88

9-
struct Pattern{L,R} end
10-
(-->)(left, right) = Pattern{left, right}()
11-
Base.show(io::IO, ::Pattern{L,R}) where {L,R} = print(io, "$L --> $R")
12-
Base.iterate(::Pattern{L}) where L = (L, Val(:right))
13-
Base.iterate(::Pattern{<:Any,R}, ::Val{:right}) where R = (R, nothing)
14-
Base.iterate(::Pattern, ::Nothing) = nothing
15-
169
include("utils.jl")
17-
export -->
1810

1911
include("einops_str.jl")
2012
export @einops_str

src/einops_str.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function parse_pattern(pattern::AbstractString)
44
lhs, rhs = strip.(split(pattern, "->"; limit = 2))
55
lhs_axes = tokenize_side(lhs)
66
rhs_axes = tokenize_side(rhs)
7-
return Tuple(lhs_axes) --> Tuple(rhs_axes)
7+
return Tuple(lhs_axes) => Tuple(rhs_axes)
88
end
99

1010
function tokenize_side(side::AbstractString)
@@ -86,10 +86,10 @@ For parity with Python implementation.
8686
8787
```jldoctest
8888
julia> einops"a 1 b c -> (c b) a"
89-
(:a, 1, :b, :c) --> ((:c, :b), :a)
89+
(:a, 1, :b, :c) => ((:c, :b), :a)
9090
9191
julia> einops"embed token (head batch) -> (embed head) token batch"
92-
(:embed, :token, (:head, :batch)) --> ((:embed, :head), :token, :batch)
92+
(:embed, :token, (:head, :batch)) => ((:embed, :head), :token, :batch)
9393
9494
julia> einops"i j * k" # for packing
9595
(:i, :j, *, :k)

src/pack_unpack.jl

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,68 @@ size_after_wildcard(dims::Dims, pattern::PackPattern) = dims[end-length(pattern)
99
size_wildcard(dims::Dims, pattern::PackPattern) = dims[begin+length(size_before_wildcard(dims, pattern)):end-length(size_after_wildcard(dims, pattern))]
1010
packed_size(dims::Dims, pattern::PackPattern) = (size_before_wildcard(dims, pattern)..., prod(size_wildcard(dims, pattern)), size_after_wildcard(dims, pattern)...)
1111

12-
function pack(unpacked_arrays::Vector, pattern::PackPattern{N}) where N
12+
"""
13+
pack(unpacked_arrays, pattern)
14+
15+
Pack a vector of arrays into a single array according to the pattern.
16+
17+
# Examples
18+
19+
```jldoctest
20+
julia> inputs = [rand(2,3,5), rand(2,3,7,5), rand(2,3,7,9,5)]
21+
22+
julia> packed_array, packed_shapes = pack(inputs, (:i, :j, *, :k));
23+
24+
julia> size(packed_array)
25+
(2, 3, 71, 5)
26+
27+
julia> packed_shapes
28+
3-element Vector{NTuple{N, Int64} where N}:
29+
()
30+
(7,)
31+
(7, 9)
32+
"""
33+
function pack(unpacked_arrays, pattern::PackPattern{N}) where N
1334
checkpacking(pattern)
1435
reshaped_arrays = [reshape(A, packed_size(size(A), pattern)::Dims{N})::AbstractArray{<:Any,N} for A in unpacked_arrays]
1536
concatenated_array::AbstractArray{<:Any,N} = cat(reshaped_arrays..., dims=find_wildcard(pattern))
16-
packed_shapes = Dims[size_wildcard(size(A), pattern) for A in unpacked_arrays]
37+
packed_shapes = Dims[size_wildcard(size(unpacked_array), pattern) for unpacked_array in unpacked_arrays]
1738
return concatenated_array, packed_shapes
1839
end
1940

2041
splice(a::Dims, i::Int, r::Dims) = (a[1:i-1]..., r..., a[i+1:end]...)
2142

22-
# FIXME: results in a reshape of a view ... so should we collect?
43+
# FIXME: result is a reshape of a view ... so we should collect?
44+
"""
45+
unpack(packed_array, packed_shapes, pattern)
46+
47+
Unpack a single array into a vector of arrays according to the pattern.
48+
49+
# Examples
50+
51+
```jldoctest
52+
julia> inputs = [rand(2,3,5), rand(2,3,7,5), rand(2,3,7,9,5)];
53+
54+
julia> inputs == unpack(pack(inputs, (:i, :j, *, :k))..., (:i, :j, *, :k))
55+
true
56+
57+
julia> packed_array = rand(2,3,16);
58+
59+
julia> packed_shapes = [(), (7,), (4, 2)];
60+
61+
julia> unpack(packed_array, packed_shapes, (:i, :j, *)) .|> size
62+
3-element Vector{Tuple{Int64, Int64, Vararg{Int64}}}:
63+
(2, 3)
64+
(2, 3, 7)
65+
(2, 3, 4, 2)
66+
```
67+
"""
2368
function unpack(packed_array::AbstractArray{<:Any,N}, packed_shapes, pattern::PackPattern{N}) where N
2469
checkpacking(pattern)
2570
inds = Iterators.accumulate(+, Iterators.map(prod, packed_shapes))
2671
unpacked_arrays = map(Iterators.flatten((0, inds)), inds, packed_shapes) do i, j, ps
27-
A = selectdim(packed_array, find_wildcard(pattern), i+1:j)
28-
reshape(A, splice(size(A), find_wildcard(pattern), ps))
72+
unpacked_array = selectdim(packed_array, find_wildcard(pattern), i+1:j)
73+
return collect(reshape(unpacked_array, splice(size(unpacked_array), find_wildcard(pattern), ps)))
2974
end
3075
return unpacked_arrays
3176
end

src/rearrange.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,20 @@ Can always be expressed as a `reshape` + `permutedims` + `reshape`.
5555
# Examples
5656
5757
```jldoctest
58-
julia> rearrange(rand(2,3,5), (:a, :b, :c) --> (:c, :b, :a)) |> size
58+
julia> rearrange(rand(2,3,5), (:a, :b, :c) => (:c, :b, :a)) |> size
5959
(5, 3, 2)
6060
6161
julia> permutedims(rand(2,3,5), (3,2,1)) |> size
6262
(5, 3, 2)
6363
64-
julia> rearrange(rand(2,3,35), (:a, :b, (:c, :d)) --> (:a, :d, (:c, :b)), c=5) |> size
64+
julia> rearrange(rand(2,3,35), (:a, :b, (:c, :d)) => (:a, :d, (:c, :b)), c=5) |> size
6565
(2, 7, 15)
6666
6767
julia> reshape(permutedims(reshape(rand(2,3,35), 2,3,5,7), (1,4,3,2)), 2,7,5*3) |> size
6868
(2, 7, 15)
6969
```
7070
"""
71-
function rearrange(x, @nospecialize pattern::Pattern; context...)
72-
left, right = pattern
71+
function rearrange(x, (left, right); context...)
7372
(!isempty(extract(typeof(..), left)) || !isempty(extract(typeof(..), right))) && throw(ArgumentError("Ellipses (..) are currently not supported"))
7473
left_names, right_names = extract(Symbol, left), extract(Symbol, right)
7574
reshaped_in = reshape_in(x, left; context...)

test/runtests.jl

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ using Test, Statistics
1818
@testset "einops string tokenization" begin
1919

2020
@testset "arrow pattern" begin
21-
@test einops"a b c -> a (c b)" == ((:a, :b, :c) --> (:a, (:c, :b)))
22-
@test einops"a b c -> a(c b)" == ((:a, :b, :c) --> (:a, (:c, :b)))
23-
@test einops"a b 1 -> a 1 b" == ((:a, :b, 1) --> (:a, 1, :b))
24-
@test einops"a b () -> a () b" == ((:a, :b, ()) --> (:a, (), :b))
25-
@test einops"a b()->a()b" == ((:a, :b, ()) --> (:a, (), :b))
26-
@test einops"b ... -> b ..." == ((:b, ..) --> (:b, ..))
27-
@test einops"->" == (() --> ())
28-
@test einops"-> 1" == (() --> (1,))
21+
@test einops"a b c -> a (c b)" == ((:a, :b, :c) => (:a, (:c, :b)))
22+
@test einops"a b c -> a(c b)" == ((:a, :b, :c) => (:a, (:c, :b)))
23+
@test einops"a b 1 -> a 1 b" == ((:a, :b, 1) => (:a, 1, :b))
24+
@test einops"a b () -> a () b" == ((:a, :b, ()) => (:a, (), :b))
25+
@test einops"a b()->a()b" == ((:a, :b, ()) => (:a, (), :b))
26+
@test einops"b ... -> b ..." == ((:b, ..) => (:b, ..))
27+
@test einops"->" == (() => ())
28+
@test einops"-> 1" == (() => (1,))
2929
@test_throws "'.'" Einops.parse_pattern("-> .")
3030
@test_throws "'('" Einops.parse_pattern("-> (")
3131
@test_throws "')'" Einops.parse_pattern("-> )")
@@ -43,36 +43,36 @@ using Test, Statistics
4343
@testset "rearrange" begin
4444

4545
x = rand(2,3,5)
46-
@test rearrange(x, (:a, :b, :c) --> (:c, :b, :a)) == permutedims(x, (3,2,1))
47-
@test rearrange(x, (:a, :b, :c) --> (:a, (:c, :b))) == reshape(permutedims(x, (1,3,2)), 2,5*3)
48-
@test rearrange(x, (:first, :second, :third) --> (:third, :second, :first)) == rearrange(x, (:a, :b, :c) --> (:c, :b, :a))
49-
@test_throws "Input length" rearrange(x, (:a, (:b, :c)) --> (:c, :b, :a))
50-
@test_throws ["Set of", "does not match"] rearrange(x, (:a, :b, :c) --> (:a, :b, :a))
51-
@test_throws ["Left names", "not unique"] rearrange(x, (:a, :a, :b) --> (:a, :b))
52-
@test_throws ["Right names", "not unique"] rearrange(x, (:a, :b, :c) --> (:a, :b, :c, :a))
53-
@test_throws "Invalid input dimension" rearrange(x, (:a, :b, 'c') --> (:a, :b, :c))
54-
@test_broken rearrange(x, (:a, :b, ..) --> (:a, .., :b)) == rearrange(x, (:a, :b, :c) --> (:a, :c, :b))
46+
@test rearrange(x, (:a, :b, :c) => (:c, :b, :a)) == permutedims(x, (3,2,1))
47+
@test rearrange(x, (:a, :b, :c) => (:a, (:c, :b))) == reshape(permutedims(x, (1,3,2)), 2,5*3)
48+
@test rearrange(x, (:first, :second, :third) => (:third, :second, :first)) == rearrange(x, (:a, :b, :c) => (:c, :b, :a))
49+
@test_throws "Input length" rearrange(x, (:a, (:b, :c)) => (:c, :b, :a))
50+
@test_throws ["Set of", "does not match"] rearrange(x, (:a, :b, :c) => (:a, :b, :a))
51+
@test_throws ["Left names", "not unique"] rearrange(x, (:a, :a, :b) => (:a, :b))
52+
@test_throws ["Right names", "not unique"] rearrange(x, (:a, :b, :c) => (:a, :b, :c, :a))
53+
@test_throws "Invalid input dimension" rearrange(x, (:a, :b, 'c') => (:a, :b, :c))
54+
@test_broken rearrange(x, (:a, :b, ..) => (:a, .., :b)) == rearrange(x, (:a, :b, :c) => (:a, :c, :b))
5555

5656
x = reshape(rand(1)) # size (), length 1
57-
@test rearrange(x, () --> ()) == x
58-
@test rearrange(x, () --> (1,)) == reshape(x, 1)
57+
@test rearrange(x, () => ()) == x
58+
@test rearrange(x, () => (1,)) == reshape(x, 1)
5959

6060
x = rand(2,3,5*7)
61-
@test rearrange(x, (:a, :b, (:c, :d)) --> (:a, :d, (:c, :b)), c=5) == reshape(permutedims(reshape(x, 2,3,5,7), (1,4,3,2)), 2,7,5*3)
61+
@test rearrange(x, (:a, :b, (:c, :d)) => (:a, :d, (:c, :b)), c=5) == reshape(permutedims(reshape(x, 2,3,5,7), (1,4,3,2)), 2,7,5*3)
6262

6363
x = rand(2,3,5*7*11)
64-
@test rearrange(x, (:a, :b, (:c, :d, :e)) --> ((:a, :e), :d, (:c, :b)), c=5, d=7) == reshape(permutedims(reshape(x, 2,3,5,7,11), (1,5,4,3,2)), 2*11,7,5*3)
65-
@test_throws "Unknown dimension sizes" rearrange(x, (:a, :b, (:c, :d, :e)) --> (:a, :b, :c, :d, :e), c=5)
64+
@test rearrange(x, (:a, :b, (:c, :d, :e)) => ((:a, :e), :d, (:c, :b)), c=5, d=7) == reshape(permutedims(reshape(x, 2,3,5,7,11), (1,5,4,3,2)), 2*11,7,5*3)
65+
@test_throws "Unknown dimension sizes" rearrange(x, (:a, :b, (:c, :d, :e)) => (:a, :b, :c, :d, :e), c=5)
6666

6767
x = rand(2,1,3)
68-
@test rearrange(x, (:a, 1, :b) --> (:a, :b)) == dropdims(x, dims=2)
69-
@test_throws "Singleton dimension size is not 1" rearrange(x, (2, :a, :b) --> (:a, :b))
70-
@test_throws "Singleton dimension size is not 1" rearrange(x, (:a, :b, :c) --> (:a, :b, :c, 2))
68+
@test rearrange(x, (:a, 1, :b) => (:a, :b)) == dropdims(x, dims=2)
69+
@test_throws "Singleton dimension size is not 1" rearrange(x, (2, :a, :b) => (:a, :b))
70+
@test_throws "Singleton dimension size is not 1" rearrange(x, (:a, :b, :c) => (:a, :b, :c, 2))
7171

7272
x = rand(2,3)
73-
@test rearrange(x, (:a, :b) --> (:b, 1, :a)) == reshape(permutedims(x, (2,1)), 3,1,2)
74-
@test rearrange(x, (:a, :b) --> (:b, 1, 1, :a, 1)) == reshape(permutedims(x, (2,1)), 3,1,1,2,1)
75-
@test rearrange(x, (:a, :b) --> (:b, (), :a)) == rearrange(x, (:a, :b) --> (:b, (), :a))
73+
@test rearrange(x, (:a, :b) => (:b, 1, :a)) == reshape(permutedims(x, (2,1)), 3,1,2)
74+
@test rearrange(x, (:a, :b) => (:b, 1, 1, :a, 1)) == reshape(permutedims(x, (2,1)), 3,1,1,2,1)
75+
@test rearrange(x, (:a, :b) => (:b, (), :a)) == rearrange(x, (:a, :b) => (:b, (), :a))
7676

7777
@testset "Python API reference parity" begin
7878
# see https://einops.rocks/api/rearrange/

0 commit comments

Comments
 (0)