Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Basic matrix operations until QR and SVD algorithms #172

Merged
merged 10 commits into from
Feb 16, 2024
223 changes: 189 additions & 34 deletions spork/math.janet
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,9 @@
(def cells @[])
(forever
(set (cells x)
(*
bc
(math/pow p x)
(math/pow (- 1 p) (- t x))))
(* bc
(math/pow p x)
(math/pow (- 1 p) (- t x))))
(+= cp (cells x))
(++ x)
(set bc (/ (* bc (+ (- t x) 1)) x))
Expand Down Expand Up @@ -1026,7 +1025,7 @@
[c &opt r]
(def v (array/new-filled c 0))
(if r
(seq [_ :range [0 c]] (array/slice v))
(seq [_ :range [0 r]] (array/slice v))
v))

(defn scalar
Expand All @@ -1046,16 +1045,16 @@
(scalar c 1))

(defn trans
"Returns a new transposed matrix from `m`."
"Tansposes a list of row vectors."
[m]
(def [c r] (size m))
(def res (array/new c))
(for i 0 r
(def cr (array/new c))
(for j 0 c
(array/push cr (get-in m [j i])))
(array/push res cr))
res)
(map array ;m))

(defn row->col
"Transposes a row vector `xs` to col vector. Returns `xs` if it has higher dimensions."
[xs]
(case (type (xs 0))
:number (map array xs)
:array xs))

(defn sop
```
Expand All @@ -1067,7 +1066,8 @@
(if-not (empty? a) |(op $ ;a) op))
(for i 0 (cols m)
(for j 0 (rows m)
(update-in m [i j] opa))))
(update-in m [i j] opa)))
m)

(defn mop
```
Expand All @@ -1077,7 +1077,7 @@
[m op a]
(for i 0 (cols m)
(for j 0 (rows m)
(update-in m [j i] op (get-in a [j i])))))
(update-in m [j i] op (get-in a [j i])))) m)

(defn add
```
Expand All @@ -1090,34 +1090,40 @@
:array (mop m + a)))

(defn dot
"Computes dot product of matrices or vectors `x` and `y`."
[mx my]
(def [rx cx] (size mx))
(def [ry cy] (size my))
(assert (= cx ry) "matrices do not have right sizes for dot product")
(def res (array/new cy))
(for r 0 rx
(def cr (array/new cx))
(for c 0 cy
(var s 0)
(for rr 0 ry
(+= s (* (get-in mx [r rr]) (get-in my [rr c]))))
(array/push cr s))
(array/push res cr))
res)
"Dot product between two row vectors."
[v1 v2]
(apply + (map * v1 v2)))

(defn dot-fast
"Fast dot product between two row vectors of equal size."
[v1 v2]
(var t 0)
(for i 0 (length v1)
(+= t (* (get v1 i) (get v2 i))))
t)

(defn matmul
"Matrix multiplication between matrices `ma` and `mb`. Does not mutate."
[ma mb]
(map (fn [row-a]
(map (fn [col-b]
(apply + (map * row-a col-b)))
(trans mb)))
ma))

(defn mul
```
Multiply matrix `m` with `a` which can be matrix or vector.
Matrix `m` is mutated.
Multiply matrix `m` with `a` which can be matrix or a list.
Mutates `m`. A list `a` will be converted to column vector
then multiplifed from the right as `x * a`.
```
[m a]
(case (type a)
:number
(sop m * a)
:array
(if (number? (a 0))
(dot m (seq [x :in a] @[x]))
(matmul m (row->col a))
(mop m * a))))

(defn minor
Expand Down Expand Up @@ -1354,3 +1360,152 @@
(if (> x one)
(array/concat res (factor-pollard x))))
res)

(defn scale
"Scale a vector `v` by a number `k`."
[v k]
(map (fn [x] (* x k)) v))

(defn subtract
"Elementwise subtract vector `v2` from `v1`."
[v1 v2]
(map - v1 v2))

(defn copy
"Deep copy an array or view `xs`."
[xs]
(if (= :ta/view (type xs)) (:slice xs) (array/slice xs)))

(defn sign
"Sign function."
[x] (cmp x 0))

(defn outer
"Outer product of vectors `v1` and `v2`."
[v1 v2]
(matmul (map array v1) (array v2)))

(defn unit-e
"Unit vector of `n` dimensions along dimension `k`."
[n k]
(update-in
(zero n) [k] (fn [x] 1)))

(defn normalize-v
"Returns normalized vector of `xs` by Euclidian (L2) norm."
[xs]
(map |(/ $0 (math/sqrt (dot xs xs))) xs))

(defn join-rows
"Stack vertically rows of two matrices."
[m1 m2]
(array/concat @[] m1 m2))

(defn join-cols
"Stack horizontally columns of two matrices."
[m1 m2]
(map join-rows m1 m2))

(defn squeeze
"Concatenate a list of rows into a single row. Does not mutate `m`."
[m]
(array/concat @[] ;m))

(defn flipud
"Flip a matrix upside-down."
[m]
(reverse m))

(defn fliplr
"Flip a matrix leftside-right."
[m]
(map reverse m))

(defn expand-m
"Embeds a matrix `m` inside an identity matrix of size n."
[n m]
(let [I (ident n)
left (join-rows I (zero n (rows m)))
right (join-rows (zero (cols m) n) m)]
(join-cols left right)))

(defn slice-m
"Slice a matrix `m` by rows and columns."
[m rslice cslice]
(-> m
(array/slice ;rslice)
trans
(array/slice ;cslice)
trans))


(defn qr1
"Transform using Householder reflections by one step."
[m]
(let [x ((trans m) 0) # take first column
k 0
a (* -1 (sign (x k)) (math/sqrt (dot x x)))
e1 (unit-e (length x) 0)
u (subtract x (map |(* $ a) e1)) # (mul e1 a)
v (normalize-v u)
I (ident (length u))
Q (mop I - (sop (outer v v) * 2))
Qm (matmul Q m)
m^ (slice-m Qm [1] [1])]
{:Q Q
:m^ m^}))


(defn qr
```
Stable and robust QR decomposition of a matrix.
Decompose a matrix using Householder transformations. O(n^3).
```
[m]
(var m^ m)
(var Qs (seq [i :range [0 (min (- (rows m) 1) (cols m))]]
(def res (qr1 m^))
(set m^ (res :m^))
(def Q^ (expand-m i (res :Q)))
Q^))
(def I (ident (cols Qs)))
(var Q (reduce matmul I Qs))
(var R (reduce matmul I (array/concat (reverse Qs) (array m))))
{:Q Q
:R R})

(defn svd
```
Simple Singular-Value-Decomposition based on repeated QR decomposition. The algorithm converges at O(n^3).
```
[m &opt n-iter]
(def n-iter 100)
(var U (ident (rows m)))
(var V U)
(var Q1 U)
(var Q2 U)
(var R1 m)
(var R2 U)
(var Q1 U)
(loop [i :range [0 n-iter]]
(var res (qr R1))
(set Q1 (res :Q))
(set R1 (res :R))
(var res^ (qr (trans R1)))
(set Q2 (res^ :Q))
(set R2 (res^ :R))
(set R1 (trans R2))
(set U (matmul U Q1))
(set V (matmul V Q2)))
{:U U
:S R1
:V V})

(defn m-approx=
"Compares two matrices of equal size for equivalence within epsilon."
[m1 m2 &opt tolerance]
(let [v1 (squeeze m1)
v2 (squeeze m2)
b (map approx-eq v1 v2)]
(and (= (length v1) (length v2))
(every? b))))
Loading
Loading