Skip to content

Commit

Permalink
feat(core): Implement extend_mnemonics() for SLIP-39.
Browse files Browse the repository at this point in the history
[no changelog]
  • Loading branch information
andrewkozlik committed Oct 10, 2024
1 parent d71d9e9 commit fb60bbe
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 3 deletions.
80 changes: 77 additions & 3 deletions core/src/trezor/crypto/slip39.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,69 @@ def split_ems(
return mnemonics


def extend_mnemonics(
share_count: int, # The number of shares to create.
mnemonics: list[str], # A threshold set of the old mnemonics.
) -> list[str]:
"""
Extends a set of mnemonics to the desired share_count, while maintaining the threshold. This,
for example, allows extending a 2-of-2 backup to 2-of-3, where the first two shares remain the
same. It also allows reconstructing lost shares by providing any threshold number of shares and
requesting the original share_count. The current implementation is limited to Slip39_Basic,
i.e. single group.
It is not possible to tell how many shares the user originally created, so if share_count is
less than the original number of shares, then this function will return the first share_count
shares.
"""

if not mnemonics:
raise MnemonicError("The list of mnemonics is empty.")

(
identifier,
extendable,
iteration_exponent,
group_threshold,
group_count,
groups,
) = _decode_mnemonics(mnemonics)

if group_threshold != 1 or group_count != 1 or len(groups) != 1:
raise MnemonicError("Extending advanced backups is not supported.")

threshold = groups[0][0]
shares = groups[0][1]
if len(shares) != threshold:
raise MnemonicError(
f"Wrong number of mnemonics. Expected {threshold} mnemonics, but {len(shares)} were provided."
)

if threshold == 1 and share_count > 1:
raise ValueError(
"Creating multiple member shares with member threshold 1 is not allowed. Use 1-of-1 member sharing instead."
)

shares = _extend_shares(share_count, list(shares))

mnemonics = []
for index, value in shares:
mnemonics.append(
_encode_mnemonic(
identifier,
extendable,
iteration_exponent,
group_index=0,
group_threshold=1,
group_count=1,
member_index=index,
member_threshold=threshold,
value=value,
)
)
return mnemonics


def recover_ems(mnemonics: list[str]) -> tuple[int, bool, int, bytes]:
"""
Combines mnemonic shares to obtain the encrypted master secret which was previously
Expand Down Expand Up @@ -457,9 +520,7 @@ def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes:
return hmac(hmac.SHA256, random_data, shared_secret).digest()[:_DIGEST_LENGTH_BYTES]


def _split_secret(
threshold: int, share_count: int, shared_secret: bytes
) -> list[tuple[int, bytes]]:
def _check_parameters(threshold: int, share_count: int) -> None:
if threshold < 1:
raise ValueError(
f"The requested threshold ({threshold}) must be a positive integer."
Expand All @@ -475,6 +536,12 @@ def _split_secret(
f"The requested number of shares ({share_count}) must not exceed {MAX_SHARE_COUNT}."
)


def _split_secret(
threshold: int, share_count: int, shared_secret: bytes
) -> list[tuple[int, bytes]]:
_check_parameters(threshold, share_count)

# If the threshold is 1, then the digest of the shared secret is not used.
if threshold == 1:
return [(i, shared_secret) for i in range(share_count)]
Expand All @@ -499,6 +566,13 @@ def _split_secret(
return shares


def _extend_shares(
share_count: int, old_shares: list[tuple[int, bytes]]
) -> list[tuple[int, bytes]]:
_check_parameters(len(old_shares), share_count)
return [(i, shamir.interpolate(old_shares, i)) for i in range(share_count)]


def _recover_secret(threshold: int, shares: list[tuple[int, bytes]]) -> bytes:
# If the threshold is 1, then the digest of the shared secret is not used.
if threshold == 1:
Expand Down
10 changes: 10 additions & 0 deletions core/tests/test_trezor.crypto.slip39.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ def test_basic_sharing_random(self):
slip39.recover_ems(mnemonics[:3]), slip39.recover_ems(mnemonics[2:])
)

def test_basic_sharing_extend(self):
identifier = slip39.generate_random_identifier()
for extendable in (False, True):
mnemonics = slip39.split_ems(1, [(2, 3)], identifier, extendable, 1, self.EMS)
mnemonics = mnemonics[0]
extended_mnemonics = slip39.extend_mnemonics(4, mnemonics[1:])
self.assertEqual(mnemonics, extended_mnemonics[:3])
for i in range(3):
self.assertEqual(slip39.recover_ems([extended_mnemonics[3], mnemonics[i]])[3], self.EMS)

def test_basic_sharing_fixed(self):
for extendable in (False, True):
generated_identifier = slip39.generate_random_identifier()
Expand Down

0 comments on commit fb60bbe

Please sign in to comment.