@@ -17,57 +17,10 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) whe
17
17
ys, generator_pullback
18
18
end
19
19
20
- # Needed for Yota, but shouldn't these be automatic?
21
- ChainRulesCore. rrule (:: Type{<:Base.Generator} , f, iter) = Base. Generator (f, iter), dy -> (NoTangent (), dy. f, dy. iter)
22
- ChainRulesCore. rrule (:: Type{<:Iterators.ProductIterator} , iters) = Iterators. ProductIterator (iters), dy -> (NoTangent (), dy. iterators)
23
-
24
- #=
25
-
26
- Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
27
- Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
28
-
29
- Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape
30
- Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators
31
-
32
- Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3)
33
- Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally
34
-
35
- Yota.grad(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3)
36
- Diffractor.gradient(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3)
37
-
38
- Yota.grad(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3)
39
- Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3)
40
-
41
-
42
- @btime Yota.grad($(rand(1000))) do xs
43
- sum(abs2, [sqrt(x) for x in xs])
44
- end
45
- # Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB)
46
- # Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB)
47
-
48
- # Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB)
49
-
50
-
51
- @btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys
52
- zs = map(xs, ys) do x, y
53
- atan(x/y)
54
- end
55
- sum(abs2, zs)
56
- end
57
- # Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB)
58
- # Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB)
59
-
60
- # Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster
61
-
62
-
63
- =#
64
-
65
-
66
20
# ####
67
21
# #### `zip`
68
22
# ####
69
23
70
-
71
24
function rrule (:: typeof (zip), xs:: AbstractArray... )
72
25
function zip_pullback (dy)
73
26
@debug " zip array pullback" summary (dy)
@@ -94,8 +47,6 @@ function _unmap_pad(x::AbstractArray, dx::AbstractArray)
94
47
@debug " _unmap_pad is extending gradient" length (x) == length (dx)
95
48
i1 = firstindex (x)
96
49
∇getindex (x, vec (dx), i1: i1+ length (dx)- 1 )
97
- # dx2 = vcat(vec(dx), similar(x, ZeroTangent, length(x) - length(dx)))
98
- # ProjectTo(x)(reshape(dx2, axes(x)))
99
50
end
100
51
end
101
52
0 commit comments