Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port proofs for constant_time_ops.rs #559

Merged
merged 15 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fstar-helpers/fstar-bitvec/Tactics.GetBit.fst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ open FStar.Option
open Tactics.Utils
open Tactics.Pow2

open BitVecEq {}
open Tactics.Seq {norm_index, tactic_list_index}
open BitVecEq
open Tactics.Seq


let _ = Rust_primitives.Hax.array_of_list
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-kem/hax.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class extractAction(argparse.Action):

def __call__(self, parser, args, values, option_string=None) -> None:
# Extract platform interfaces
include_str = "+:**"
include_str = "+:** -**::x86::init::cpuid -**::x86::init::cpuid_count"
interface_include = "+**"
cargo_hax_into = [
"cargo",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,43 @@ open Core
open FStar.Mul

let inz (value: u8) =
let v__orig_value:u8 = value in
let value:u16 = cast (value <: u8) <: u16 in
let result:u16 =
((value |. (Core.Num.impl__u16__wrapping_add (~.value <: u16) 1us <: u16) <: u16) >>! 8l <: u16) &.
1us
let result:u8 =
cast ((Core.Num.impl__u16__wrapping_add (~.value <: u16) 1us <: u16) >>! 8l <: u16) <: u8
in
cast (result <: u16) <: u8
let res:u8 = result &. 1uy in
let _:Prims.unit =
if v v__orig_value = 0
then
(assert (value == zero);
lognot_lemma value;
assert ((~.value +. 1us) == zero);
assert ((Core.Num.impl__u16__wrapping_add (~.value <: u16) 1us <: u16) == zero);
logor_lemma value zero;
assert ((value |. (Core.Num.impl__u16__wrapping_add (~.value <: u16) 1us <: u16) <: u16) ==
value);
assert (v result == v ((value >>! 8l)));
assert ((v value / pow2 8) == 0);
assert (result == 0uy);
logand_lemma 1uy result;
assert (res == 0uy))
else
(assert (v value <> 0);
lognot_lemma value;
assert (v (~.value) = pow2 16 - 1 - v value);
assert (v (~.value) + 1 = pow2 16 - v value);
assert (v (value) <= pow2 8 - 1);
assert ((v (~.value) + 1) = (pow2 16 - pow2 8) + (pow2 8 - v value));
assert ((v (~.value) + 1) = (pow2 8 - 1) * pow2 8 + (pow2 8 - v value));
assert ((v (~.value) + 1) / pow2 8 = (pow2 8 - 1));
assert (v ((Core.Num.impl__u16__wrapping_add (~.value <: u16) 1us <: u16) >>! 8l) =
pow2 8 - 1);
assert (result = ones);
logand_lemma 1uy result;
assert (res = 1uy))
in
res

let is_non_zero (value: u8) = Core.Hint.black_box #u8 (inz value <: u8)

Expand All @@ -18,43 +49,143 @@ let compare (lhs rhs: t_Slice u8) =
let r:u8 =
Rust_primitives.Hax.Folds.fold_range (sz 0)
(Core.Slice.impl__len #u8 lhs <: usize)
(fun r temp_1_ ->
(fun r i ->
let r:u8 = r in
let _:usize = temp_1_ in
true)
let i:usize = i in
v i <= Seq.length lhs /\
(if (Seq.slice lhs 0 (v i) = Seq.slice rhs 0 (v i)) then r == 0uy else ~(r == 0uy)))
r
(fun r i ->
let r:u8 = r in
let i:usize = i in
r |. ((lhs.[ i ] <: u8) ^. (rhs.[ i ] <: u8) <: u8) <: u8)
let nr:u8 = r |. ((lhs.[ i ] <: u8) ^. (rhs.[ i ] <: u8) <: u8) in
let _:Prims.unit =
if r =. 0uy
then
(if (Seq.index lhs (v i) = Seq.index rhs (v i))
then
(logxor_lemma (Seq.index lhs (v i)) (Seq.index rhs (v i));
assert (((lhs.[ i ] <: u8) ^. (rhs.[ i ] <: u8) <: u8) = zero);
logor_lemma r ((lhs.[ i ] <: u8) ^. (rhs.[ i ] <: u8) <: u8);
assert (nr = r);
assert (forall j. Seq.index (Seq.slice lhs 0 (v i)) j == Seq.index lhs j);
assert (forall j. Seq.index (Seq.slice rhs 0 (v i)) j == Seq.index rhs j);
eq_intro (Seq.slice lhs 0 ((v i) + 1)) (Seq.slice rhs 0 ((v i) + 1)))
else
(logxor_lemma (Seq.index lhs (v i)) (Seq.index rhs (v i));
assert (((lhs.[ i ] <: u8) ^. (rhs.[ i ] <: u8) <: u8) <> zero);
logor_lemma r ((lhs.[ i ] <: u8) ^. (rhs.[ i ] <: u8) <: u8);
assert (v nr > 0);
assert (Seq.index (Seq.slice lhs 0 ((v i) + 1)) (v i) <>
Seq.index (Seq.slice rhs 0 ((v i) + 1)) (v i));
assert (Seq.slice lhs 0 ((v i) + 1) <> Seq.slice rhs 0 ((v i) + 1))))
else
(logor_lemma r ((lhs.[ i ] <: u8) ^. (rhs.[ i ] <: u8) <: u8);
assert (v nr >= v r);
assert (Seq.slice lhs 0 (v i) <> Seq.slice rhs 0 (v i));
if (Seq.slice lhs 0 ((v i) + 1) = Seq.slice rhs 0 ((v i) + 1))
then
(assert (forall j.
j < (v i) + 1 ==>
Seq.index (Seq.slice lhs 0 ((v i) + 1)) j ==
Seq.index (Seq.slice rhs 0 ((v i) + 1)) j);
eq_intro (Seq.slice lhs 0 (v i)) (Seq.slice rhs 0 (v i));
assert (False)))
in
let r:u8 = nr in
r)
in
is_non_zero r

let compare_ciphertexts_in_constant_time (lhs rhs: t_Slice u8) =
Core.Hint.black_box #u8 (compare lhs rhs <: u8)

#push-options "--ifuel 0 --z3rlimit 50"

let select_ct (lhs rhs: t_Slice u8) (selector: u8) =
let mask:u8 = Core.Num.impl__u8__wrapping_sub (is_non_zero selector <: u8) 1uy in
let _:Prims.unit =
assert (if selector = 0uy then mask = ones else mask = zero);
lognot_lemma mask;
assert (if selector = 0uy then ~.mask = zero else ~.mask = ones)
in
let out:t_Array u8 (sz 32) = Rust_primitives.Hax.repeat 0uy (sz 32) in
let out:t_Array u8 (sz 32) =
Rust_primitives.Hax.Folds.fold_range (sz 0)
Libcrux_ml_kem.Constants.v_SHARED_SECRET_SIZE
(fun out temp_1_ ->
(fun out i ->
let out:t_Array u8 (sz 32) = out in
let _:usize = temp_1_ in
true)
let i:usize = i in
v i <= v Libcrux_ml_kem.Constants.v_SHARED_SECRET_SIZE /\
(forall j.
j < v i ==>
(if (selector =. 0uy)
then Seq.index out j == Seq.index lhs j
else Seq.index out j == Seq.index rhs j)) /\
(forall j. j >= v i ==> Seq.index out j == 0uy))
out
(fun out i ->
let out:t_Array u8 (sz 32) = out in
let i:usize = i in
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out
i
(((lhs.[ i ] <: u8) &. mask <: u8) |. ((rhs.[ i ] <: u8) &. (~.mask <: u8) <: u8) <: u8)
<:
t_Array u8 (sz 32))
let _:Prims.unit = assert ((out.[ i ] <: u8) = 0uy) in
let outi:u8 =
((lhs.[ i ] <: u8) &. mask <: u8) |. ((rhs.[ i ] <: u8) &. (~.mask <: u8) <: u8)
in
let _:Prims.unit =
if (selector = 0uy)
then
(logand_lemma (lhs.[ i ] <: u8) mask;
assert (((lhs.[ i ] <: u8) &. mask <: u8) == (lhs.[ i ] <: u8));
logand_lemma (rhs.[ i ] <: u8) (~.mask);
assert (((rhs.[ i ] <: u8) &. (~.mask <: u8) <: u8) == zero);
logor_lemma ((lhs.[ i ] <: u8) &. mask <: u8)
((rhs.[ i ] <: u8) &. (~.mask <: u8) <: u8);
assert ((((lhs.[ i ] <: u8) &. mask <: u8) |.
((rhs.[ i ] <: u8) &. (~.mask <: u8) <: u8)
<:
u8) ==
(lhs.[ i ] <: u8));
logor_lemma (out.[ i ] <: u8) (lhs.[ i ] <: u8);
assert (((out.[ i ] <: u8) |.
(((lhs.[ i ] <: u8) &. mask <: u8) |.
((rhs.[ i ] <: u8) &. (~.mask <: u8) <: u8)
<:
u8)
<:
u8) ==
(lhs.[ i ] <: u8));
assert (outi = (lhs.[ i ] <: u8)))
else
(logand_lemma (lhs.[ i ] <: u8) mask;
assert (((lhs.[ i ] <: u8) &. mask <: u8) == zero);
logand_lemma (rhs.[ i ] <: u8) (~.mask);
assert (((rhs.[ i ] <: u8) &. (~.mask <: u8) <: u8) == (rhs.[ i ] <: u8));
logor_lemma (rhs.[ i ] <: u8) zero;
assert ((logor zero (rhs.[ i ] <: u8)) == (rhs.[ i ] <: u8));
assert ((((lhs.[ i ] <: u8) &. mask <: u8) |.
((rhs.[ i ] <: u8) &. (~.mask <: u8) <: u8)) ==
(rhs.[ i ] <: u8));
logor_lemma (out.[ i ] <: u8) (rhs.[ i ] <: u8);
assert (((out.[ i ] <: u8) |.
(((lhs.[ i ] <: u8) &. mask <: u8) |.
((rhs.[ i ] <: u8) &. (~.mask <: u8) <: u8)
<:
u8)
<:
u8) ==
(rhs.[ i ] <: u8));
assert (outi = (rhs.[ i ] <: u8)))
in
let out:t_Array u8 (sz 32) =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out i outi
in
out)
in
let _:Prims.unit = if (selector =. 0uy) then (eq_intro out lhs) else (eq_intro out rhs) in
out

