Requesting help to translate small array program into typed racket

When I added array broadcasting to SRFI 231 here GitHub - gambiteer/srfi-231 at 231-bis I made some design decisions that differ from those in Racket's math/array and I would like to compare the implementations of the two libraries.

I'd like someone to help by adding type annotations to the following small program to test the interaction of array broadcasting with folding along an axis so that the program will run in Typed Racket.

Thanks.

Brad

#lang racket

(require math/array)
(require math/matrix)

(require racket/flonum)

(define (matrix-multiply A B)
  (array-axis-fold
   (array-map
    fl*
    (array-axis-insert A 2)
    (array-axis-insert B 0))
   1
   fl+))

(define A (build-array #(100 100) (lambda (multi-index)
                                    (exact->inexact (- (vector-ref multi-index 0)
                                                       (vector-ref multi-index 1))))))

(define B (build-array #(100 100) (lambda (multi-index)
                                    (exact->inexact (+ (vector-ref multi-index 0)
                                                       (vector-ref multi-index 1))))))

(define C (time (matrix-multiply A B)))

(define A* (build-matrix 100 100 (lambda (i j)
                                   (exact->inexact (- i
                                                      j)))))

(define B* (build-matrix 100 100 (lambda (i j)
                                   (exact->inexact (+ i
                                                      j)))))

(define C* (time (matrix* A* B*)))
#lang typed/racket

(require math/array)
(require math/matrix)

(require/typed
 racket/flonum
 [fl+ (-> Float Float Float)]
 [fl* (-> Float Float Float)])

(: matrix-multiply (-> (Matrix Float) (Matrix Float) (Matrix Float)))
(define (matrix-multiply A B)
  (array-axis-fold
   (array-map
    fl*
    (array-axis-insert A 2)
    (array-axis-insert B 0))
   1
   fl+))

