Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arithmetic coding demo #631

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
102 changes: 102 additions & 0 deletions examples/arithmetic-coding.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
'## [Arithmetic coding](https://en.wikipedia.org/wiki/Arithmetic_coding)
This demonstrates a lossless method for compression on a string of letters.
Rather than assigning a code to each letter, the entire string is encoded
into a single floating-point number.

Alphabet = Fin 26
Interval = (Float&Float)
top:Interval = (0.,1.)

def charToIdx (c: Word8) : Int = W8ToI c - W8ToI 'a'
def idxToChar (i: Int) : Word8 = IToW8 (i + (W8ToI 'a'))

'### Statistical modelling
First, model the probability of each letter given by the string to be encoded.

def cumProb (ps: n=>Float) : n=>Float =
withState 0.0 \total.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be computed with Accum instead of state.

for i. if ps.i > 0.
then
currTotal = get total
newTotal = currTotal + ps.i
total := newTotal
currTotal
else 0.

def getFrequency (str: (Fin l)=>Word8) : Alphabet=>Int =
a: Alphabet => Int = zero
yieldState a \ref. for i.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can also be computed with a parallel Accum.

i' = (charToIdx str.i)@_
ref!i' := (get ref).i' + 1

def getProbability (l: Int) (freq: Alphabet=>Int) : Alphabet=>(Float&Float) =
probs = for i. IToF freq.i / IToF l
cums = cumProb probs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to repeat work every time it's called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The probabilities are cached on line 92:
p = getProbability l $ getFrequency str
So it's only calculated once.

for i. (probs.i, cums.i)

'### Scaling functions

def getUpdateRule (p: Alphabet=>(Float&Float)) : Alphabet=>(Interval->Interval) =
for i.
case p.i == (0.,0.) of
True -> id
False ->
\(x, w).
x' = x + w*(snd p.i)
w' = w*(fst p.i)
(x', w')

def subdivide (str: (Fin l)=>Word8)
(rule: Alphabet=>(Interval->Interval))
(i: (Fin l)) (in: Interval) : Interval =
updateInterval = rule.((charToIdx str.i)@_)
updateInterval in

def findInterval (l: Int)
(code: Float)
(rule: Alphabet=>(Interval->Interval))
(i: (Fin l))
((str,in): (List Word8 & Interval)) : (List Word8 & Interval) =
(letter, in') = boundedIter (size Alphabet) (' ', top) \j.
case rule.(j@_) in == in of
True -> Continue
False ->
(x, w) = rule.(j@_) in
case code >= x && code < (x+w) of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a binary search?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually just linear search on intervals, following after most code implementations I've seen. But maybe it'd scale better to larger alphabets if binary search were implemented instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might be able to use searchSorted from the prelude:
https://github.com/google-research/dex-lang/blob/main/lib/prelude.dx#L1620

True -> Done (idxToChar j, (x,w))
False -> Continue
(str <> AsList 1 [letter], in')

'### Coding interface
Start from an initial interval, [0, 1).
For each letter encoded from the string, the current interval is divided based on the
cumulative probability of all letters, then updated to the partition that matches
the encoded letter.
The decoding process retraces the steps of the encoding process to recover the correct letters.

def encode (str: (Fin l)=>Word8) (rule: Alphabet=>(Interval->Interval)) : Float =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Fin l and not just in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It gets us this error:

Type error:
Expected: ((Fin a) => Word8)
  Actual: (in => Word8)
(Solving for: [a:Int32])

  update = subdivide str rule

But yeah, it does look more succinct with in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Well it's not a big deal either way.

update = subdivide str rule
finalInterval = fold top update
fst finalInterval + (snd finalInterval)/2.

def decode (l: Int) (code: Float) (rule: Alphabet=>(Interval->Interval)) : List Word8 =
update = findInterval l code rule
initStr: List Word8 = AsList _ []
fst $ fold (initStr, top) update

'### Demo: Lossless compression on a test string

str' = "abbadcabccdd"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make a longer test? I'm not convinced that there aren't lurking floating-point issues in this implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A longer test here would fail. I could change all instances of Float to Float64 for more precision and pass the longer test, but the code would look bloated. i.e. this line,
top:Interval = (0., 1.)
would become
top:Interval = (FToF64 0.,FToF64 1.)
Is there a way to define a Float to arbitrary precision? I considered using integer arithmetic instead of floating-point arithmetic for better control over precision, but that code is still a WIP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. Would an even longer test then fail even if you used F64? I get that this is just a demo, but if it only works for short sequences we should put some warnings in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I'll probably make a more serious attempt at integer arithmetic before I commit to this implementation and add warnings.

(AsList l str) = str'

p = getProbability l $ getFrequency str
r = getUpdateRule p

code = encode str r
code
> 0.081569

decoded = decode l code r
decoded == str'
> True