Skip to content

Commit cfd5b84

Browse files
committed
Revise the examples sections in the README
1 parent 4ddf2a0 commit cfd5b84

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ fooLet (x, y, z) =
7676

7777
### Vector-Jacobian product (VJP) and symbolic derivatives
7878

79-
The most general symbolic reverse derivative program for this function can be obtained using the `vjpArtifact` tool. Because the `vjp` family of tools permits non-scalar codomains (but expects an incoming cotangent argument to compensate, visible in the code below as `dret`), we illustrate it using the original `fooLet`, without `ssum0`.
79+
The most general symbolic reverse derivative program for this function can be obtained using the `vjpArtifact` tool. Because the `vjp` family of tools permits non-scalar codomains (but expects an incoming cotangent argument to compensate, visible in the code below as `dret`), we illustrate it using the original `fooLet` from the previous section, without the need to add `ssum0`.
8080
```hs
8181
artifact :: AstArtifactRev (X (ThreeConcreteMatrices Double)) (TKS '[2, 2] Double)
8282
artifact = vjpArtifact fooLet threeSimpleMatrices
8383
```
8484

85-
The vector-Jacobian product program presented below with additional formatting looks like ordinary functional code with a lot of nested pairs and projections that represent tuples. A quick inspection of the code reveals that computations are not repeated, which is thanks to the `tlet` used above.
85+
The vector-Jacobian product program (presented below with additional formatting) looks like ordinary functional code with nested pairs and projections representing tuples. A quick inspection of the code reveals that computations are not repeated, which is thanks to the `tlet` used above.
8686

8787
```hs
8888
>>> printArtifactPretty artifact
@@ -97,13 +97,13 @@ The vector-Jacobian product program presented below with additional formatting l
9797
((m4 * m5) * dret + m4 * dret)"
9898
```
9999

100-
A concrete value of the symbolic reverse derivative at the same input as before can be obtained by interpreting its program in the context of the operations supplied by the horde-ad library. (Note that the output happens to be the same as `gradSumFooMatrix threeSimpleMatrices` above, which used `cgrad` on `kfromS . ssum0 . foo`; the reason is that `srepl 1.0` happens to be the reverse derivative of `kfromS . ssum0`.)
100+
A concrete value of this symbolic reverse derivative at the same input as before can be obtained by interpreting its program in the context of the operations supplied by the horde-ad library. (Note that the output happens to be the same as `gradSumFooMatrix threeSimpleMatrices` above, which used `cgrad` on `kfromS . ssum0 . foo`; the reason is that `srepl 1.0` happens to be the reverse derivative of `kfromS . ssum0`.)
101101
```hs
102-
>>> vjpInterpretArtifact artifact (toTarget threeSimpleMatrices) (srepl 1)
102+
>>> vjpInterpretArtifact artifact (toTarget threeSimpleMatrices) (srepl 1.0)
103103
((sfromListLinear [2,2] [2.4396285219055063,2.4396285219055063,2.4396285219055063,2.4396285219055063],sfromListLinear [2,2] [-1.953374825727421,-1.953374825727421,-1.953374825727421,-1.953374825727421],sfromListLinear [2,2] [0.9654825811012627,0.9654825811012627,0.9654825811012627,0.9654825811012627]) :: ThreeConcreteMatrices Double)
104104
```
105105

106-
Note that, as evidenced by the `printArtifactPretty` call above, `artifact` contains the complete and simplified code of the VJP of `fooLet` so repeated calls of `vjpInterpretArtifact artifact` won't ever repeat differentiation nor simplification and will only incur the cost of straightforward interpretation. However the repeated call would fail with an error if the provided argument had a different shape than `threeSimpleMatrices`, which is nevertheless impossible for the examples we use here, because all tensors we present are shaped, meaning their full shape is stated in their type and so can't differ for two (tuples of) tensors of the same type.
106+
Note that, as evidenced by the `printArtifactPretty` call above, `artifact` contains the complete and simplified code of the VJP of `fooLet`, so repeated calls of `vjpInterpretArtifact artifact` won't ever repeat differentiation nor simplification and will only incur the cost of straightforward interpretation. The repeated call would fail with an error if the provided argument had a different shape than `threeSimpleMatrices`. For the examples we show here, such a scenario is ruled out by the types, however, because all tensors we present are shaped, meaning their full shape is stated in their type and so can't differ for two (tuples of) tensors of the same type. More loosely-typed variants of all the tensor operations, where runtime checks can really fail, are available in the horde-ad API and can be mixed and matched freely.
107107

108108
A shorthand that creates a symbolic gradient program, simplifies it and interprets it with a given input on the default CPU backend is called `grad` and is used exactly the same as (but with often much better performance on the same program than) `cgrad`:
109109
```hs
@@ -118,11 +118,11 @@ An important feature of this library is a type system for tensor shape
118118
arithmetic. The following code is part of a convolutional neural network
119119
definition, for which horde-ad computes the shape of the gradient from
120120
the shape of the input data and the initial parameters.
121-
The compiler is able to infer many tensor shapes, deriving them both
121+
The Haskell compiler is able to infer many tensor shapes, deriving them both
122122
from dynamic dimension arguments (the first two lines of parameters
123123
to the function below) and from static type-level hints.
124124

125-
It is common to see neural network code with shape annotations in comments, hidden from the compiler:
125+
Let's look at the body of the `convMnistTwoS` function before we look at its signature. It is common to see neural network code like that, with shape annotations in comments, hidden from the compiler:
126126
```hs
127127
convMnistTwoS
128128
kh@SNat kw@SNat h@SNat w@SNat
@@ -187,6 +187,8 @@ convMnistTwoS
187187
...
188188
```
189189

190+
This style gets verbose and the Haskell compiler needs some convincing to accept such programs, but type-safety is the reward. In practice, at least the parameters of the objective function are best expressed with shaped tensors, while the implementation can (zero-cost) convert the tensors to loosely typed variants as needed.
191+
190192
The full neural network definition from which this function is taken can be found at
191193

192194
https://github.com/Mikolaj/horde-ad/tree/master/example

0 commit comments

Comments
 (0)