Skip to content

Commit 654c2db

Browse files
committed
Add pack and unpack
1 parent b43ee71 commit 654c2db

File tree

4 files changed

+91
-19
lines changed

4 files changed

+91
-19
lines changed

src/Einops.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ export @einops_str
2222
include("rearrange.jl")
2323
export rearrange
2424

25+
include("pack_unpack.jl")
26+
export pack, unpack
27+
2528
# TODO: implement reduce, repeat
2629
Base.reduce(f, x::AbstractArray, pattern::Pattern; context...) = error("Not implemented")
2730
Base.repeat(x, pattern::Pattern; context...) = error("Not implemented")
2831

29-
# TODO: einsum, pack, unpack
32+
# TODO: einsum
3033

3134
end

src/einops_str.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
function parse_pattern(pattern::AbstractString)
2-
occursin("->", pattern) ||
3-
throw(ArgumentError("pattern must contain \"->\" (got \"$pattern\")"))
2+
occursin("->", pattern) || return tokenize_packing_pattern(pattern)
43

54
lhs, rhs = strip.(split(pattern, "->"; limit = 2))
6-
lhs_axes = tokenise_side(lhs)
7-
rhs_axes = tokenise_side(rhs)
5+
lhs_axes = tokenize_side(lhs)
6+
rhs_axes = tokenize_side(rhs)
87
return Tuple(lhs_axes) --> Tuple(rhs_axes)
98
end
109

11-
function tokenise_side(side::AbstractString)
10+
function tokenize_side(side::AbstractString)
1211
tokens = Any[]
1312
stack = Vector{Any}[]
1413
buf = IOBuffer()
@@ -73,8 +72,13 @@ function tokenise_side(side::AbstractString)
7372
return Tuple(tokens)
7473
end
7574

75+
function tokenize_packing_pattern(pattern::AbstractString)
76+
Tuple(map(s -> s == "*" ? (*) : Symbol(s), filter(!isempty, split(pattern, ' '))))
77+
end
78+
7679
"""
77-
einops"... -> ..."
80+
einops"a (b c) -> (c b a)"
81+
einops"i j * k"
7882
7983
For parity with Python implementation.
8084
@@ -86,6 +90,9 @@ julia> einops"a 1 b c -> (c b) a"
8690
8791
julia> einops"embed token (head batch) -> (embed head) token batch"
8892
(:embed, :token, (:head, :batch)) --> ((:embed, :head), :token, :batch)
93+
94+
julia> einops"i j * k" # for packing
95+
(:i, :j, *, :k)
8996
```
9097
"""
9198
macro einops_str(pattern)

src/pack_unpack.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
const Wildcard = typeof(*)
2+
const PackPattern{N} = NTuple{N,Union{Symbol,Wildcard}}
3+
4+
checkpacking(pattern::PackPattern) = count(x -> x isa Wildcard, pattern) == 1 || error("Only one wildcard (*) is allowed in the pattern")
5+
find_wildcard(pattern::PackPattern) = findfirst(x -> x isa Wildcard, pattern)
6+
7+
size_before_wildcard(dims::Dims, pattern::PackPattern) = dims[1:find_wildcard(pattern)-1]
8+
size_after_wildcard(dims::Dims, pattern::PackPattern) = dims[end-length(pattern)+find_wildcard(pattern)+1:end]
9+
size_wildcard(dims::Dims, pattern::PackPattern) = dims[begin+length(size_before_wildcard(dims, pattern)):end-length(size_after_wildcard(dims, pattern))]
10+
packed_size(dims::Dims, pattern::PackPattern) = (size_before_wildcard(dims, pattern)..., prod(size_wildcard(dims, pattern)), size_after_wildcard(dims, pattern)...)
11+
12+
function pack(unpacked_arrays::Vector, pattern::PackPattern{N}) where N
13+
checkpacking(pattern)
14+
reshaped_arrays = [reshape(A, packed_size(size(A), pattern)::Dims{N})::AbstractArray{<:Any,N} for A in unpacked_arrays]
15+
concatenated_array::AbstractArray{<:Any,N} = cat(reshaped_arrays..., dims=find_wildcard(pattern))
16+
packed_shapes = Dims[size_wildcard(size(A), pattern) for A in unpacked_arrays]
17+
return concatenated_array, packed_shapes
18+
end
19+
20+
splice(a::Dims, i::Int, r::Dims) = (a[1:i-1]..., r..., a[i+1:end]...)
21+
22+
# FIXME: results in a reshape of a view ... so should we collect?
23+
function unpack(packed_array::AbstractArray{<:Any,N}, packed_shapes, pattern::PackPattern{N}) where N
24+
checkpacking(pattern)
25+
inds = Iterators.accumulate(+, Iterators.map(prod, packed_shapes))
26+
unpacked_arrays = map(Iterators.flatten((0, inds)), inds, packed_shapes) do i, j, ps
27+
A = selectdim(packed_array, find_wildcard(pattern), i+1:j)
28+
reshape(A, splice(size(A), find_wildcard(pattern), ps))
29+
end
30+
return unpacked_arrays
31+
end

