Skip to content

Commit 43977a3

Browse files
committed
Add parse_shape
1 parent 3b4cadd commit 43977a3

File tree

6 files changed

+103
-46
lines changed

6 files changed

+103
-46
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,13 @@ rgb_image = repeat(image, (:h, :w) --> (:repeat, :h, :w), repeat=3)
5858

5959
* [x] Implement `rearrange`.
6060
* [x] Support Python implementation's string syntax for patterns with string macro.
61+
* [x] Implement `parse_shape`.
62+
* [x] Implement `pack` and `unpack`.
6163
* [ ] Support ellipsis notation (using `..` from [EllipsisNotation.jl](https://github.com/SciML/EllipsisNotation.jl)).
6264
* [ ] Implement `reduce`.
6365
* [ ] Implement `repeat`.
6466
* [ ] Explore integration with `PermutedDimsArray` or `TransmuteDims.jl` for lazy and statically inferrable permutations.
65-
* [ ] Implement `einsum` (or wrap existing implementation)
66-
* [x] Implement `pack`, and `unpack`.
67+
* [ ] Implement `einsum` (or wrap existing implementation).
6768

6869
## Contributing
6970

src/Einops.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export ..
88

99
include("utils.jl")
1010
export -->
11+
export parse_shape
1112

1213
include("einops_str.jl")
1314
export @einops_str

src/einops_str.jl

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,65 @@
11
function parse_pattern(pattern::AbstractString)
2-
occursin("->", pattern) || return tokenize_packing_pattern(pattern)
3-
2+
occursin("->", pattern) || return tokenize_generic(pattern)
43
lhs, rhs = strip.(split(pattern, "->"; limit = 2))
4+
occursin("->", rhs) && throw(ArgumentError("multiple \"->\" in pattern"))
55
lhs_axes = tokenize_side(lhs)
66
rhs_axes = tokenize_side(rhs)
77
return Tuple(lhs_axes) --> Tuple(rhs_axes)
88
end
99

1010
function tokenize_side(side::AbstractString)
11+
12+
function parse_token!(buf::IOBuffer, tokens::Vector)
13+
if position(buf) > 0
14+
s = String(take!(buf))
15+
token = tryparse(Int, s)
16+
isnothing(token) && (token = Symbol(s))
17+
push!(tokens, token)
18+
end
19+
end
20+
1121
tokens = Any[]
12-
stack = Vector{Any}[]
1322
buf = IOBuffer()
23+
stack = Vector{Any}[]
1424
i = firstindex(side)
15-
1625
while i <= lastindex(side)
1726
c = side[i]
18-
1927
if c == ' '
20-
if position(buf) > 0
21-
s = String(take!(buf))
22-
token = tryparse(Int, s)
23-
isnothing(token) && (token = Symbol(s))
24-
push!(tokens, token)
25-
end
28+
parse_token!(buf, tokens)
2629
i += 1
27-
2830
elseif c == '('
29-
if position(buf) > 0
30-
s = String(take!(buf))
31-
token = tryparse(Int, s)
32-
isnothing(token) && (token = Symbol(s))
33-
push!(tokens, token)
34-
end
31+
parse_token!(buf, tokens)
3532
push!(stack, tokens)
3633
tokens = Any[]
3734
i += 1
38-
3935
elseif c == ')'
40-
if position(buf) > 0
41-
s = String(take!(buf))
42-
token = tryparse(Int, s)
43-
isnothing(token) && (token = Symbol(s))
44-
push!(tokens, token)
45-
end
36+
parse_token!(buf, tokens)
4637
isempty(stack) && throw(ArgumentError("unmatched ')' in pattern"))
4738
sub = tokens
4839
tokens = pop!(stack)
4940
push!(tokens, Tuple(sub))
5041
i += 1
51-
5242
elseif c == '.'
5343
# Expect literal "..."
5444
(i + 2 lastindex(side) && side[i:i+2] == "...") ||
5545
throw(ArgumentError("single '.' not allowed in pattern"))
5646
push!(tokens, ..)
5747
i += 3
58-
5948
else
6049
write(buf, c)
6150
i += 1
6251
end
6352
end
53+
parse_token!(buf, tokens)
6454

65-
if position(buf) > 0
66-
s = String(take!(buf))
67-
token = tryparse(Int, s)
68-
isnothing(token) && (token = Symbol(s))
69-
push!(tokens, token)
70-
end
7155
!isempty(stack) && throw(ArgumentError("unmatched '(' in pattern"))
7256
return Tuple(tokens)
7357
end
7458

75-
function tokenize_packing_pattern(pattern::AbstractString)
76-
Tuple(map(s -> s == "*" ? (*) : Symbol(s), filter(!isempty, split(pattern, ' '))))
77-
end
59+
const SpecialToken = Dict(:* => (*), :_ => (-))
60+
get_special_token(symbol) = get(SpecialToken, symbol, symbol)
61+
mapfilter(f, pred, xs) = map(f, filter(pred, xs))
62+
tokenize_generic(pattern) = Tuple(mapfilter(get_special_token Symbol, !isempty, split(pattern, ' ')))
7863

7964
"""
8065
einops"a (b c) -> (c b a)"

src/pack_unpack.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
const Wildcard = typeof(*)
22
const PackPattern{N} = NTuple{N,Union{Symbol,Wildcard}}
33

4-
checkpacking(pattern::PackPattern) = count(x -> x isa Wildcard, pattern) == 1 || error("Only one wildcard (*) is allowed in the pattern")
4+
function check_packing_pattern(pattern::PackPattern)
5+
count(x -> x isa Wildcard, pattern) == 1 || error("Only one wildcard (*) is allowed in the pattern")
6+
allunique(pattern) || error("Pattern $(pattern) has duplicate elements")
7+
return nothing
8+
end
9+
510
find_wildcard(pattern::PackPattern) = findfirst(x -> x isa Wildcard, pattern)
611

712
size_before_wildcard(dims::Dims, pattern::PackPattern) = dims[1:find_wildcard(pattern)-1]
@@ -31,7 +36,7 @@ julia> packed_shapes
3136
(7, 9)
3237
"""
3338
function pack(unpacked_arrays, pattern::PackPattern{N}) where N
34-
checkpacking(pattern)
39+
check_packing_pattern(pattern)
3540
reshaped_arrays = [reshape(A, packed_size(size(A), pattern)::Dims{N})::AbstractArray{<:Any,N} for A in unpacked_arrays]
3641
concatenated_array::AbstractArray{<:Any,N} = cat(reshaped_arrays..., dims=find_wildcard(pattern))
3742
packed_shapes = Dims[size_wildcard(size(unpacked_array), pattern) for unpacked_array in unpacked_arrays]
@@ -66,7 +71,7 @@ julia> unpack(packed_array, packed_shapes, (:i, :j, *)) .|> size
6671
```
6772
"""
6873
function unpack(packed_array::AbstractArray{<:Any,N}, packed_shapes, pattern::PackPattern{N}) where N
69-
checkpacking(pattern)
74+
check_packing_pattern(pattern)
7075
inds = Iterators.accumulate(+, Iterators.map(prod, packed_shapes))
7176
unpacked_arrays = map(Iterators.flatten((0, inds)), inds, packed_shapes) do i, j, ps
7277
unpacked_array = selectdim(packed_array, find_wildcard(pattern), i+1:j)

src/utils.jl

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
struct Pattern{L,R} end
2+
3+
# TODO: document -->
24
(-->)(L, R) = Pattern{L, R}()
5+
36
Base.show(io::IO, ::Pattern{L,R}) where {L,R} = print(io, "$L --> $R")
47
Base.iterate(::Pattern{L}) where L = (L, Val(:R))
58
Base.iterate(::Pattern{<:Any,R}, ::Val{:R}) where R = (R, nothing)
69
Base.iterate(::Pattern, ::Nothing) = nothing
710

8-
function permutation_mapping(left::NTuple{N,T}, right::NTuple{N,T}) where {N,T}
9-
perm::Vector{Int} = findfirst.(isequal.([right...]), Ref([left...]))
10-
return ntuple(i -> perm[i], Val(N))
11-
end
1211

1312
extract(::Type, ::Tuple{}) = ()
1413
function extract(T::Type, input_tuple::Tuple)
@@ -24,6 +23,57 @@ function extract(T::Type, input_tuple::Tuple)
2423
return (instances_from_first..., extract(T, rest_elements)...)
2524
end
2625

26+
27+
@generated function findtype(::Type{T}, xs::Tuple) where T
28+
inds = Int[]
29+
for (i, el_type) in enumerate(xs.parameters)
30+
el_type <: T && push!(inds, i)
31+
end
32+
return Expr(:tuple, inds...)
33+
end
34+
35+
const Ignored = typeof(-)
36+
const ShapePattern{N} = NTuple{N,Union{Symbol,Ignored}}
37+
38+
"""
39+
parse_shape(x, pattern)
40+
41+
Capture the shape of an array in a pattern by naming dimensions using `Symbol`s,
42+
and `-` to ignore dimensions.
43+
44+
```jldoctest
45+
julia> parse_shape(rand(2,3,4), (:a, :b, -))
46+
(a = 2, b = 3)
47+
48+
julia> parse_shape(rand(2,3), (-, -))
49+
NamedTuple()
50+
51+
julia> parse_shape(rand(2,3,4,5), (:first, :second, :third, :fourth))
52+
(first = 2, second = 3, third = 4, fourth = 5)
53+
```
54+
55+
The output is a `NamedTuple`, whose type contains the `Symbol` elements of the `pattern::NTuple{N,Union{Symbol,typeof(-)}}`,
56+
meaning that, unless the pattern is [constant-propagated](https://discourse.julialang.org/t/how-does-constant-propagation-work/22735/4),
57+
the output type is not known at compile time.
58+
59+
`@code_warntype parse_shape(rand(2,3,4), (:a, :b, -))`
60+
61+
`h() = parse_shape(rand(2,3,4), (:a, :b, -)); @code_warntype h()`
62+
"""
63+
function parse_shape(x::AbstractArray{<:Any,N}, pattern::ShapePattern{N}) where N
64+
names = extract(Symbol, pattern)
65+
allunique(names) || error("Pattern $(pattern) has duplicate elements")
66+
inds = findtype(Symbol, pattern)
67+
return NamedTuple{names,NTuple{length(inds),Int}}(size(x, i) for i in inds)
68+
end
69+
70+
71+
function permutation_mapping(left::NTuple{N,T}, right::NTuple{N,T}) where {N,T}
72+
perm::Vector{Int} = findfirst.(isequal.([right...]), Ref([left...]))
73+
return ntuple(i -> perm[i], Val(N))
74+
end
75+
76+
2777
# fix for 1.10:
2878
_permutedims(x::AbstractArray{T,0}, ::Tuple{}) where T = x
2979
_permutedims(x, perm) = permutedims(x, perm)

test/runtests.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,42 @@ using Test, Statistics
1515
@test repr((:a, :b, :c) --> (:c, :b, :a)) == "(:a, :b, :c) --> (:c, :b, :a)"
1616
end
1717

18+
@testset "parse_shape" begin
19+
@test begin
20+
x = rand(2,3,5)
21+
shape = parse_shape(x, (:a, :b, :c))
22+
(shape[:a], shape[:b], shape[:c]) == size(x)
23+
end
24+
end
25+
1826
@testset "einops string tokenization" begin
1927

28+
@testset "parse_shape pattern" begin
29+
@test einops"a _ c" == (:a, -, :c)
30+
@test einops"_ _ _" == (-, -, -)
31+
end
32+
2033
@testset "arrow pattern" begin
2134
@test einops"a b c -> a (c b)" == ((:a, :b, :c) --> (:a, (:c, :b)))
2235
@test einops"a b c -> a(c b)" == ((:a, :b, :c) --> (:a, (:c, :b)))
2336
@test einops"a b 1 -> a 1 b" == ((:a, :b, 1) --> (:a, 1, :b))
2437
@test einops"a b () -> a () b" == ((:a, :b, ()) --> (:a, (), :b))
2538
@test einops"a b()->a()b" == ((:a, :b, ()) --> (:a, (), :b))
2639
@test einops"b ... -> b ..." == ((:b, ..) --> (:b, ..))
40+
@test einops"b b -> a a" == ((:b, :b) --> (:a, :a))
2741
@test einops"->" == (() --> ())
2842
@test einops"-> 1" == (() --> (1,))
2943
@test_throws "'.'" Einops.parse_pattern("-> .")
3044
@test_throws "'('" Einops.parse_pattern("-> (")
3145
@test_throws "')'" Einops.parse_pattern("-> )")
3246
end
3347

34-
@testset "packing pattern" begin
48+
@testset "pack and unpack pattern" begin
3549
@test einops"i j * k" == (:i, :j, *, :k)
3650
@test einops" i j * k " == (:i, :j, *, :k)
3751
@test einops"* i" == (*, :i)
3852
@test einops"i *" == (:i, *)
53+
@test einops"i i" == (:i, :i)
3954
end
4055

4156
end

0 commit comments

Comments
 (0)