(: indexex In-Indexes)
(define indexex #(100 100))

(define A
  (build-array
   indexex
   (λ ({mi : Indexes}) (exact->inexact (- (vector-ref mi 0) (vector-ref mi 1))))))

(define B
  (build-array
   indexex
   (lambda ({multi-index : Indexes}) (exact->inexact (+ (vector-ref multi-index 0) (vector-ref multi-index 1))))))

(define C (time (matrix-multiply A B)))

(define A* (build-matrix 100 100 (lambda ({i : Index} {j : Index}) (exact->inexact (- i j)))))

(define B* (build-matrix 100 100 (lambda ({i : Index} {j : Index}) (exact->inexact (+ i j)))))

(define C* (time (matrix* A* B*)))

Thanks! The results on my 11-year-old Linux box, after upping the matrix size to 200x200, were

cpu time: 3378 real time: 3383 gc time: 873
cpu time: 36 real time: 37 gc time: 8

So the built-in matrix* is about 100 times as fast as the routine matrix-multiply (which was written to see how implicit array broadcasting interacted with array-axis-fold).

So that's about 8,000,000 floating-point multiply-adds in 36ms or roughly 220 million fma instructions per second. Pretty good.

Similar code using the SRFI 231 followup library, with safe code in Gambit:

(define (matrix-multiply-2 A B)
  ;; We rely on implicit broadcasting of
  ;; the array arguments A* and B*
  (array-copy!
   (array-map (lambda (products)
                (array-fold-left fl+ 0. products))
              (array-curry
               (array-permute (array-map
                               fl*
                               (array-insert-axis A 2)
                               (array-insert-axis B 0))
                              (index-last 3 1))
               1))))


(define A (array-copy! (make-array (make-interval '#(200 200))
                                   (lambda (i j) (inexact (- i j))))))

(define B (array-copy! (make-array (make-interval '#(200 200))
                                   (lambda (i j) (inexact (+ i j))))))

(define D (time (matrix-multiply-2 A B)))

(define G (time (array-copy! (array-inner-product A fl+ fl* B))))

the timings are

(time (matrix-multiply-2 A B))
    0.380979 secs real time
    0.380508 secs cpu time (0.379516 user, 0.000992 system)
    7 collections accounting for 0.003366 secs real time (0.003349 user, 0.000012 system)
    20821184 bytes allocated
    78 minor faults
    no major faults
    1368323814 cpu cycles
(time (array-copy! (array-inner-product A fl+ fl* B)))
    0.349888 secs real time
    0.349500 secs cpu time (0.349500 user, 0.000000 system)
    9 collections accounting for 0.005723 secs real time (0.005717 user, 0.000000 system)
    23837600 bytes allocated
    78 minor faults
    no major faults
    1256656605 cpu cycles

The Typed Racket results were so good, I looked at ./math-lib/math/private/matrix/matrix-arithmetic.rkt, which leads me to ask whether matrix* is an open-coded macro when used in this way.

The matrix multiplication is defined in

So I think, the purpose of .../private/matrix/matrix-arithmetic.rkt is solely to make matrix* work from both untyped and typed Racket at the same time.

But the code is involved. The actual sum-and-multiply loop is here:

The macro stepper yields this for the line in question, and this shows an application of the function make-matrix-multiply"

(define-values:365 (C*)
     (let-values:366 (((v:367 cpu:367 user:367 gc:367)
                       (#%app:367
                        time-apply:367
                        (lambda:367 ()
                          (let-values:368 (((arr:369)
                                            (with-continuation-mark:370
                                             parameterization-key:370
                                             (#%app:370
                                              extend-parameterization:370
                                              (#%app:370 continuation-mark-set-first:370 (quote #f) parameterization-key:370)
                                              array-strictness:371
                                              (quote #f))
                                             (let-values:373 ()
                                               (let-values:374 (((m:375 p:375 n:375 arr-data:375 brr-data:375 bx:375)
                                                                 (#%app:376 matrix-multiply-data:377 A* B*)))
                                                 (#%app:378
                                                  make-matrix-multiply:379
                                                  m:375
                                                  p:375
                                                  n:375
                                                  (lambda:380 (i:375 j:375)
                                                    (let-values:382 (((bx:375) (#%app:383 bx:375)))
                                                      (let-values:382 (((v:375)
                                                                        (#%app:384
                                                                         unsafe-fl*
                                                                         (#%app:385 unsafe-vector-ref:375 arr-data:375 i:375)
                                                                         (#%app:386 unsafe-vector-ref:375 brr-data:375 j:375))))
                                                        (#%app:387
                                                         (letrec-values:388 (((loop:375)
                                                                              (lambda:389 (k:375 v:375)
                                                                                (if (#%app:391 unsafe-fx< k:375 p:375)
                                                                                  (let-values:392 ()
                                                                                    (#%app:393
                                                                                     loop:375
                                                                                     (#%app:394 unsafe-fx+ k:375 (quote 1))
                                                                                     (#%app:395
                                                                                      unsafe-fl+
                                                                                      v:375
                                                                                      (#%app:396
                                                                                       unsafe-fl*
                                                                                       (#%app:397
                                                                                        unsafe-vector-ref:375
                                                                                        arr-data:375
                                                                                        (#%app:398 unsafe-fx+ i:375 k:375))
                                                                                       (#%app:399
                                                                                        unsafe-vector-ref:375
                                                                                        brr-data:375
                                                                                        (#%app:400 unsafe-fx+ j:375 k:375))))))
                                                                                  (let-values:401 ()
                                                                                    (#%app:402 set-box!:375 bx:375 v:375))))))
                                                           loop:375)
                                                         (quote 1)
                                                         v:375)
                                                        (#%app:403 unsafe-unbox bx:375))))))))))
                            (#%app:404 array-default-strict!:405 arr:369)
                            arr:369))
                        null:367)))
       (#%app:406 printf:367 (quote "cpu time: ~s real time: ~s gc time: ~s\n") cpu:367 user:367 gc:367)
       (#%app:367 apply:367 values:367 v:367)))

Thanks, I think I understand Racket’s array/matrix code better.

I want to ask: the definition of matrix*/ns in ./math-lib/math/private/matrix/typed-matrix-arithmetic.rkt is

(: matrix*/ns
   (case-> ((Matrix Flonum) (Listof (Matrix Flonum)) -> (Matrix Flonum))
           ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real))
           ((Matrix Float-Complex) (Listof (Matrix Float-Complex)) -> (Matrix Float-Complex))
           ((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number))))
(define (matrix*/ns a as)
  (cond [(empty? as) (matrix-shape a) ;; does argument checking
                      a]
        [else  (matrix*/ns (inline-matrix-multiply a (first as)) (rest as))]))

and inline-matrix-multiply is defined as a macro in ./math-lib/math/private/matrix/untyped-matrix-arithmetic.rkt with the comment

  ;; This is a macro so the result can have as precise a type as possible

My question: Does this mean that Typed Racket compiles a version of matrix*/ns specialized to each of the type cases in the signature?

Studying this code leads me to reconsider the specifications of array-outer-product and array-inner-product in SRFI 231, which now seem a bit sloppy. Some things perhaps to copy from how matrix* is coded:

  1. If necessary, copy argument arrays to specialized (strict) arrays, after permuting the second array argument to place its first axis last.
  2. Return a specialized (strict) array.
  3. RIght now, the calling sequence is
(array-inner-product A f g B)

and the inner loop is computed by (a and b are one-dimensional subarrays of A and B):

(lambda (a b)
  (array-reduce f (array-map g a b)))

where, for the matrix* loop, g is *, and f is +.

Perhaps the calling sequence should be

(array-inner-product A f-of-g g B)

where for matrix*, for example, g would be * and f-of-g would be

(lambda (sum a b) (+ sum (* a b)))

so it would be not two but one function call (other than the + and *, which can be inlined by the compiler, as in Gambit) per iteration.

I'll have to think about things more.

On Jan 15, 2026, at 1:59 PM, Gambiteer notifications@racket.discoursemail.com asked whether "this mean that Typed Racket compiles a version of matrix*/ns specialized to each of the type cases in the signature?”

Yes. TR does have type abstraction (polymorphism) but by using a macro, the primitive operations can be specialized to the type that’s provide.