Skip to content

Commit

Permalink
Update the method for retrieving the device label. (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuanhuan authored Aug 20, 2024
1 parent df6b13a commit b24ee6f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 0 deletions.
37 changes: 37 additions & 0 deletions core/embed/extmod/modtrezorconfig/modtrezorconfig.c
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,41 @@ STATIC mp_obj_t mod_trezorconfig_set_needs_backup(mp_obj_t needs_backup) {
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorconfig_set_needs_backup_obj,
mod_trezorconfig_set_needs_backup);

/// def get_val_len(app: int, key: int, public: bool = False) -> int:
/// """
/// Gets the length of the value of the given key for the given app (or None
/// if not set). Raises a RuntimeError if decryption or authentication of
/// the stored value fails.
/// """
STATIC mp_obj_t mod_trezorconfig_get_val_len(size_t n_args,
const mp_obj_t *args) {
uint32_t key = trezor_obj_get_uint(args[1]);

bool is_private = key & (1 << 31);

secbool (*reader)(uint16_t, void *, uint16_t) =
is_private ? se_get_private_region : se_get_public_region;

// key is position
key &= ~(1 << 31);

uint8_t temp[4] = {0};
if (sectrue != reader(key, temp, 3)) {
return mp_const_none;
}
// has flag
if (temp[0] != 1) {
return mp_const_none;
}

uint16_t len = 0;
len = (temp[1] << 8) + temp[2];

return mp_obj_new_int_from_uint(len);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_get_val_len_obj, 2,
3, mod_trezorconfig_get_val_len);

/// def get(app: int, key: int, public: bool = False) -> bytes | None:
/// """
/// Gets the value of the given key for the given app (or None if not set).
Expand Down Expand Up @@ -1118,6 +1153,8 @@ STATIC const mp_rom_map_elem_t mp_module_trezorconfig_globals_table[] = {
MP_ROM_PTR(&mod_trezorconfig_has_wipe_code_obj)},
{MP_ROM_QSTR(MP_QSTR_change_wipe_code),
MP_ROM_PTR(&mod_trezorconfig_change_wipe_code_obj)},
{MP_ROM_QSTR(MP_QSTR_get_val_len),
MP_ROM_PTR(&mod_trezorconfig_get_val_len_obj)},
{MP_ROM_QSTR(MP_QSTR_get), MP_ROM_PTR(&mod_trezorconfig_get_obj)},
{MP_ROM_QSTR(MP_QSTR_set), MP_ROM_PTR(&mod_trezorconfig_set_obj)},
{MP_ROM_QSTR(MP_QSTR_delete), MP_ROM_PTR(&mod_trezorconfig_delete_obj)},
Expand Down
9 changes: 9 additions & 0 deletions core/mocks/generated/trezorconfig.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def set_needs_backup(needs_backup: bool = False) -> bool:
"""


# extmod/modtrezorconfig/modtrezorconfig.c
def get_val_len(app: int, key: int, public: bool = False) -> int:
"""
Gets the length of the value of the given key for the given app (or None if not set).
Raises a RuntimeError if decryption or authentication of the stored
value fails.
"""


# extmod/modtrezorconfig/modtrezorconfig.c
def get(app: int, key: int, public: bool = False) -> bytes | None:
"""
Expand Down
4 changes: 4 additions & 0 deletions core/src/storage/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def get(app: int, key: int, public: bool = False) -> bytes | None:
return config.get(app, key, public)


def get_val_len(app: int, key: int, public: bool = False) -> int | None:
return config.get_val_len(app, key, public)


def delete(
app: int, key: int, public: bool = False, writable_locked: bool = False
) -> None:
Expand Down
5 changes: 5 additions & 0 deletions core/src/storage/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@
DEVICE_ID = struct_public["device_id"][0]
_VERSION = struct_public["version"][0]
_LANGUAGE = struct_public["language"][0]
_LABEL_DEPRECATED = struct_public["label_deprecated"][0]
_LABEL = struct_public["label"][0]
_USE_PASSPHRASE = struct_public["use_passphrase"][0]
_PASSPHRASE_ALWAYS_ON_DEVICE = struct_public["passphrase_always_on_device"][0]
Expand Down Expand Up @@ -723,6 +724,10 @@ def get_label() -> str:
global _LABEL_VALUE
if _LABEL_VALUE is None:
label = common.get(_NAMESPACE, _LABEL, True) # public
if label is None:
previous_label_len = common.get_val_len(_NAMESPACE, _LABEL_DEPRECATED, True)
if previous_label_len is not None and 0 < previous_label_len < 16:
label = common.get(_NAMESPACE, _LABEL_DEPRECATED, True)
_LABEL_VALUE = label.decode() if label else utils.DEFAULT_LABEL
return _LABEL_VALUE

Expand Down

0 comments on commit b24ee6f

Please sign in to comment.