nixos/tests/test-driver: better control test env symbols
Previous to this commit, the entire test driver environment was shared with the actual python test environment. This is a hefty api surface. This commit selectively exposes only those symbols to the test environment that are actually meant to be used by tests.
This commit is contained in:
parent
5edf5b60c3
commit
db614e11d6
2 changed files with 45 additions and 16 deletions
|
@ -89,9 +89,7 @@ CHAR_TO_KEY = {
|
|||
")": "shift-0x0B",
|
||||
}
|
||||
|
||||
# Forward references
|
||||
log: "Logger"
|
||||
machines: "List[Machine]"
|
||||
global log, machines, test_script
|
||||
|
||||
|
||||
def eprint(*args: object, **kwargs: Any) -> None:
|
||||
|
@ -103,7 +101,6 @@ def make_command(args: list) -> str:
|
|||
|
||||
|
||||
def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
|
||||
global log
|
||||
log.log("starting VDE switch for network {}".format(vlan_nr))
|
||||
vde_socket = tempfile.mkdtemp(
|
||||
prefix="nixos-test-vde-", suffix="-vde{}.ctl".format(vlan_nr)
|
||||
|
@ -246,6 +243,9 @@ def _perform_ocr_on_screenshot(
|
|||
|
||||
|
||||
class Machine:
|
||||
def __repr__(self) -> str:
|
||||
return f"<Machine '{self.name}'>"
|
||||
|
||||
def __init__(self, args: Dict[str, Any]) -> None:
|
||||
if "name" in args:
|
||||
self.name = args["name"]
|
||||
|
@ -910,29 +910,25 @@ class Machine:
|
|||
|
||||
|
||||
def create_machine(args: Dict[str, Any]) -> Machine:
|
||||
global log
|
||||
args["log"] = log
|
||||
return Machine(args)
|
||||
|
||||
|
||||
def start_all() -> None:
|
||||
global machines
|
||||
with log.nested("starting all VMs"):
|
||||
for machine in machines:
|
||||
machine.start()
|
||||
|
||||
|
||||
def join_all() -> None:
|
||||
global machines
|
||||
with log.nested("waiting for all VMs to finish"):
|
||||
for machine in machines:
|
||||
machine.wait_for_shutdown()
|
||||
|
||||
|
||||
def run_tests(interactive: bool = False) -> None:
|
||||
global machines
|
||||
if interactive:
|
||||
ptpython.repl.embed(globals(), locals())
|
||||
ptpython.repl.embed(test_symbols(), {})
|
||||
else:
|
||||
test_script()
|
||||
# TODO: Collect coverage data
|
||||
|
@ -942,12 +938,10 @@ def run_tests(interactive: bool = False) -> None:
|
|||
|
||||
|
||||
def serial_stdout_on() -> None:
|
||||
global log
|
||||
log._print_serial_logs = True
|
||||
|
||||
|
||||
def serial_stdout_off() -> None:
|
||||
global log
|
||||
log._print_serial_logs = False
|
||||
|
||||
|
||||
|
@ -989,6 +983,37 @@ def subtest(name: str) -> Iterator[None]:
|
|||
return False
|
||||
|
||||
|
||||
def _test_symbols() -> Dict[str, Any]:
|
||||
general_symbols = dict(
|
||||
start_all=start_all,
|
||||
test_script=globals().get("test_script"), # same
|
||||
machines=globals().get("machines"), # without being initialized
|
||||
log=globals().get("log"), # extracting those symbol keys
|
||||
os=os,
|
||||
create_machine=create_machine,
|
||||
subtest=subtest,
|
||||
run_tests=run_tests,
|
||||
join_all=join_all,
|
||||
serial_stdout_off=serial_stdout_off,
|
||||
serial_stdout_on=serial_stdout_on,
|
||||
)
|
||||
return general_symbols
|
||||
|
||||
|
||||
def test_symbols() -> Dict[str, Any]:
|
||||
|
||||
general_symbols = _test_symbols()
|
||||
|
||||
machine_symbols = {m.name: machines[idx] for idx, m in enumerate(machines)}
|
||||
print(
|
||||
"additionally exposed symbols:\n "
|
||||
+ ", ".join(map(lambda m: m.name, machines))
|
||||
+ ",\n "
|
||||
+ ", ".join(list(general_symbols.keys()))
|
||||
)
|
||||
return {**general_symbols, **machine_symbols}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
|
||||
arg_parser.add_argument(
|
||||
|
@ -1028,12 +1053,9 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
global test_script
|
||||
testscript = pathlib.Path(args.testscript).read_text()
|
||||
|
||||
def test_script() -> None:
|
||||
with log.nested("running the VM test script"):
|
||||
exec(testscript, globals())
|
||||
global log, machines, test_script
|
||||
|
||||
log = Logger()
|
||||
|
||||
|
@ -1062,6 +1084,11 @@ if __name__ == "__main__":
|
|||
process.terminate()
|
||||
log.close()
|
||||
|
||||
def test_script() -> None:
|
||||
with log.nested("running the VM test script"):
|
||||
symbols = test_symbols() # call eagerly
|
||||
exec(testscript, symbols, None)
|
||||
|
||||
interactive = args.interactive or (not bool(testscript))
|
||||
tic = time.time()
|
||||
run_tests(interactive)
|
||||
|
|
|
@ -42,7 +42,9 @@ rec {
|
|||
python <<EOF
|
||||
from pydoc import importfile
|
||||
with open('driver-symbols', 'w') as fp:
|
||||
fp.write(','.join(dir(importfile('${testDriverScript}'))))
|
||||
t = importfile('${testDriverScript}')
|
||||
test_symbols = t._test_symbols()
|
||||
fp.write(','.join(test_symbols.keys()))
|
||||
EOF
|
||||
'';
|
||||
|
||||
|
|
Loading…
Reference in a new issue