Skip to content

Commit

Permalink
Merge pull request #236 from horw/ref/revise-the-code-flashing-binari…
Browse files Browse the repository at this point in the history
…es-to-target
  • Loading branch information
hfudev authored Feb 23, 2024
2 parents 360215d + 9f39de7 commit 4b0d191
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 168 deletions.
41 changes: 16 additions & 25 deletions pytest-embedded-arduino/pytest_embedded_arduino/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import esptool
from pytest_embedded_serial_esp.serial import EspSerial, EsptoolArgs
from pytest_embedded_serial_esp.serial import EspSerial

from .app import ArduinoApp

Expand Down Expand Up @@ -40,33 +40,24 @@ def flash(self) -> None:
"""
Flash the binary files to the board.
"""
flash_files = [
(offset, open(path, 'rb')) for (offset, path, encrypted) in self.app.flash_files if not encrypted
]
flash_files = []
for offset, path, encrypted in self.app.flash_files:
if encrypted:
continue
flash_files.extend((str(offset), path))

default_kwargs = {
'addr_filename': flash_files,
'encrypt_files': None,
'no_stub': False,
'compress': True,
'verify': False,
'ignore_flash_encryption_efuse_setting': False,
'erase_all': False,
'encrypt': False,
'force': False,
'chip': self.app.target,
}
flash_settings = []
for k, v in self.app.flash_settings[self.app.target].items():
flash_settings.append(f'--{k}')
flash_settings.append(v)

default_kwargs.update(self.app.flash_settings[self.app.target])
flash_args = EsptoolArgs(**default_kwargs)
if self.esp_flash_force:
flash_settings.append('--force')

try:
self.stub.change_baud(self.esptool_baud)
esptool.detect_flash_size(self.stub, flash_args)
esptool.write_flash(self.stub, flash_args)
self.stub.change_baud(self.baud)
esptool.main(
['--chip', self.app.target, 'write_flash', *flash_files, *flash_settings],
esp=self.esp,
)
except Exception:
raise
finally:
for _, f in flash_files:
f.close()
171 changes: 89 additions & 82 deletions pytest-embedded-idf/pytest_embedded_idf/serial.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import contextlib
import hashlib
import logging
import os
import tempfile
from typing import Optional, TextIO, Union

import esptool
from pytest_embedded.log import live_print_call
from pytest_embedded_serial_esp.serial import EspSerial, EsptoolArgs
from pytest_embedded_serial_esp.serial import EspSerial

from .app import IdfApp

