Skip to content

Commit 2223921

Browse files
committed
Update the simple examples in README
1 parent 425b97a commit 2223921

File tree

4 files changed

+207
-62
lines changed

4 files changed

+207
-62
lines changed

README.md

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,81 @@ The benchmarks at SOMEWHERE show that this library has performance highly compet
1313
It is hoped that the (well-typed) separation of AD logic and the tensor manipulation backend will enable similar speedups on numerical accelerators.
1414

1515

16-
# WIP: The examples below are outdated and will be replaced soon using a new API
17-
18-
1916
## Computing the derivative of a simple function
2017

2118
Here is an example of a Haskell function to be differentiated:
2219

2320
```hs
2421
-- A function that goes from R^3 to R.
25-
foo :: RealFloat a => (a,a,a) -> a
26-
foo (x,y,z) =
22+
foo :: RealFloat a => (a, a, a) -> a
23+
foo (x, y, z) =
2724
let w = x * sin y
2825
in atan2 z w + z * w -- note that w appears twice
2926
```
3027

31-
The gradient of `foo` is:
32-
<!--
33-
TODO: this may yet get simpler and the names not leaking implementation details
34-
("delta") so much, when the adaptor gets used at scale and redone.
35-
Alternatively, we could settle on Double already here.
36-
-->
28+
The gradient of `foo` instantiated to `Double` is:
3729
```hs
38-
grad_foo :: forall r. (HasDelta r, AdaptableScalar 'ADModeGradient r)
39-
=> (r, r, r) -> (r, r, r)
40-
grad_foo = rev @r foo
30+
gradFooDouble :: (Double, Double, Double) -> (Double, Double, Double)
31+
gradFooDouble = fromDValue . crev foo . fromValue
4132
```
4233

43-
As can be verified by computing the gradient at `(1.1, 2.2, 3.3)`:
34+
as can be verified by computing the gradient at `(1.1, 2.2, 3.3)`:
4435
```hs
45-
>>> grad_foo (1.1 :: Double, 2.2, 3.3)
36+
>>> gradFooDouble (1.1, 2.2, 3.3)
4637
(2.4396285219055063, -1.953374825727421, 0.9654825811012627)
4738
```
4839

49-
As a side note, `w` is processed only once during gradient computation and this property of sharing preservation is guaranteed universally by horde-ad without any action required from the user. The property holds not only for scalar values, but for arbitrary tensors, e.g., those in further examples. We won't mention the property further.
40+
Instantiated to matrices, the gradient is:
41+
```hs
42+
gradFooMatrix :: Differentiable r, GoodScalar r)
43+
=> (RepN (TKS '[2, 2] r), RepN (TKS '[2, 2] r), RepN (TKS '[2, 2] r))
44+
-> (RepN (TKS '[2, 2] r), RepN (TKS '[2, 2] r), RepN (TKS '[2, 2] r))
45+
gradFooMatrix = crev foo
46+
```
5047

51-
<!--
52-
Do we want yet another example here, before we reach Jacobians or shaped tensors? Perhaps one with the testing infrastructure, e.g., generating a single set of random tensors, or a full QuickCheck example or just a simple
48+
as can be verified by:
5349
```hs
54-
assertEqualUpToEpsilon 1e-9
55-
(6.221706565357043, -12.856908977773593, 6.043601532156671)
56-
(rev bar (1.1, 2.2, 3.3))
50+
>>> gradFooMatrix (srepl 1.1, srepl 2.2, srepl (3.3 :: Double))
51+
(sfromListLinear [2.4396285219055063,2.4396285219055063,2.4396285219055063,2.4396285219055063],sfromListLinear [-1.953374825727421,-1.953374825727421,-1.953374825727421,-1.953374825727421],sfromListLinear [0.9654825811012627,0.9654825811012627,0.9654825811012627,0.9654825811012627])
5752
```
58-
? Or is there a risk the reader won't make it to the shaped example below if we tarry here? Or perhaps finish the shaped tensor example below with an invocation of `assertEqualUpToEpsilon`?
59-
-->
53+
54+
Note that `w` is processed only once during gradient computation and this property of sharing preservation is guaranteed for the `crev` tool universally by horde-ad without any action required from the user. When computing symbolic derivative programs, however, the user has to explicitly mark values for sharing using `tlet` with a more specific type of the objective function, as shown below.
55+
56+
```hs
57+
fooLet :: (RealFloatH (target a), LetTensor target)
58+
=> (target a, target a, target a) -> target a
59+
fooLet (x, y, z) =
60+
tlet (x * sin y) $ \w ->
61+
atan2H z w + z * w
62+
```
63+
64+
The symbolic derivative program (here presented with additional formatting) can be obtained using the `revArtifactAdapt` tool:
65+
```hs
66+
>>> let ftk = FTKS @'[2, 2] [2, 2] (FTKScalar @Double)
67+
in printArtifactGradient
68+
(fst $ revArtifactAdapt True fooLet (FTKProduct (FTKProduct ftk ftk) ftk))
69+
"\m6 m1 ->
70+
let m3 = sin (tproject2 (tproject1 m1))
71+
m4 = tproject1 (tproject1 m1) * m3
72+
m5 = recip (tproject2 m1 * tproject2 m1 + m4 * m4)
73+
m7 = (negate (tproject2 m1) * m5) * m6 + tproject2 m1 * m6
74+
in tpair
75+
( tpair (m3 * m7, cos (tproject2 (tproject1 m1)) * (tproject1 (tproject1 m1) * m7))
76+
, (m4 * m5) * m6 + m4 * m6)"
77+
```
78+
79+
A quick inspection of the derivative program reveals that computations are not repeated, which is thanks to sharing. A concrete value of the symbolic derivative can be obtained by interpreting the derivative program in the context of the operations supplied by the horde-ad library. The value should be the same as when evaluating `fooLet` with `crev` on the concrete input, as before. A shorthand that creates the symbolic derivative program and evaluates it at a given input is called `rev` and is used exactly the same (but with potentially better performance) as `crev`.
80+
81+
82+
# WIP: The examples below are outdated and will be replaced soon using a new API
6083

