Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve mp_root_n code #532

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 27 additions & 35 deletions mp_root_n.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
*/
mp_err mp_root_n(const mp_int *a, int b, mp_int *c)
{
mp_int t1, t2, t3, a_;
mp_int t1, t2, t3, a_, d;
int ilog2;
mp_err err;

Expand All @@ -27,15 +27,15 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c)
return MP_VAL;
}

if ((err = mp_init_multi(&t1, &t2, &t3, NULL)) != MP_OKAY) {
if ((err = mp_init_multi(&t1, &t2, &t3, &d, NULL)) != MP_OKAY) {
return err;
}

/* if a is negative fudge the sign but keep track */
a_ = *a;
a_.sign = MP_ZPOS;

/* Compute seed: 2^(log_2(n)/b + 2)*/
/* Compute seed: 2^(log_2(n)/b + 1)*/
ilog2 = mp_count_bits(a);

/*
Expand All @@ -57,21 +57,21 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c)
err = MP_OKAY;
goto LBL_ERR;
}
ilog2 = ilog2 / b;
ilog2 = (ilog2 - 1) / b;
if (ilog2 == 0) {
mp_set(c, 1uL);
c->sign = a->sign;
err = MP_OKAY;
goto LBL_ERR;
}
/* Start value must be larger than root */
ilog2 += 2;
if ((err = mp_2expt(&t2,ilog2)) != MP_OKAY) goto LBL_ERR;
ilog2 += 1;
if ((err = mp_2expt(&t1, ilog2)) != MP_OKAY) goto LBL_ERR;

do {
/* t1 = t2 */
if ((err = mp_copy(&t2, &t1)) != MP_OKAY) goto LBL_ERR;
mp_ord cmp;

/* t2 = t1 - ((t1**b - a) / (b * t1**(b-1))) */
/* t2 = t1 - ceiling(((t1**b - a) / (b * t1**(b-1)))) */

/* t3 = t1**(b-1) */
if ((err = mp_expt_n(&t1, b - 1, &t3)) != MP_OKAY) goto LBL_ERR;
Expand All @@ -80,6 +80,14 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c)
/* t2 = t1**b */
if ((err = mp_mul(&t3, &t1, &t2)) != MP_OKAY) goto LBL_ERR;

cmp = mp_cmp(&t2, &a_);
if (cmp == MP_EQ || cmp == MP_LT) {
err = MP_OKAY;
mp_exch(&t1, c);
c->sign = a->sign;
goto LBL_ERR;
}

/* t2 = t1**b - a */
if ((err = mp_sub(&t2, &a_, &t2)) != MP_OKAY) goto LBL_ERR;

Expand All @@ -88,35 +96,19 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c)
if ((err = mp_mul_d(&t3, (mp_digit)b, &t3)) != MP_OKAY) goto LBL_ERR;

/* t3 = (t1**b - a)/(b * t1**(b-1)) */
if ((err = mp_div(&t2, &t3, &t3, NULL)) != MP_OKAY) goto LBL_ERR;
if ((err = mp_div(&t2, &t3, &t3, &d)) != MP_OKAY) goto LBL_ERR;
/* round up t3 - so t1 will be rounded down */
if(!mp_iszero(&d)) {
if ((err = mp_add_d(&t3, 1uL, &t3)) != MP_OKAY) goto LBL_ERR;
}

if ((err = mp_sub(&t1, &t3, &t2)) != MP_OKAY) goto LBL_ERR;
/* t1 = t1 - t3 */
if ((err = mp_sub(&t1, &t3, &t1)) != MP_OKAY) goto LBL_ERR;

/*
Number of rounds is at most log_2(root). If it is more it
got stuck, so break out of the loop and do the rest manually.
*/
if (ilog2-- == 0) {
break;
}
} while (mp_cmp(&t1, &t2) != MP_EQ);
/* while t3 != 1 */
} while (!((t3.used == 1u) && (t3.dp[0] == 1u)));

/* result can be off by a few so check */
/* Loop beneath can overshoot by one if found root is smaller than actual root */
for (;;) {
mp_ord cmp;
if ((err = mp_expt_n(&t1, b, &t2)) != MP_OKAY) goto LBL_ERR;
cmp = mp_cmp(&t2, &a_);
if (cmp == MP_EQ) {
err = MP_OKAY;
goto LBL_ERR;
}
if (cmp == MP_LT) {
if ((err = mp_add_d(&t1, 1uL, &t1)) != MP_OKAY) goto LBL_ERR;
} else {
break;
}
}
/* correct overshoot from above or from recurrence */
for (;;) {
if ((err = mp_expt_n(&t1, b, &t2)) != MP_OKAY) goto LBL_ERR;
Expand All @@ -134,7 +126,7 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c)
c->sign = a->sign;

LBL_ERR:
mp_clear_multi(&t1, &t2, &t3, NULL);
mp_clear_multi(&t1, &t2, &t3, &d, NULL);
return err;
}

Expand Down