#pop-options

let select_shared_secret_in_constant_time (lhs rhs: t_Slice u8) (selector: u8) =
Core.Hint.black_box #(t_Array u8 (sz 32)) (select_ct lhs rhs selector <: t_Array u8 (sz 32))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,67 @@ open Core
open FStar.Mul

/// Return 1 if `value` is not zero and 0 otherwise.
val inz (value: u8) : Prims.Pure u8 Prims.l_True (fun _ -> Prims.l_True)
val inz (value: u8)
: Prims.Pure u8
Prims.l_True
(ensures
fun result ->
let result:u8 = result in
Hax_lib.implies (value =. 0uy <: bool)
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result =. 0uy <: bool) &&
Hax_lib.implies (value <>. 0uy <: bool)
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result =. 1uy <: bool))

val is_non_zero (value: u8) : Prims.Pure u8 Prims.l_True (fun _ -> Prims.l_True)
val is_non_zero (value: u8)
: Prims.Pure u8
Prims.l_True
(ensures
fun result ->
let result:u8 = result in
Hax_lib.implies (value =. 0uy <: bool)
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result =. 0uy <: bool) &&
Hax_lib.implies (value <>. 0uy <: bool)
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result =. 1uy <: bool))

/// Return 1 if the bytes of `lhs` and `rhs` do not exactly
/// match and 0 otherwise.
val compare (lhs rhs: t_Slice u8)
: Prims.Pure u8
(requires (Core.Slice.impl__len #u8 lhs <: usize) =. (Core.Slice.impl__len #u8 rhs <: usize))
(fun _ -> Prims.l_True)
(ensures
fun result ->
let result:u8 = result in
Hax_lib.implies (lhs =. rhs <: bool)
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result =. 0uy <: bool) &&
Hax_lib.implies (lhs <>. rhs <: bool)
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result =. 1uy <: bool))