test/runtests.jl

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,28 @@ using Test, Statistics
1616
end
1717

1818
@testset "einops string tokenization" begin
19-
@test einops"a b c -> a (c b)" == ((:a, :b, :c) --> (:a, (:c, :b)))
20-
@test einops"a b c -> a(c b)" == ((:a, :b, :c) --> (:a, (:c, :b)))
21-
@test einops"a b 1 -> a 1 b" == ((:a, :b, 1) --> (:a, 1, :b))
22-
@test einops"a b () -> a () b" == ((:a, :b, ()) --> (:a, (), :b))
23-
@test einops"a b()->a()b" == ((:a, :b, ()) --> (:a, (), :b))
24-
@test einops"b ... -> b ..." == ((:b, ..) --> (:b, ..))
25-
@test einops"->" == (() --> ())
26-
@test einops"-> 1" == (() --> (1,))
27-
@test_throws "'.'" Einops.parse_pattern("-> .")
28-
@test_throws "'('" Einops.parse_pattern("-> (")
29-
@test_throws "')'" Einops.parse_pattern("-> )")
30-
@test_throws "->" Einops.parse_pattern("")
19+
20+
@testset "arrow pattern" begin
21+
@test einops"a b c -> a (c b)" == ((:a, :b, :c) --> (:a, (:c, :b)))
22+
@test einops"a b c -> a(c b)" == ((:a, :b, :c) --> (:a, (:c, :b)))
23+
@test einops"a b 1 -> a 1 b" == ((:a, :b, 1) --> (:a, 1, :b))
24+
@test einops"a b () -> a () b" == ((:a, :b, ()) --> (:a, (), :b))
25+
@test einops"a b()->a()b" == ((:a, :b, ()) --> (:a, (), :b))
26+
@test einops"b ... -> b ..." == ((:b, ..) --> (:b, ..))
27+
@test einops"->" == (() --> ())
28+
@test einops"-> 1" == (() --> (1,))
29+
@test_throws "'.'" Einops.parse_pattern("-> .")
30+
@test_throws "'('" Einops.parse_pattern("-> (")
31+
@test_throws "')'" Einops.parse_pattern("-> )")
32+
end
33+
34+
@testset "packing pattern" begin
35+
@test einops"i j * k" == (:i, :j, *, :k)
36+
@test einops" i j * k " == (:i, :j, *, :k)
37+
@test einops"* i" == (*, :i)
38+
@test einops"i *" == (:i, *)
39+
end
40+
3141
end
3242

3343
@testset "rearrange" begin
@@ -172,4 +182,25 @@ using Test, Statistics
172182

173183
end
174184

185+
@testset "pack_unpack" begin
186+
187+
@testset "Python API reference parity" begin
188+
# see https://einops.rocks/api/pack_unpack/
189+
190+
inputs = [rand(2, 3, 5), rand(2, 3, 7, 5), rand(2, 3, 7, 9, 5)]
191+
@test begin
192+
packed, ps = pack(inputs, einops"i j * k")
193+
packed |> size == (2, 3, 71, 5) && ps == [(), (7,), (7, 9)]
194+
end
195+
196+
@test begin
197+
packed, ps = pack(inputs, einops"i j * k")
198+
inputs_unpacked = unpack(packed, ps, einops"i j * k")
199+
all(inputs .== inputs_unpacked)
200+
end
201+
202+
end
203+
204+
end
205+
175206
end

0 commit comments

Comments
 (0)