Skip to content

Commit e781ae1

Browse files
committed
fixup many unzip things
1 parent b020a2a commit e781ae1

File tree

2 files changed

+46
-33
lines changed

2 files changed

+46
-33
lines changed

src/unzipped.jl

+28-29
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,8 @@ end
6666
##### map
6767
#####
6868

69-
# `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`,
70-
# will be useful for the gradient of `map` etc.
71-
72-
7369
"""
74-
unzip_map(f, args...)
70+
unzip_map(f, args...)
7571
7672
For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
7773
but performed using `StructArrays` for efficiency.
@@ -86,40 +82,36 @@ function unzip_map(f::F, args...) where {F}
8682
end
8783

8884
unzip_map(f::F, args::Tuple...) where {F} = unzip(map(f, args...))
85+
# unzip_map(f::F, args::NamedTuple...) where {F} = unzip(map(f, args...))
8986

9087
unzip_map(f::F, args::AbstractGPUArray...) where {F} = unzip(map(f, args...))
9188

89+
"""
90+
unzip_map_reversed(f, args...)
91+
92+
For a pure function `f` which returns a tuple, this is `== unzip(map(f, args...))`.
93+
But the order of evaluation is should be the reverse.
94+
Does NOT handle `zip`-like behaviour.
95+
"""
9296
function unzip_map_reversed(f::F, args...) where {F}
9397
T = Broadcast.combine_eltypes(f, args)
9498
if isconcretetype(T)
9599
T <: Tuple || throw(ArgumentError("""unzip_map_reversed(f, args) only works on functions returning a tuple,
96100
but f = $(sprint(show, f)) returns type T = $T"""))
97101
end
98102
len1 = length(first(args))
99-
if all(a -> length(a)==len1, args)
100-
rev_args = map(Iterators.reverse, args)
101-
outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...)))
102-
else
103-
len = minimum(length, args)
104-
rev_args = map(a -> Iterators.reverse(@view a[begin:begin+len-1]), args)
105-
outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...)))
106-
end
107-
return map(reverse!!, outs)
103+
all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.")
104+
return map(reverse!!, unzip_map(f, map(_safereverse, args)...))
108105
end
109106

107+
# This avoids MethodError: no method matching iterate(::Base.Iterators.Reverse{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}) on 1.6
108+
_safereverse(x) = VERSION > v"1.7" ? Iterators.reverse(x) : reverse(x)
109+
110110
function unzip_map_reversed(f::F, args::Tuple...) where {F}
111-
len = minimum(length, args)
112-
rev_args = map(a -> reverse(a[1:len]), args)
113-
# vlen = Val(len)
114-
# rev_args = map(args) do a
115-
# reverse(ntuple(i -> a[i], vlen)) # does not infer better
116-
# end
117-
return map(reverse, unzip(map(f, rev_args...)))
111+
len1 = length(first(args))
112+
all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.")
113+
return map(reverse, unzip(map(f, map(reverse, args)...)))
118114
end
119-
# function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N}
120-
# rev_args = map(reverse, args)
121-
# return map(reverse, unzip(map(f, rev_args...)))
122-
# end
123115

124116
"""
125117
reverse!!(x)
@@ -135,10 +127,11 @@ function reverse!!(x::AbstractArray)
135127
end
136128
end
137129
reverse!!(x::AbstractArray{<:AbstractZero}) = x
130+
reverse!!(x) = reverse(x)
138131

139-
frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot)
132+
frule((_, xdot), ::typeof(reverse!!), x) = reverse!!(x), reverse!!(xdot)
140133

141-
function rrule(::typeof(reverse!!), x::AbstractArray)
134+
function rrule(::typeof(reverse!!), x)
142135
reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy)))
143136
return reverse!!(x), reverse!!_back
144137
end
@@ -181,10 +174,16 @@ end
181174
Expr(:tuple, each...)
182175
end
183176

184-
unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy
177+
function unzip(xs::AbstractArray{Tuple{T}}) where {T}
178+
if isbitstype(T)
179+
(reinterpret(T, xs),) # best case, no copy
180+
else
181+
(map(only, xs),)
182+
end
183+
end
185184

186185
@generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple}
187-
each = if count(!Base.issingletontype, Ts.parameters) < 2
186+
each = if count(!Base.issingletontype, Ts.parameters) < 2 && all(isbitstype, Ts.parameters)
188187
# good case, no copy of data, some trivial arrays
189188
[Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters]
190189
else

test/unzipped.jl

+18-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed
33

44
@testset "unzipped.jl" begin
5-
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzipmap, unzipbroadcast, unzip_map, unzip_map_reversed]
5+
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzipmap, unzipbroadcast, unzip_map]
66
@test_throws Exception fun(sqrt, 1:3)
77

88
@test @inferred(fun(tuple, 1:3, 4:6)) == ([1, 2, 3], [4, 5, 6])
@@ -27,22 +27,32 @@ using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed
2727
end
2828
@test @inferred(fun(tuple, (1,2,3), [4,5,6])) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector
2929
end
30-
30+
3131
@testset "zip behaviour: $unzip_map" for unzip_map in [unzip_map, unzip_map_reversed]
3232
check(f, args...) = @inferred(unzip_map(f, args...)) == unzip(map(f, args...))
33+
check_no_inferr(f, args...) = unzip_map(f, args...) == unzip(map(f, args...))
34+
3335
@test check(tuple, [1 2; 3 4], [5,6,7,8]) # makes a vector
36+
@test check_no_inferr(tuple, [1,2,3], (5,6,7))
37+
38+
unzip_map == unzip_map_reversed && continue # does not handle unequal lengths.
39+
3440
@test check(tuple, [1 2; 3 4], [5,6,7])
3541
@test check(tuple, [1 2; 3 4], [5,6,7,8,9,10])
42+
43+
@test check_no_inferr(tuple, [1,2,3], (5,6,7,8))
44+
@test check_no_inferr(tuple, [1,2,3,4], (5,6,7))
45+
@test check_no_inferr(tuple, [1 2;3 4], (5,6,7))
3646
end
3747

3848
@testset "unzip_map_reversed" begin
3949
cnt(x, y) = (x, y) .+ (CNT[] += 1)
4050
CNT = Ref(0)
41-
@test unzip_map_reversed(cnt, [10, 20], [30, 40, 50]) == ([12, 21], [32, 41])
51+
@test unzip_map_reversed(cnt, [10, 20], [30, 40]) == ([12, 21], [32, 41])
4252
@test CNT[] == 2
4353

4454
CNT = Ref(0)
45-
@test unzip_map_reversed(cnt, (10, 20, 99), (30, 40)) == ((12, 21), (32, 41))
55+
@test unzip_map_reversed(cnt, (10, 20), (30, 40)) == ((12, 21), (32, 41))
4656
end
4757

4858
@testset "rrules" begin
@@ -76,6 +86,10 @@ using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed
7686
@test unzip([(1,), (3,), (5,)])[1] isa Base.ReinterpretArray
7787

7888
@test unzip(((1,2), (3,4), (5,6))) == ((1, 3, 5), (2, 4, 6))
89+
90+
# Bug: these cases cannot be done by reinterpret
91+
@test unzip([([1,2],), ([3,4],)]) == ([[1, 2], [3, 4]],)
92+
@test unzip([(nothing, [1,2]), (nothing, [3,4])]) == ([nothing, nothing], [[1, 2], [3, 4]])
7993

8094
# test_rrule(unzip, [(1,2), (3,4), (5.0,6.0)], check_inferred=false) # DimensionMismatch: second dimension of A, 6, does not match length of x, 2
8195

0 commit comments

Comments
 (0)