Skip to content

Commit 3b4cadd

Browse files
committed
multi-array input
1 parent 5560133 commit 3b4cadd

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

src/rearrange.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,11 @@ end
4545

4646
reshape_out(x, ::Tuple{Vararg{Symbol}}) = x
4747

48+
# TODO: support taking vectors of arrays
49+
4850
"""
49-
rearrange(x::AbstractArray, left --> right)
51+
rearrange(array::AbstractArray, left --> right; context...)
52+
rearrange(arrays, left --> right; context...)
5053
5154
Rearrange the axes of `x` according to the pattern specified by `left --> right`.
5255
@@ -68,11 +71,17 @@ julia> reshape(permutedims(reshape(rand(2,3,35), 2,3,5,7), (1,4,3,2)), 2,7,5*3)
6871
(2, 7, 15)
6972
```
7073
"""
71-
function rearrange(x, (left, right); context...)
74+
function rearrange(x::AbstractArray, pattern::Pattern; context...)
75+
@nospecialize pattern
76+
left, right = pattern
7277
(!isempty(extract(typeof(..), left)) || !isempty(extract(typeof(..), right))) && throw(ArgumentError("Ellipses (..) are currently not supported"))
7378
left_names, right_names = extract(Symbol, left), extract(Symbol, right)
7479
reshaped_in = reshape_in(x, left; context...)
7580
permuted = permute(reshaped_in, left_names, right_names)
7681
reshaped_out = reshape_out(permuted, right)
7782
return reshaped_out
7883
end
84+
85+
rearrange(x::AbstractArray{<:AbstractArray}, pattern::Pattern; context...) = rearrange(stack(x), pattern; context...)
86+
87+
rearrange(x, pattern::Pattern; context...) = rearrange(stack(x), pattern; context...)

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ using Test, Statistics
7474
@test rearrange(x, (:a, :b) --> (:b, 1, 1, :a, 1)) == reshape(permutedims(x, (2,1)), 3,1,1,2,1)
7575
@test rearrange(x, (:a, :b) --> (:b, (), :a)) == rearrange(x, (:a, :b) --> (:b, (), :a))
7676

77+
x = rand(2,3,5)
78+
@test rearrange([x, x], (:a, :b, :c, :d) --> (:c, :b, :a, :d)) == permutedims(stack([x, x]), (3,2,1,4))
79+
@test rearrange(reshape([x, x], 1, 2), (:a, :b, :c, 1, :d) --> (:c, :b, :a, :d)) == permutedims(reshape(cat(x, x, dims=5), 2,3,5,2), (3,2,1,4))
80+
@test rearrange((x, x), (:a, :b, :c, :d) --> (:c, :b, :a, :d)) == permutedims(cat(x, x, dims=4), (3,2,1,4))
81+
7782
@testset "Python API reference parity" begin
7883
# see https://einops.rocks/api/rearrange/
7984

0 commit comments

Comments
 (0)