(number? pi) => #t and (number? (/ pi 1)) ==> #t
after (require malt) the number? function behaves different
(number? pi) ==> still #t but (number? (/ pi 1) ==> #f???
The malt
language redefines math functions like /
that accept and return tensors. Doesn't look like it provides an updated number?
that's aware of those types.
Overwriting a function with exact the same name, makes the old function unavailable.
My workaround is to re-define /, (define div /), before calling (require malt). Not elegant at all.
Is there a possibility to put a namespace before a function in Scheme? Like math. / or malt./
How is such problem solved in scheme?
I have no experience with libraries or modules in Scheme.
(require (prefix-in malt: malt))
Then all the malt
functions are imported with the prefix malt:
.
The documentation I linked to earlier mentions it in passing; you should (require malt/base-no-overrides)
instead; that stops things like /
from being overridden (use d/
to get the Malt version). For more options and details see 2 Entry points
Indeed it's written there. Sorry, my brain shuts down on so many possibilties for using malt.
I want to calculate the (gradient-of sin (/ pi 4)).
==> contract violation saying that 0.7853... is no number
because / is overridden and the number is now a tensor-0.
After some feedback here I use the primitive-/ in malt by writing:
(require malt/base-no-overrides)
(gradient-of d-sin (/ pi 4))
==> now gradient-of is undefined?
Why, It is not overridden?
This overriding is a nice thing to write concise code, but it comes at a cost: I have to start reading manuals, to be able to use a primitive function.
It would be nice if one could grab back to another implementation of the same funtion name, by putting a prefix before it. Like primitive-/ instead of having to change the require call and put a d- before all overriden malt functions.
Hmmm, now I'm asking Racket to be modified because I can't read the manual. That's no good.
Anyway I'm still stuck on a simple calculation.
You can write something like (require (prefix-in r: racket))
(or racket/base
) to be able to write r:/
.
In general, it might be helpful to keep in mind that “overridden” functions aren’t mutated or “monkey-patched” the way a function might be in Ruby, JavaScript, or Python. What happens is that bindings imported with require
shadow bindings for the same name that come from the #lang
, in just the same way that let
can locally shadow bindings from some wider scope.
I want to calculate the (gradient-of sin (/ pi 4)).
==> contract violation saying that 0.7853... is no number
Ah, yes, I came across this sort of thing too.
It happens because "gradient-of" expects the first argument, i.e. the function, to be able to handle tensors, not plain old numbers. IOW, to be what the text calls an "extended" function.
For instance you will see that
(gradient-of sqrt (/ pi 4))
works just fine, as "sqrt" is already suitably defined.
Also, if you (define (square x)(* x x)
then
(gradient-of square (/ pi 4))
will also work correctly.
This is because the arithmetic functions +,-,*,/
are already extended to work for tensors, so any composition of them will also be "extended".
I'm not really sure what to recommend in this case, but I guess you'll have to find a way to extend the trigonometric primitives if you want to use AD on them. I need to know the system a bit better before I can recommend a method, to avoid giving you a bum steer.
It is a malt issue, however, not a Racket one!
PS. Interestingly enough, both exp
and expt
are defined in malt, so you can do stuff like
> (gradient-of exp 1)
2.718281828459045
> (gradient-of (curryr expt 1/3) 27)
0.037037037037037035
with no problem. You could easily, e.g. define hyperbolic functions!
It's just poor old sine, cosine and tangent that have not yet joined the party.
However, here's an example of a workaround for sin
, using complex numbers:
(define i (sqrt -1))
(define neg-i (* -1 i))
(define (sin z) (/ (- (exp (* i z))(exp (* neg-i z))) (* 2 i)))
> (gradient-of sin (/ pi 4))
0.7071067811865476+0.0i
I also found that it will work if you define sine in terms of its Taylor expansion.
First, some auxilliary stuff, for series...
foldl
will not work if used with range
(you can't do AD on it). So we provide our own Sigma routine, sum
, which can be differentiated.
Remember, all extended functions, are in terms of tensor
, not number
.
Effectively, malt highjacks all numbers into tensors.
(define ((reduce op init) term a next b)
(define (reduce-tr a tot)
(cond ((> a b) tot)
(else (reduce-tr (next a) (op tot (term a))))))
(reduce-tr a init))
(define sum (reduce + 0))
(define prod (reduce * 1))
(define (inc n)(+ 1 n))
(define (factorial n)
(if (< n 2) 1
(prod identity 2.0 inc n)))
Now, with the above, it's easy:
(define (sine x)
(sum (λ(n)(/ (* (expt -1 n)(expt x (inc (* 2 n))))
(factorial (inc (* 2 n)))))
0 inc 20))
> (gradient-of sine (/ pi 4))
0.7071067811865475
Just for fun
(define cosine (curry gradient-of sine))
> (cosine (/ pi 2))
4.253941440428854e-17
> (cosine (/ pi 6))
0.8660254037844386
> (cosine (/ pi 3))
0.5000000000000001
PS. However, the above does yield a bug:
> (cosine 0)
. . log: division by zero
>
I won't bother with it for now , the malt library is newish, just define the cos Taylor series longhand, and it should work fine!
(define (cosine x)
(sum (λ(n)
(/ (* (expt -1 n)(expt x (* 2 n)))
(factorial (* 2 n))))
0 inc 20))
So I guess the quick answer is:
For any function to be Automatically Differentiable in malt it must be composed of primitives that have already been defined as differentiable im malt (and which return tensors when applied to their args) see here
Sin and cos have not yet been defined as diff/able primitives in malt, so they must be composed by other such primitives, e.g. using complex numbers or Taylor series like in the two examples above.
I am sure there is a simple way to add them as primitives (thus gaining efficiency) using the "dual" system, as I have seen implemented in other languages, notably Clojure and Julia and even Scheme. It should be very easy in Racket, I just haven't got that far yet.
The code for duals is duplicated for the different tensor implementation in the repo, but here is where the other operators are defined for malt/learner
https://github.com/themetaschemer/malt/blob/main/learner/ext-ops/A-scalar-ops.rkt
Thanks for that! The Learner code looks really easy to extend, all the necessary functions prim1
, ext1
, etc, are available just with (require malt)
, I tried it out, following the source code, but I only got to first base, there's probably something more I need to do.
(maybe the line (* z (sin a))
should have used plain-old Racket multiplication? i.e. before it was commandeered by malt)
(define sin-0
(prim1 sin
(λ (a z)
(* z (cos a)))))
(define d-sin (ext1 sin-0 0))
> (d-sin 1)
0.8414709848078965
> (number? (d-sin 1))
#f
> (scalar? (d-sin 1))
#t
> (gradient-of d-sin 1.0)
. . +: contract violation
expected: number?
given: 0.8414709848078965
>
Seems like the extension of d-sin is blocked by the failure of the multiplication.
I'll have a better look at the implementation later on, for now, I'm trying to understand ML, of which I know even less than I do about AD!
Yup, got it all going fine now!
(require malt)
(require (only-in racket/base
[* racket:*]
[- racket:-]
[sin racket:sin]
[cos racket:cos]
))
(define sin-0
(prim1 racket:sin
(λ (a z)
(racket:* z (racket:cos a)))))
(define cos-0
(prim1 racket:cos
(λ (a z)
(racket:* z (racket:- (racket:sin a))))))
(define sin (ext1 sin-0 0))
(define cos (ext1 cos-0 0))
(define (tan x) (/ (sin x) (cos x)))
Try a few tests...
> (sin 1)
0.8414709848078965
> (cos 1)
0.5403023058681398
> (sin (/ pi 4))
0.7071067811865476
> (cos (/ pi 4))
0.7071067811865476
> (tan 1)
1.557407724654902
> (tan 0)
0
> (number? (cos 1))
#f
> (number? (tan 1))
#f
> (scalar? (cos 1))
#t
> (scalar? (tan 1))
#t
> (tensor? (tan 1))
#t
> (gradient-of sin (/ pi 4))
0.7071067811865476
> (gradient-of cos (/ pi 4))
0.7071067811865475
> (gradient-of tan 0)
1.0
> (gradient-of tan 1)
3.425518820814759
All seems good.