val compare_ciphertexts_in_constant_time (lhs rhs: t_Slice u8)
: Prims.Pure u8
(requires (Core.Slice.impl__len #u8 lhs <: usize) =. (Core.Slice.impl__len #u8 rhs <: usize))
(fun _ -> Prims.l_True)
(ensures
fun result ->
let result:u8 = result in
Hax_lib.implies (lhs =. rhs <: bool)
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result =. 0uy <: bool) &&
Hax_lib.implies (lhs <>. rhs <: bool)
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result =. 1uy <: bool))

/// If `selector` is not zero, return the bytes in `rhs`; return the bytes in
/// `lhs` otherwise.
Expand All @@ -27,19 +73,32 @@ val select_ct (lhs rhs: t_Slice u8) (selector: u8)
(requires
(Core.Slice.impl__len #u8 lhs <: usize) =. (Core.Slice.impl__len #u8 rhs <: usize) &&
(Core.Slice.impl__len #u8 lhs <: usize) =. Libcrux_ml_kem.Constants.v_SHARED_SECRET_SIZE)
(fun _ -> Prims.l_True)
(ensures
fun result ->
let result:t_Array u8 (sz 32) = result in
Hax_lib.implies (selector =. 0uy <: bool) (fun _ -> result =. lhs <: bool) &&
Hax_lib.implies (selector <>. 0uy <: bool) (fun _ -> result =. rhs <: bool))

val select_shared_secret_in_constant_time (lhs rhs: t_Slice u8) (selector: u8)
: Prims.Pure (t_Array u8 (sz 32))
(requires
(Core.Slice.impl__len #u8 lhs <: usize) =. (Core.Slice.impl__len #u8 rhs <: usize) &&
(Core.Slice.impl__len #u8 lhs <: usize) =. Libcrux_ml_kem.Constants.v_SHARED_SECRET_SIZE)
(fun _ -> Prims.l_True)
(ensures
fun result ->
let result:t_Array u8 (sz 32) = result in
Hax_lib.implies (selector =. 0uy <: bool) (fun _ -> result =. lhs <: bool) &&
Hax_lib.implies (selector <>. 0uy <: bool) (fun _ -> result =. rhs <: bool))

val compare_ciphertexts_select_shared_secret_in_constant_time (lhs_c rhs_c lhs_s rhs_s: t_Slice u8)
: Prims.Pure (t_Array u8 (sz 32))
(requires
(Core.Slice.impl__len #u8 lhs_c <: usize) =. (Core.Slice.impl__len #u8 rhs_c <: usize) &&
(Core.Slice.impl__len #u8 lhs_s <: usize) =. (Core.Slice.impl__len #u8 rhs_s <: usize) &&
(Core.Slice.impl__len #u8 lhs_s <: usize) =. Libcrux_ml_kem.Constants.v_SHARED_SECRET_SIZE)
(fun _ -> Prims.l_True)
(ensures
fun result ->
let result:t_Array u8 (sz 32) = result in
let selector = if lhs_c =. rhs_c then 0uy else 1uy in
Hax_lib.implies (selector =. 0uy <: bool) (fun _ -> result =. lhs_s <: bool) &&
Hax_lib.implies (selector <>. 0uy <: bool) (fun _ -> result =. rhs_s <: bool))
Loading
Loading