6184

62-
<!--
6385
## Computing Jacobians
6486

6587
-- TODO: we can have vector/matrix/tensor codomains, but not pair codomains
6688
-- until #68 is done;
6789
-- perhaps a vector codomain example, with a 1000x3 Jacobian, would make sense?
90+
-- 2 years later: actually, we can now have TKProduct codomains.
6891

6992
Now let's consider a function from 'R^n` to `R^m'. We don't want the gradient, but instead the Jacobian.
7093
```hs
@@ -75,7 +98,7 @@ foo (x,y,z) =
7598
in (atan2 z w, z * w)
7699
```
77100
TODO: show how the 2x3 Jacobian emerges from here
78-
-->
101+
79102

80103

81104
## Forall shapes and sizes

src/HordeAd/Core/CarriersConcrete.hs

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,31 @@ instance (Nested.NumElt r, Nested.PrimElt r, Eq r, IntegralH r)
114114
, either (V.replicate (V.length x)) id y' )
115115
in V.zipWith
116116
(\a b -> if b == 0 then 0 else remH a b) x y)))
117-
-- TODO: do better somehow
117+
-- TODO: do better somehow'
118+
119+
instance GoodScalar r
120+
=> Real (Nested.Ranked n r) where
121+
toRational = error "horde-ad: operation not defined for tensor"
122+
123+
instance GoodScalar r
124+
=> Real (Nested.Shaped sh r) where
125+
toRational = error "horde-ad: operation not defined for tensor"
126+
127+
instance GoodScalar r
128+
=> Real (Nested.Mixed sh r) where
129+
toRational = error "horde-ad: operation not defined for tensor"
130+
131+
instance (GoodScalar r, Nested.FloatElt r)
132+
=> RealFrac (Nested.Ranked n r) where
133+
properFraction = error "horde-ad: operation not defined for tensor"
134+
135+
instance (GoodScalar r, RealFrac r, Nested.FloatElt r)
136+
=> RealFrac (Nested.Shaped sh r) where
137+
properFraction = error "horde-ad: operation not defined for tensor"
138+
139+
instance (GoodScalar r, Nested.FloatElt r)
140+
=> RealFrac (Nested.Mixed sh r) where
141+
properFraction = error "horde-ad: operation not defined for tensor"
118142

119143
instance (Nested.NumElt r, Nested.PrimElt r, RealFloatH r, Nested.FloatElt r)
120144
=> RealFloatH (Nested.Ranked n r) where
@@ -157,6 +181,77 @@ instance (Nested.NumElt r, Nested.PrimElt r, RealFloatH r, Nested.FloatElt r)
157181
, either (V.replicate (V.length x)) id y' )
158182
in V.zipWith atan2H x y))) -- TODO: do better somehow
159183