Expand Down Expand Up @@ -84,30 +83,38 @@ def load_ram(self) -> None:
if self.app.bin_file:
bin_file = self.app.bin_file
else:
live_print_call(
with contextlib.redirect_stdout(self._q):
esptool.main(
[
'--chip',
self.app.target,
'elf2image',
self.app.elf_file,
*self.app.write_flash_args,
],
esp=self.esp,
)
bin_file = self.app.elf_file.replace('.elf', '.bin')

with contextlib.redirect_stdout(self._q):
esptool.main(
[
'esptool.py',
'--chip',
self.app.target,
'elf2image',
self.app.elf_file,
*self.app.write_flash_args,
'--no-stub',
'load_ram',
bin_file,
],
msg_queue=self._q,
esp=self.esp,
)
bin_file = self.app.elf_file.replace('.elf', '.bin')

live_print_call(
[
'esptool.py',
'--chip',
self.app.target,
'--no-stub',
'load_ram',
bin_file,
],
msg_queue=self._q,
)
def _force_flag(self):
if self.esp_flash_force:
return ['--force']
config = self.app.sdkconfig
if any((config.get('SECURE_FLASH_ENC_ENABLED', False), config.get('SECURE_BOOT', False))):
return ['--force']
return []

@EspSerial.use_esptool()
def flash(self) -> None:
Expand All @@ -122,59 +129,56 @@ def flash(self) -> None:
logging.error('No flash settings detected. Skipping auto flash...')
return

flash_files = [(file.offset, open(file.file_path, 'rb')) for file in self.app.flash_files if not file.encrypted]
encrypt_files = [(file.offset, open(file.file_path, 'rb')) for file in self.app.flash_files if file.encrypted]

nvs_file = None
try:
if self.erase_nvs:
address = self.app.partition_table['nvs']['offset']
size = self.app.partition_table['nvs']['size']
nvs_file = tempfile.NamedTemporaryFile(delete=False)
nvs_file.write(b'\xff' * size)
if not isinstance(address, int):
address = int(address, 0)

if self.app.flash_settings['encrypt']:
encrypt_files.append((address, open(nvs_file.name, 'rb')))
_args = []
for k, v in self.app.flash_args['extra_esptool_args'].items():
if isinstance(v, bool):
if k == 'stub':
if v is False:
_args.append('--no-stub')
elif v:
_args.append(f'--{k}')
else:
_args.append(f'--{k}')
if k == 'after':
_args.append('hard_reset')
else:
flash_files.append((address, open(nvs_file.name, 'rb')))

# write_flash expects the parameter encrypt_files to be None and not
# an empty list, so perform the check here
default_kwargs = {
'addr_filename': flash_files,
'encrypt_files': encrypt_files or None,
'no_stub': False,
'compress': True,
'verify': False,
'ignore_flash_encryption_efuse_setting': False,
'erase_all': False,
'force': False,
}

default_kwargs.update(self.app.flash_settings)
default_kwargs.update(self.app.flash_args.get('extra_esptool_args', {}))
args = EsptoolArgs(**default_kwargs)

self.stub.change_baud(self.esptool_baud)
esptool.detect_flash_size(self.stub, args)
esptool.write_flash(self.stub, args)
self.stub.change_baud(self.baud)

if self._meta:
self._meta.set_port_app_cache(self.port, self.app)
finally:
if nvs_file:
nvs_file.close()
try:
os.remove(nvs_file.name)
except OSError:
pass
for _, f in flash_files:
f.close()
for _, f in encrypt_files:
f.close()
_args.append(str(v))

_args.append('write_flash')

if self.erase_nvs:
esptool.main(
[
'erase_region',
str(self.app.partition_table['nvs']['offset']),
str(self.app.partition_table['nvs']['size']),
],
esp=self.esp,
)
self.esp.connect()

encrypt_files = []
flash_files = []
for file in self.app.flash_files:
if file.encrypted:
encrypt_files.extend([hex(file.offset), str(file.file_path)])
else:
flash_files.extend([hex(file.offset), str(file.file_path)])

if flash_files and encrypt_files:
_args.extend([*flash_files, '--encrypt-files', *encrypt_files])
else:
if flash_files:
_args.extend(flash_files)
else:
_args.extend(['--encrypt', *encrypt_files])

_args.extend([*self.app.flash_args['write_flash_args'], *self._force_flag()])

esptool.main(_args, esp=self.esp)

if self._meta:
self._meta.set_port_app_cache(self.port, self.app)

@EspSerial.use_esptool()
def dump_flash(
Expand Down Expand Up @@ -207,15 +211,12 @@ def dump_flash(
else:
raise ValueError('You must specify "partition" or ("address" and "size") to dump flash')

content = self.stub.read_flash(_addr, _size)
if output:
if isinstance(output, str):
os.makedirs(os.path.dirname(output), exist_ok=True)
with open(output, 'wb') as f:
f.write(content)
else:
output.write(content)
esptool.main(['read_flash', str(_addr), str(_size), str(output)], esp=self.esp)
else:
with tempfile.NamedTemporaryFile() as fp:
esptool.main(['read_flash', str(_addr), str(_size), fp.name], esp=self.esp)
content = fp.read()
return content

@EspSerial.use_esptool()
Expand All @@ -233,7 +234,7 @@ def erase_partition(self, partition_name: str) -> None:
address = self.app.partition_table[partition_name]['offset']
size = self.app.partition_table[partition_name]['size']
logging.info(f'Erasing the partition "{partition_name}" of size {size} at {address}')
self.stub.erase_region(address, size)
esptool.main(['erase_region', str(address), str(size), *self._force_flag()], esp=self.esp)
else:
raise ValueError(f'partition name "{partition_name}" not found in app partition table')

Expand All @@ -254,7 +255,13 @@ def read_flash_elf_sha256(self) -> bytes:
if not bin_offset:
raise ValueError('.bin file not found in flash files')

return self.stub.read_flash(bin_offset + self.DEFAULT_SHA256_OFFSET, 32)
with tempfile.NamedTemporaryFile() as fp:
esptool.main(
['read_flash', str(bin_offset + self.DEFAULT_SHA256_OFFSET), str(32), fp.name],
esp=self.esp,
)
content = fp.read()
return content

def is_target_flashed_same_elf(self) -> bool:
"""
Expand Down
37 changes: 37 additions & 0 deletions pytest-embedded-idf/tests/test_idf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,43 @@ def test_idf_serial_flash(dut):
result.assert_outcomes(passed=1)


def test_esp_flash_force_flag(testdir):
testdir.makepyfile("""
import pexpect
import pytest
def test_idf_serial_flash(dut):
dut.expect('Hello world!')
assert dut.serial.esp_flash_force == True
""")
result = testdir.runpytest(
'-s',
'--embedded-services', 'esp,idf',
'--app-path', os.path.join(testdir.tmpdir, 'hello_world_esp32'),
'--esp-flash-force',
)

result.assert_outcomes(passed=1)


def test_esp_flash_no_force_flag(testdir):
testdir.makepyfile("""
import pexpect
import pytest
def test_idf_serial_flash(dut):
dut.expect('Hello world!')
assert dut.serial.esp_flash_force == False
""")
result = testdir.runpytest(
'-s',
'--embedded-services', 'esp,idf',
'--app-path', os.path.join(testdir.tmpdir, 'hello_world_esp32'),
)

result.assert_outcomes(passed=1)


def test_expect_no_matching(testdir):
testdir.makepyfile("""
import pexpect
Expand Down
Loading

0 comments on commit 4b0d191

Please sign in to comment.