Skip to content

Commit 112ed1b

Browse files
committed
revert to -->
1 parent 1d24622 commit 112ed1b

File tree

6 files changed

+54
-42
lines changed

6 files changed

+54
-42
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
1111
Einops.jl is a Julia implementation of the [einops](https://einops.rocks) Python package, providing an elegant and intuitive notation for tensor operations. We currently implement `rearrange`, offering a unified way to perform Julia's `reshape` and `permutedims`, as well as `pack` and `unpack`, with plans for implementing `reduce` and `repeat`.
1212

13-
The Python implementation uses strings to specify the operation, but that would be tricky to compile in Julia, so a string macro `@einops_str` is exported for parity, e.g. `einops"a 1 b c -> (c b) a"`, which expands to the form `(:a, 1, :b, :c,) => ((:c, :b), :a)`, allowing for compile-time awareness of dimensionalities, ensuring type stability.
13+
The Python implementation uses strings to specify the operation, but that would be tricky to compile in Julia, so a string macro `@einops_str` is exported for parity, e.g. `einops"a 1 b c -> (c b) a"`, which expands to the form `(:a, 1, :b, :c,) --> ((:c, :b), :a)`, allowing for compile-time awareness of dimensionalities, ensuring type stability.
1414

1515
## Operations
1616

@@ -22,15 +22,15 @@ The `rearrange` combines reshaping and permutation operations into a single, exp
2222
julia> images = randn(32, 30, 40, 3); # batch, height, width, channel
2323

2424
# reorder axes to "b c h w" format:
25-
julia> rearrange(images, (:b, :h, :w, :c) => (:b, :c, :h, :w)) |> size
25+
julia> rearrange(images, (:b, :h, :w, :c) --> (:b, :c, :h, :w)) |> size
2626
(32, 3, 30, 40)
2727

2828
# flatten each image into a vector
29-
julia> rearrange(images, (:b, :h, :w, :c) => (:b, (:h, :w, :c))) |> size
29+
julia> rearrange(images, (:b, :h, :w, :c) --> (:b, (:h, :w, :c))) |> size
3030
(32, 3600)
3131

3232
# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
33-
julia> rearrange(images, (:b, (:h1, :h), (:w1, :w), :c) => ((:b, :h1, :w1), :h, :w, :c), h1=2, w1=2) |> size
33+
julia> rearrange(images, (:b, (:h1, :h), (:w1, :w), :c) --> ((:b, :h1, :w1), :h, :w, :c), h1=2, w1=2) |> size
3434
(128, 15, 20, 3)
3535
```
3636

@@ -41,7 +41,7 @@ The `reduce` function will allow for applying reduction operations (like `sum`,
4141
```julia
4242
# Example (conceptual):
4343
x = randn(100, 32, 64)
44-
y = reduce(maximum, x, (:t, :b, :c) => (:b, :c)) # max-reduction on the first axis
44+
y = reduce(maximum, x, (:t, :b, :c) --> (:b, :c)) # max-reduction on the first axis
4545
```
4646

4747
### `repeat` (Planned)
@@ -51,7 +51,7 @@ The `repeat` function will provide a concise way to repeat elements along existi
5151
```julia
5252
# Example (conceptual):
5353
image = randn(30, 40)
54-
rgb_image = repeat(image, (:h, :w) => (:repeat, :h, :w), repeat=3)
54+
rgb_image = repeat(image, (:h, :w) --> (:repeat, :h, :w), repeat=3)
5555
```
5656

5757
## Roadmap

src/Einops.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export ..
77
# TODO: use TransmuteDims.jl
88

99
include("utils.jl")
10+
export -->
1011

1112
include("einops_str.jl")
1213
export @einops_str
@@ -19,7 +20,7 @@ export pack, unpack
1920

2021
# TODO: implement reduce, repeat
2122
Base.reduce(f, x::AbstractArray, pattern::Pattern; context...) = error("Not implemented")
22-
Base.repeat(x, pattern::Pattern; context...) = error("Not implemented")
23+
Base.repeat(x::AbstractArray, pattern::Pattern; context...) = error("Not implemented")
2324

2425
# TODO: einsum
2526

src/einops_str.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function parse_pattern(pattern::AbstractString)
44
lhs, rhs = strip.(split(pattern, "->"; limit = 2))
55
lhs_axes = tokenize_side(lhs)
66
rhs_axes = tokenize_side(rhs)
7-
return Tuple(lhs_axes) => Tuple(rhs_axes)
7+
return Tuple(lhs_axes) --> Tuple(rhs_axes)
88
end
99

1010
function tokenize_side(side::AbstractString)
@@ -86,10 +86,10 @@ For parity with Python implementation.
8686
8787
```jldoctest
8888
julia> einops"a 1 b c -> (c b) a"
89-
(:a, 1, :b, :c) => ((:c, :b), :a)
89+
(:a, 1, :b, :c) --> ((:c, :b), :a)
9090
9191
julia> einops"embed token (head batch) -> (embed head) token batch"
92-
(:embed, :token, (:head, :batch)) => ((:embed, :head), :token, :batch)
92+
(:embed, :token, (:head, :batch)) --> ((:embed, :head), :token, :batch)
9393
9494
julia> einops"i j * k" # for packing
9595
(:i, :j, *, :k)

src/rearrange.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,22 @@ end
4646
reshape_out(x, ::Tuple{Vararg{Symbol}}) = x
4747

4848
"""
49-
rearrange(x::AbstractArray, left => right)
49+
rearrange(x::AbstractArray, left --> right)
5050
51-
Rearrange the axes of `x` according to the pattern specified by `left => right`.
51+
Rearrange the axes of `x` according to the pattern specified by `left --> right`.
5252
5353
Can always be expressed as a `reshape` + `permutedims` + `reshape`.
5454
5555
# Examples
5656
5757
```jldoctest
58-
julia> rearrange(rand(2,3,5), (:a, :b, :c) => (:c, :b, :a)) |> size
58+
julia> rearrange(rand(2,3,5), (:a, :b, :c) --> (:c, :b, :a)) |> size
5959
(5, 3, 2)
6060
6161
julia> permutedims(rand(2,3,5), (3,2,1)) |> size
6262
(5, 3, 2)
6363
64-
julia> rearrange(rand(2,3,35), (:a, :b, (:c, :d)) => (:a, :d, (:c, :b)), c=5) |> size
64+
julia> rearrange(rand(2,3,35), (:a, :b, (:c, :d)) --> (:a, :d, (:c, :b)), c=5) |> size
6565
(2, 7, 15)
6666
6767
julia> reshape(permutedims(reshape(rand(2,3,35), 2,3,5,7), (1,4,3,2)), 2,7,5*3) |> size

src/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
struct Pattern{L,R} end
2+
(-->)(L, R) = Pattern{L, R}()
3+
Base.show(io::IO, ::Pattern{L,R}) where {L,R} = print(io, "$L --> $R")
4+
Base.iterate(::Pattern{L}) where L = (L, Val(:R))
5+
Base.iterate(::Pattern{<:Any,R}, ::Val{:R}) where R = (R, nothing)
6+
Base.iterate(::Pattern, ::Nothing) = nothing
7+
18
function permutation_mapping(left::NTuple{N,T}, right::NTuple{N,T}) where {N,T}
29
perm::Vector{Int} = findfirst.(isequal.([right...]), Ref([left...]))
310
return ntuple(i -> perm[i], Val(N))

test/runtests.jl

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ using Test, Statistics
1818
@testset "einops string tokenization" begin
1919

2020
@testset "arrow pattern" begin
21-
@test einops"a b c -> a (c b)" == ((:a, :b, :c) => (:a, (:c, :b)))
22-
@test einops"a b c -> a(c b)" == ((:a, :b, :c) => (:a, (:c, :b)))
23-
@test einops"a b 1 -> a 1 b" == ((:a, :b, 1) => (:a, 1, :b))
24-
@test einops"a b () -> a () b" == ((:a, :b, ()) => (:a, (), :b))
25-
@test einops"a b()->a()b" == ((:a, :b, ()) => (:a, (), :b))
26-
@test einops"b ... -> b ..." == ((:b, ..) => (:b, ..))
27-
@test einops"->" == (() => ())
28-
@test einops"-> 1" == (() => (1,))
21+
@test einops"a b c -> a (c b)" == ((:a, :b, :c) --> (:a, (:c, :b)))
22+
@test einops"a b c -> a(c b)" == ((:a, :b, :c) --> (:a, (:c, :b)))
23+
@test einops"a b 1 -> a 1 b" == ((:a, :b, 1) --> (:a, 1, :b))
24+
@test einops"a b () -> a () b" == ((:a, :b, ()) --> (:a, (), :b))
25+
@test einops"a b()->a()b" == ((:a, :b, ()) --> (:a, (), :b))
26+
@test einops"b ... -> b ..." == ((:b, ..) --> (:b, ..))
27+
@test einops"->" == (() --> ())
28+
@test einops"-> 1" == (() --> (1,))
2929
@test_throws "'.'" Einops.parse_pattern("-> .")
3030
@test_throws "'('" Einops.parse_pattern("-> (")
3131
@test_throws "')'" Einops.parse_pattern("-> )")
@@ -43,36 +43,36 @@ using Test, Statistics
4343
@testset "rearrange" begin
4444

4545
x = rand(2,3,5)
46-
@test rearrange(x, (:a, :b, :c) => (:c, :b, :a)) == permutedims(x, (3,2,1))
47-
@test rearrange(x, (:a, :b, :c) => (:a, (:c, :b))) == reshape(permutedims(x, (1,3,2)), 2,5*3)
48-
@test rearrange(x, (:first, :second, :third) => (:third, :second, :first)) == rearrange(x, (:a, :b, :c) => (:c, :b, :a))
49-
@test_throws "Input length" rearrange(x, (:a, (:b, :c)) => (:c, :b, :a))
50-
@test_throws ["Set of", "does not match"] rearrange(x, (:a, :b, :c) => (:a, :b, :a))
51-
@test_throws ["Left names", "not unique"] rearrange(x, (:a, :a, :b) => (:a, :b))
52-
@test_throws ["Right names", "not unique"] rearrange(x, (:a, :b, :c) => (:a, :b, :c, :a))
53-
@test_throws "Invalid input dimension" rearrange(x, (:a, :b, 'c') => (:a, :b, :c))
54-
@test_broken rearrange(x, (:a, :b, ..) => (:a, .., :b)) == rearrange(x, (:a, :b, :c) => (:a, :c, :b))
46+
@test rearrange(x, (:a, :b, :c) --> (:c, :b, :a)) == permutedims(x, (3,2,1))
47+
@test rearrange(x, (:a, :b, :c) --> (:a, (:c, :b))) == reshape(permutedims(x, (1,3,2)), 2,5*3)
48+
@test rearrange(x, (:first, :second, :third) --> (:third, :second, :first)) == rearrange(x, (:a, :b, :c) --> (:c, :b, :a))
49+
@test_throws "Input length" rearrange(x, (:a, (:b, :c)) --> (:c, :b, :a))
50+
@test_throws ["Set of", "does not match"] rearrange(x, (:a, :b, :c) --> (:a, :b, :a))
51+
@test_throws ["Left names", "not unique"] rearrange(x, (:a, :a, :b) --> (:a, :b))
52+
@test_throws ["Right names", "not unique"] rearrange(x, (:a, :b, :c) --> (:a, :b, :c, :a))
53+
@test_throws "Invalid input dimension" rearrange(x, (:a, :b, 'c') --> (:a, :b, :c))
54+
@test_broken rearrange(x, (:a, :b, ..) --> (:a, .., :b)) == rearrange(x, (:a, :b, :c) --> (:a, :c, :b))
5555

5656
x = reshape(rand(1)) # size (), length 1
57-
@test rearrange(x, () => ()) == x
58-
@test rearrange(x, () => (1,)) == reshape(x, 1)
57+
@test rearrange(x, () --> ()) == x
58+
@test rearrange(x, () --> (1,)) == reshape(x, 1)
5959

6060
x = rand(2,3,5*7)
61-
@test rearrange(x, (:a, :b, (:c, :d)) => (:a, :d, (:c, :b)), c=5) == reshape(permutedims(reshape(x, 2,3,5,7), (1,4,3,2)), 2,7,5*3)
61+
@test rearrange(x, (:a, :b, (:c, :d)) --> (:a, :d, (:c, :b)), c=5) == reshape(permutedims(reshape(x, 2,3,5,7), (1,4,3,2)), 2,7,5*3)
6262

6363
x = rand(2,3,5*7*11)
64-
@test rearrange(x, (:a, :b, (:c, :d, :e)) => ((:a, :e), :d, (:c, :b)), c=5, d=7) == reshape(permutedims(reshape(x, 2,3,5,7,11), (1,5,4,3,2)), 2*11,7,5*3)
65-
@test_throws "Unknown dimension sizes" rearrange(x, (:a, :b, (:c, :d, :e)) => (:a, :b, :c, :d, :e), c=5)
64+
@test rearrange(x, (:a, :b, (:c, :d, :e)) --> ((:a, :e), :d, (:c, :b)), c=5, d=7) == reshape(permutedims(reshape(x, 2,3,5,7,11), (1,5,4,3,2)), 2*11,7,5*3)
65+
@test_throws "Unknown dimension sizes" rearrange(x, (:a, :b, (:c, :d, :e)) --> (:a, :b, :c, :d, :e), c=5)
6666

6767
x = rand(2,1,3)
68-
@test rearrange(x, (:a, 1, :b) => (:a, :b)) == dropdims(x, dims=2)
69-
@test_throws "Singleton dimension size is not 1" rearrange(x, (2, :a, :b) => (:a, :b))
70-
@test_throws "Singleton dimension size is not 1" rearrange(x, (:a, :b, :c) => (:a, :b, :c, 2))
68+
@test rearrange(x, (:a, 1, :b) --> (:a, :b)) == dropdims(x, dims=2)
69+
@test_throws "Singleton dimension size is not 1" rearrange(x, (2, :a, :b) --> (:a, :b))
70+
@test_throws "Singleton dimension size is not 1" rearrange(x, (:a, :b, :c) --> (:a, :b, :c, 2))
7171

7272
x = rand(2,3)
73-
@test rearrange(x, (:a, :b) => (:b, 1, :a)) == reshape(permutedims(x, (2,1)), 3,1,2)
74-
@test rearrange(x, (:a, :b) => (:b, 1, 1, :a, 1)) == reshape(permutedims(x, (2,1)), 3,1,1,2,1)
75-
@test rearrange(x, (:a, :b) => (:b, (), :a)) == rearrange(x, (:a, :b) => (:b, (), :a))
73+
@test rearrange(x, (:a, :b) --> (:b, 1, :a)) == reshape(permutedims(x, (2,1)), 3,1,2)
74+
@test rearrange(x, (:a, :b) --> (:b, 1, 1, :a, 1)) == reshape(permutedims(x, (2,1)), 3,1,1,2,1)
75+
@test rearrange(x, (:a, :b) --> (:b, (), :a)) == rearrange(x, (:a, :b) --> (:b, (), :a))
7676

7777
@testset "Python API reference parity" begin
7878
# see https://einops.rocks/api/rearrange/
@@ -106,6 +106,8 @@ using Test, Statistics
106106

107107
@testset "reduce" begin
108108

109+
@test_throws "Not implemented" reduce(+, rand(2,3,4), einops"a b c -> b c")
110+
109111
@testset "Python API reference parity" begin
110112
# see https://einops.rocks/api/reduce/
111113

@@ -154,6 +156,8 @@ using Test, Statistics
154156

155157
@testset "repeat" begin
156158

159+
@test_throws "Not implemented" repeat(rand(2,3,4), einops"a b c -> b c")
160+
157161
@testset "Python API reference parity" begin
158162
# see https://einops.rocks/api/repeat/
159163

0 commit comments

Comments
 (0)