Skip to content

Commit 0c4873c

Browse files
committed
Tweak README based on Tom's feedback
1 parent ea8c852 commit 0c4873c

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ which can be verified by computing the gradient at `(1.1, 2.2, 3.3)`:
3636
(2.4396285219055063, -1.953374825727421, 0.9654825811012627)
3737
```
3838

39-
When `foo` is instantiated to matrices, which is a similarly trivial example due to the arithmetic operations working on the arrays element-wise, the gradient is:
39+
When `foo` is instantiated to matrices, which is a similarly trivial example as before due to the arithmetic operations working on the arrays element-wise, the gradient is:
4040
```hs
4141
type Matrix2x2 f r = f (TKS '[2, 2] r)
4242
type ThreeMatrices r = (Matrix2x2 Concrete r, Matrix2x2 Concrete r, Matrix2x2 Concrete r)
@@ -63,13 +63,13 @@ fooLet (x, y, z) =
6363
atan2H z w + z * w
6464
```
6565

66-
The most general symbolic gradient program can be then obtained using the `vjpArtifact` tool:
66+
The most general symbolic gradient program can be then obtained using the `vjpArtifact` tool. We are using `fooLet` without `ssum0` this time, becuase the `vjp` family of tools by convention permits non-scalar domains (but expects an incoming cotangent argument to compensate, visible in the code as `dret`).
6767
```hs
6868
artifact :: AstArtifactRev (X (ThreeConcreteMatrices Double)) (TKS '[2, 2] Double)
6969
artifact = vjpArtifact fooLet threeSimpleMatrices
7070
```
7171

72-
With additional formatting, it looks like an ordinary functional program with a lot of nested pairs and projections to represent tuples present in the objective function. A quick inspection of the gradient program reveals that computations are not repeated, which is thanks to the sharing mechanism, as promised.
72+
With additional formatting, the gradient program below looks like ordinary functional code with a lot of nested pairs and projections to represent tuples. A quick inspection of the gradient code reveals that computations are not repeated, which is thanks to the sharing mechanism, as promised.
7373

7474
```hs
7575
>>> printArtifactPretty artifact
@@ -83,14 +83,14 @@ With additional formatting, it looks like an ordinary functional program with a
8383
, (m4 * m5) * dret + m4 * dret)
8484
```
8585
86-
A concrete value of the symbolic gradient at the same input as before can be obtained by interpreting the gradient program in the context of the operations supplied by the horde-ad library. The value is the same as for `fooLet` evaluated by `cgrad` on the same input:
86+
A concrete value of the symbolic gradient at the same input as before can be obtained by interpreting the gradient program in the context of the operations supplied by the horde-ad library. The value is the same as for `fooLet` evaluated by `cgrad` on the same input, as long as the incoming cotangent argument consists of ones in all array cells, which is denoted by `srepl 1` in this case:
8787
8888
```hs
8989
>>> vjpInterpretArtifact artifact (toTarget threeSimpleMatrices) (srepl 1)
9090
((sfromListLinear [2,2] [2.4396285219055063,2.4396285219055063,2.4396285219055063,2.4396285219055063],sfromListLinear [2,2] [-1.953374825727421,-1.953374825727421,-1.953374825727421,-1.953374825727421],sfromListLinear [2,2] [0.9654825811012627,0.9654825811012627,0.9654825811012627,0.9654825811012627]) :: ThreeConcreteMatrices Double)
9191
```
9292
93-
A shorthand that creates the symbolic derivative program, simplifies it and interprets it with a given input on the default CPU backend is called `grad` and is used exactly the same (but with often much better performance) as `cgrad`:
93+
A shorthand that creates the symbolic derivative program, simplifies it and interprets it with a given input on the default CPU backend is called `grad` and is used exactly the same as (but with often much better performance) `cgrad`:
9494
```hs
9595
>>> grad (kfromS . ssum0 . fooLet) threeSimpleMatrices
9696
(sfromListLinear [2,2] [2.4396285219055063,2.4396285219055063,2.4396285219055063,2.4396285219055063],sfromListLinear [2,2] [-1.953374825727421,-1.953374825727421,-1.953374825727421,-1.953374825727421],sfromListLinear [2,2] [0.9654825811012627,0.9654825811012627,0.9654825811012627,0.9654825811012627])

0 commit comments

Comments
 (0)