-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add naive Racket matrix-multiplication
- Loading branch information
1 parent
e1a3e15
commit 3e1c669
Showing
3 changed files
with
190 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#lang racket | ||
|
||
(require binaryio) | ||
|
||
;; Returns (as bs cs) where xs is a list of "arrays" | ||
;; Each array is a pair (n . v), where v is a vector | ||
;; and n the length of the rows. | ||
(provide load-float64-arrays-from-dir) | ||
|
||
(struct npy-header | ||
(magic-string | ||
major-version-number | ||
minor-version-number | ||
length | ||
format) | ||
#:transparent) | ||
|
||
(define (load-float64-arrays-from-dir dir) | ||
(define filenames (directory-list dir)) | ||
;; each filename is X_999.npy, where X is a, b, or c and 999 is a number->string | ||
|
||
(define n-matrices | ||
(+ 1 | ||
(apply max | ||
(map | ||
(λ (entry) | ||
(string->number | ||
(cadr (regexp-match #rx"[abc]_([0-9]+).npy" (path->string entry))))) | ||
filenames)))) | ||
|
||
;; returns (values as bs cs) | ||
(define-values (as bs cs) | ||
(for/lists (as bs cs) | ||
([i (in-range n-matrices)]) | ||
(when (zero? (modulo i (quotient n-matrices 8))) | ||
(display ".")) | ||
(apply values | ||
(for/list ([x '("a" "b" "c")]) | ||
(call-with-input-file | ||
(build-path dir (string-append x "_" (number->string i) ".npy")) | ||
load-float64-array-from-npy))))) | ||
|
||
(list as bs cs)) | ||
|
||
|
||
(define (load-float64-array-from-npy in) | ||
;; Read header | ||
(define the-header | ||
(let* ([mgc (read-bytes 6 in)] | ||
[maj (read-integer 1 #f in)] | ||
[min (read-integer 1 #f in)] | ||
[len (read-integer 2 #f in #f)] | ||
[fmt (bytes->string/latin-1 (read-bytes len in))]) | ||
(npy-header mgc maj min len fmt))) | ||
|
||
;; Parse shape from a Python dictionary (!) | ||
(define the-shape | ||
(let ([shape (regexp-match #rx"'shape': \\(([0-9]+), ([0-9]+)\\)" | ||
(npy-header-format the-header))]) | ||
(list (string->number (cadr shape)) | ||
(string->number (caddr shape))))) | ||
|
||
;; Read array and return it and the length of a row | ||
(cons (cadr the-shape) | ||
(for/vector ([_ (in-range (* (car the-shape) (cadr the-shape)))]) | ||
(read-float 8 in #f)))) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#lang racket | ||
|
||
(require "load.rkt") | ||
|
||
|
||
;; ------------------------------------------------------------ | ||
;; Represent a matrix as a struct | ||
|
||
(struct matrix (ncols vec) #:transparent) | ||
|
||
(define (make-matrix nrows ncols [v 0]) | ||
(matrix ncols (make-vector (* nrows ncols) v))) | ||
|
||
(define (matrix-nrows m) | ||
(quotient (vector-length (matrix-vec m)) (matrix-ncols m))) | ||
|
||
;; Return element in the ith row and jth column | ||
;; Index from 0 | ||
(define (matrix-ref m i j) | ||
(vector-ref (matrix-vec m) (+ j (* i (matrix-ncols m))))) | ||
|
||
(define (matrix-set! m i j v) | ||
(vector-set! (matrix-vec m) (+ j (* i (matrix-ncols m))) v)) | ||
|
||
(define (matrix-equal-within? delta m1 m2) | ||
(and | ||
(= (matrix-ncols m1) (matrix-ncols m2)) | ||
(= (matrix-nrows m1) (matrix-nrows m2)) | ||
(andmap | ||
(λ (v1 v2) | ||
(<= (abs (- v1 v2)) delta)) | ||
(vector->list (matrix-vec m1)) | ||
(vector->list (matrix-vec m2))))) | ||
|
||
;; ------------------------------------------------------------ | ||
;; Mathematics | ||
|
||
(define (matrix-mul m1 m2) | ||
(let ([I (matrix-nrows m1)] | ||
[J (matrix-ncols m2)] | ||
[K (matrix-ncols m1)]) | ||
(unless (= K (matrix-nrows m2)) | ||
(error "Matrices not compatible")) | ||
(define vec | ||
(for*/vector #:length (* I J) | ||
([i (in-range I)] | ||
[j (in-range J)]) | ||
(for/sum ([k (in-range K)]) | ||
(* (matrix-ref m1 i k) | ||
(matrix-ref m2 k j))))) | ||
|
||
(matrix J vec)) | ||
) | ||
|
||
;; ------------------------------------------------------------ | ||
;; main | ||
|
||
|
||
(display "Loading test data ") | ||
|
||
(match-define (list as bs cs) | ||
(map | ||
(λ (x) | ||
(map | ||
(λ (m) | ||
(let ([ncols (car m)] | ||
[vs (cdr m)]) | ||
(matrix ncols vs))) | ||
x)) | ||
(load-float64-arrays-from-dir "../testdata/matrices"))) | ||
|
||
(display "done \n") | ||
|
||
(display "Checking test cases...") | ||
|
||
(unless | ||
(andmap | ||
(λ (m1 m2) (matrix-equal-within? 0.01 m1 m2)) | ||
cs (map matrix-mul as bs)) | ||
(error "something went wrong!")) | ||
|
||
(displayln " done.") | ||
|
||
(define *repeats* 32768) | ||
|
||
(printf "Trying ~a matrix multipications.\n" (* *repeats* (length as))) | ||
|
||
(display | ||
(time | ||
(for ([i (in-range *repeats*)]) | ||
(when (zero? (modulo i 1024)) | ||
(display ".")) | ||
(map matrix-mul as bs)))) | ||
|
||
|