Skip to content

Commit 7cce935

Browse files
committed
Simplify Matrix2x2 in the README
1 parent b832a43 commit 7cce935

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
[![Hackage](https://img.shields.io/hackage/v/horde-ad.svg)](https://hackage.haskell.org/package/horde-ad)
44

55
Welcome to the Automatic Differentiation library originally inspired by the paper [_"Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"_](https://dl.acm.org/doi/10.1145/3498710). Compared to the paper and to classic taping AD Haskell packages, the library additionally efficiently supports array operations and generation of symbolic derivative programs, though the efficiency is confined to a narrowly typed class of source programs with limited higher-orderness. A detailed account of the extension is in the paper [_"Dual-Numbers Reverse AD for Functional Array Languages"_](http://arxiv.org/abs/2507.12640) by Tom Smeding, Mikolaj Konarski, Simon Peyton Jones and Andrew Fitzgibbon.
6+
<!--
7+
More specifically, in primitive pipelines (that match the Provable paper) the objective functions have types with ADVal in them which, e.g., permit dynamic control flow via inspecting the primal components of ADVal and permit higher order functions by just applying them (they are not symbolic for ADVal), but prevent vectorization, simplification and computing a derivative only once and evaluating on many inputs.
8+
-->
69

710
This is an early prototype, both in terms of the engine performance, the API and the preliminary tools and examples built with it. At this development stage, it's not coded defensively but exactly the opposite: it will fail on cases not found in current tests so that new code and tests have to be added and old code optimized for the new specimens reported in the wild. The user should also be ready to add missing primitives and any obvious tools that should be predefined but aren't, such as weight normalization (https://github.com/Mikolaj/horde-ad/issues/42). It's already possible to differentiate basic neural network architectures, such as fully connected, recurrent, convolutional and residual. The library should also be suitable for defining exotic machine learning architectures and non-machine learning systems, given that no notion of a neural network nor of a computation graph are hardwired into the formalism, but instead they are compositionally and type-safely built up from general automatic differentiation building blocks.
811

912
Mature Haskell libraries with similar capabilities, but varying efficiency, are https://hackage.haskell.org/package/ad and https://hackage.haskell.org/package/backprop. See also https://github.com/Mikolaj/horde-ad/blob/master/CREDITS.md. Benchmarks suggest that horde-ad has competitive performance on CPU.
1013
<!--
11-
The benchmarks at SOMEWHERE show that this library has performance highly competitive with (i.e. faster than) those and PyTorch on CPU.
14+
The benchmarks at _ (TBD after GHC 9.14 is out) show that this library has performance highly competitive with (i.e. faster than) those and PyTorch on CPU.
1215
-->
1316
It is hoped that the (well-typed) separation of AD logic and tensor manipulation backend will enable similar speedups on numerical accelerators, when their support is implemented. Contributions to this and other tasks are very welcome. The newcomer-friendly tickets are listed at https://github.com/Mikolaj/horde-ad/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22. Please don't hesitate to ask questions on github, on Matrix, via email.
1417

@@ -36,13 +39,13 @@ which can be verified by computing the gradient at `(1.1, 2.2, 3.3)`:
3639
(2.4396285219055063, -1.953374825727421, 0.9654825811012627)
3740
```
3841

39-
We can instantiate `foo` to matrices; the operations within (`sin`, `+`, `*`, etc.) applying elementwise:
42+
We can instantiate `foo` to matrices (represented in the `Concrete` datatype of unboxed multi-dimensional arrays); the operations within (`sin`, `+`, `*`, etc.) applying elementwise:
4043
```hs
41-
type Matrix2x2 f r = f (TKS '[2, 2] r) -- TKS means shapely-typed tensor kind
42-
type ThreeMatrices r = (Matrix2x2 Concrete r, Matrix2x2 Concrete r, Matrix2x2 Concrete r)
44+
type Matrix2x2 r = Concrete (TKS '[2, 2] r) -- TKS means shapely-typed tensor kind
45+
type ThreeMatrices r = (Matrix2x2 r, Matrix2x2 r, Matrix2x2 r)
4346
threeSimpleMatrices :: ThreeMatrices Double
4447
threeSimpleMatrices = (srepl 1.1, srepl 2.2, srepl 3.3) -- srepl replicates its argument to fill the whole matrix
45-
fooMatrixValue :: Matrix2x2 Concrete Double
48+
fooMatrixValue :: Matrix2x2 Double
4649
fooMatrixValue = foo threeSimpleMatrices
4750
>>> fooMatrixValue
4851
sfromListLinear [2,2] [4.242393641025528,4.242393641025528,4.242393641025528,4.242393641025528])
@@ -65,7 +68,7 @@ This works as well as before:
6568

6669
We noted above that `w` appears twice in `foo`. A property of tracing-based AD systems is that such re-use may not be captured, with explosive results.
6770
In `cgrad`, such sharing is preserved, so `w` is processed only once during gradient computation and this property is guaranteed for the `cgrad` tool universally, without any action required from the user.
68-
`horde-ad` also allows computing _symbolic_ derivative programs: using this API, a program is differentiated only once, after which it can be run on many different input values.
71+
`horde-ad` also allows computing _symbolic_ derivative programs: with this API, a program is differentiated only once, after which it can be run on many different input values.
6972
In this case, however, sharing is _not_ automatically preserved, so shared variables have to be explicitly marked using `tlet`, as shown below in `fooLet`.
7073
This also makes the type of the function more specific: it now does not work on an arbitrary `Num` any more, but instead on an arbitrary `horde-ad` tensor that implements the standard arithmetic operations, some of which (e.g., `atan2H`) are implemented in custom numeric classes.
7174
```hs

test/simplified/TestAdaptorSimplified.hs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -561,12 +561,12 @@ testGradFooDouble =
561561
(2.4396285219055063, -1.953374825727421, 0.9654825811012627)
562562
(gradFooDouble (1.1, 2.2, 3.3))
563563

564-
type Matrix2x2 :: Target -> Type -> Type
565-
type Matrix2x2 f r = f (TKS '[2, 2] r)
566-
type ThreeMatrices r = (Matrix2x2 Concrete r, Matrix2x2 Concrete r, Matrix2x2 Concrete r)
564+
type Matrix2x2 :: Type -> Type
565+
type Matrix2x2 r = Concrete (TKS '[2, 2] r)
566+
type ThreeMatrices r = (Matrix2x2 r, Matrix2x2 r, Matrix2x2 r)
567567
threeSimpleMatrices :: ThreeMatrices Double
568568
threeSimpleMatrices = (srepl 1.1, srepl 2.2, srepl 3.3)
569-
fooMatrixValue :: Matrix2x2 Concrete Double
569+
fooMatrixValue :: Matrix2x2 Double
570570
fooMatrixValue = foo threeSimpleMatrices
571571
gradSumFooMatrix :: ThreeMatrices Double -> ThreeMatrices Double
572572
gradSumFooMatrix = cgrad (kfromS . ssum0 . foo)
@@ -625,8 +625,12 @@ testGradFooLetMatrixSimpRPP = do
625625
in printArtifactPretty (simplifyArtifact $ revArtifactAdapt UseIncomingCotangent fooLet (FTKProduct (FTKProduct ftk ftk) ftk)))
626626
@?= "\\dret m1 -> tconvert (ConvT2 (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2] FTKScalar)) ConvSX))) (ConvCmp (ConvXR STKScalar) ConvSX)) (STKProduct (STKProduct (STKS [2,2] STKScalar) (STKS [2,2] STKScalar)) (STKS [2,2] STKScalar)) (let m3 = sin (sfromR (tproject2 (tproject1 m1))) ; m4 = sfromR (tproject1 (tproject1 m1)) * m3 ; m5 = recip (sfromR (tproject2 m1) * sfromR (tproject2 m1) + m4 * m4) ; m7 = (negate (sfromR (tproject2 m1)) * m5) * sfromR dret + sfromR (tproject2 m1) * sfromR dret in tpair (tpair (m3 * m7) (cos (sfromR (tproject2 (tproject1 m1))) * (sfromR (tproject1 (tproject1 m1)) * m7))) ((m4 * m5) * sfromR dret + m4 * sfromR dret))"
627627

628-
sumFooMatrix :: (ADReady f, RealFloat (Matrix2x2 f r), GoodScalar r)
629-
=> (Matrix2x2 f r, Matrix2x2 f r, Matrix2x2 f r) -> f (TKScalar r)
628+
type Matrix2x2f :: Target -> Type -> Type
629+
type Matrix2x2f f r = f (TKS '[2, 2] r)
630+
631+
sumFooMatrix :: (ADReady f, RealFloat (Matrix2x2f f r), GoodScalar r)
632+
=> (Matrix2x2f f r, Matrix2x2f f r, Matrix2x2f f r)
633+
-> f (TKScalar r)
630634
sumFooMatrix = kfromS . ssum0 . foo
631635

632636
testfooSumMatrix :: Assertion

0 commit comments

Comments
 (0)