184+
instance (GoodScalar r, Nested.PrimElt r, RealFloat r, Nested.FloatElt r)
185+
=> RealFloat (Nested.Ranked n r) where
186+
atan2 = Nested.Internal.arithPromoteRanked2
187+
(Nested.Internal.mliftNumElt2
188+
(flip Nested.Internal.Arith.liftVEltwise2
189+
(\x' y' ->
190+
let (x, y) = case (x', y') of
191+
(Left x2, Left y2) ->
192+
(V.singleton x2, V.singleton y2)
193+
_ ->
194+
( either (V.replicate (V.length y)) id x'
195+
, either (V.replicate (V.length x)) id y' )
196+
in V.zipWith atan2 x y))) -- TODO: do better somehow
197+
floatRadix = error "horde-ad: operation not defined for tensor"
198+
floatDigits = error "horde-ad: operation not defined for tensor"
199+
floatRange = error "horde-ad: operation not defined for tensor"
200+
decodeFloat = error "horde-ad: operation not defined for tensor"
201+
encodeFloat = error "horde-ad: operation not defined for tensor"
202+
isNaN = error "horde-ad: operation not defined for tensor"
203+
isInfinite = error "horde-ad: operation not defined for tensor"
204+
isDenormalized = error "horde-ad: operation not defined for tensor"
205+
isNegativeZero = error "horde-ad: operation not defined for tensor"
206+
isIEEE = error "horde-ad: operation not defined for tensor"
207+
208+
instance (GoodScalar r, Nested.PrimElt r, RealFloat r, Nested.FloatElt r)
209+
=> RealFloat (Nested.Shaped sh r) where
210+
atan2 = Nested.Internal.arithPromoteShaped2
211+
(Nested.Internal.mliftNumElt2
212+
(flip Nested.Internal.Arith.liftVEltwise2
213+
(\x' y' ->
214+
let (x, y) = case (x', y') of
215+
(Left x2, Left y2) ->
216+
(V.singleton x2, V.singleton y2)
217+
_ ->
218+
( either (V.replicate (V.length y)) id x'
219+
, either (V.replicate (V.length x)) id y' )
220+
in V.zipWith atan2 x y))) -- TODO: do better somehow
221+
floatRadix = error "horde-ad: operation not defined for tensor"
222+
floatDigits = error "horde-ad: operation not defined for tensor"
223+
floatRange = error "horde-ad: operation not defined for tensor"
224+
decodeFloat = error "horde-ad: operation not defined for tensor"
225+
encodeFloat = error "horde-ad: operation not defined for tensor"
226+
isNaN = error "horde-ad: operation not defined for tensor"
227+
isInfinite = error "horde-ad: operation not defined for tensor"
228+
isDenormalized = error "horde-ad: operation not defined for tensor"
229+
isNegativeZero = error "horde-ad: operation not defined for tensor"
230+
isIEEE = error "horde-ad: operation not defined for tensor"
231+
232+
instance (GoodScalar r, Nested.PrimElt r, RealFloat r, Nested.FloatElt r)
233+
=> RealFloat (Nested.Mixed sh r) where
234+
atan2 = (Nested.Internal.mliftNumElt2
235+
(flip Nested.Internal.Arith.liftVEltwise2
236+
(\x' y' ->
237+
let (x, y) = case (x', y') of
238+
(Left x2, Left y2) ->
239+
(V.singleton x2, V.singleton y2)
240+
_ ->
241+
( either (V.replicate (V.length y)) id x'
242+
, either (V.replicate (V.length x)) id y' )
243+
in V.zipWith atan2 x y))) -- TODO: do better somehow
244+
floatRadix = error "horde-ad: operation not defined for tensor"
245+
floatDigits = error "horde-ad: operation not defined for tensor"
246+
floatRange = error "horde-ad: operation not defined for tensor"
247+
decodeFloat = error "horde-ad: operation not defined for tensor"
248+
encodeFloat = error "horde-ad: operation not defined for tensor"
249+
isNaN = error "horde-ad: operation not defined for tensor"
250+
isInfinite = error "horde-ad: operation not defined for tensor"
251+
isDenormalized = error "horde-ad: operation not defined for tensor"
252+
isNegativeZero = error "horde-ad: operation not defined for tensor"
253+
isIEEE = error "horde-ad: operation not defined for tensor"
254+
160255

161256
-- * RepORArray and its operations
162257

@@ -286,9 +381,12 @@ deriving instance Eq (RepORArray y) => Eq (RepN y)
286381
deriving instance Ord (RepORArray y) => Ord (RepN y)
287382
deriving instance Num (RepORArray y) => Num (RepN y)
288383
deriving instance IntegralH (RepORArray y) => IntegralH (RepN y)
384+
deriving instance Real (RepORArray y) => Real (RepN y)
289385
deriving instance Fractional (RepORArray y) => Fractional (RepN y)
290386
deriving instance Floating (RepORArray y) => Floating (RepN y)
387+
deriving instance RealFrac (RepORArray y) => RealFrac (RepN y)
291388
deriving instance RealFloatH (RepORArray y) => RealFloatH (RepN y)
389+
deriving instance RealFloat (RepORArray y) => RealFloat (RepN y)
292390

293391
rtoVector :: GoodScalar r => RepN (TKR n r) -> VS.Vector r
294392
rtoVector = Nested.rtoVector . unRepN

src/HordeAd/Core/Types.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class GoodScalarConstraint r => GoodScalar r
147147
instance GoodScalarConstraint r => GoodScalar r
148148

149149
type Differentiable r =
150-
(RealFloatH r, Nested.FloatElt r, RealFrac r, Random r)
150+
(RealFloatH r, Nested.FloatElt r, RealFrac r, RealFloat r, Random r)
151151

152152
-- We white-list all types on which we permit differentiation (e.g., SGD)
153153
-- to work. This is for technical typing purposes and imposes updates

0 commit comments

Comments
 (0)