diff --git a/aexpect/client.py b/aexpect/client.py index fba78aa..6f9cd70 100644 --- a/aexpect/client.py +++ b/aexpect/client.py @@ -27,6 +27,7 @@ import subprocess import locale import logging +import asyncio from aexpect.exceptions import ExpectError from aexpect.exceptions import ExpectProcessTerminatedError @@ -806,6 +807,41 @@ def _read_nonblocking(self, internal_timeout=None, timeout=None): if end_time and time.time() > end_time: return read, data + async def _read_nonblocking_async(self, internal_timeout=None, timeout=None): + """ + Read from child until there is nothing to read for timeout seconds via a coroutine. + + All arguments are identical to the regular function. + """ + if internal_timeout is None: + internal_timeout = 100 + else: + internal_timeout *= 1000 + end_time = None + if timeout: + end_time = time.time() + timeout + expect_pipe = self._get_fd("expect") + poller = select.poll() + poller.register(expect_pipe, select.POLLIN) + data = "" + read = 0 + while True: + try: + poll_status = poller.poll(internal_timeout) + except select.error: + return read, data + if poll_status: + raw_data = os.read(expect_pipe, 1024) + if not raw_data: + return read, data + read += len(raw_data) + data += raw_data.decode(self.encoding, "ignore") + else: + return read, data + if end_time and time.time() > end_time: + return read, data + await asyncio.sleep(1) + def read_nonblocking(self, internal_timeout=None, timeout=None): """ Read from child until there is nothing to read for timeout seconds. @@ -925,6 +961,54 @@ def read_until_output_matches(self, patterns, filter_func=lambda x: x, # This shouldn't happen raise ExpectError(patterns, output) + async def read_until_output_matches_async(self, patterns, filter_func=lambda x: x, + timeout=60.0, internal_timeout=None, + print_func=None, match_func=None): + """ + Read from child using read_nonblocking until a pattern matches via a coroutine. + + All arguments are identical to the regular function. + """ + if not match_func: + match_func = self.match_patterns + expect_pipe = self._get_fd("expect") + poller = select.poll() + poller.register(expect_pipe, select.POLLIN) + output = "" + end_time = time.time() + timeout + while True: + try: + max_ms = int((end_time - time.time()) * 1000) + poll_timeout_ms = max(0, max_ms) + poll_status = poller.poll(poll_timeout_ms) + except select.error: + break + if not poll_status: + raise ExpectTimeoutError(patterns, output) + # Read data from child + read, data = await self._read_nonblocking_async(internal_timeout, + end_time - time.time()) + if not read: + break + if not data: + continue + # Print it if necessary- + if print_func: + for line in data.splitlines(): + print_func(line) + # Look for patterns + output += data + match = match_func(filter_func(output), patterns) + if match is not None: + return match, output + + # Check if the child has terminated + if utils_wait.wait_for(lambda: not self.is_alive(), 5, 0, 0.1): + raise ExpectProcessTerminatedError(patterns, self.get_status(), + output) + # This shouldn't happen + raise ExpectError(patterns, output) + def read_until_last_word_matches(self, patterns, timeout=60.0, internal_timeout=None, print_func=None): """ @@ -953,6 +1037,25 @@ def _get_last_word(cont): timeout, internal_timeout, print_func) + async def read_until_last_word_matches_async(self, patterns, timeout=60.0, + internal_timeout=None, print_func=None): + """ + Read using read_nonblocking until the last word of the output matches + one of the patterns (using match_patterns), or until timeout expires + via a coroutine. + + All arguments are identical to the regular function. + """ + + def _get_last_word(cont): + if cont: + return cont.split()[-1] + return "" + + return await self.read_until_output_matches_async(patterns, _get_last_word, + timeout, internal_timeout, + print_func) + def read_until_last_line_matches(self, patterns, timeout=60.0, internal_timeout=None, print_func=None): """ @@ -987,6 +1090,25 @@ def _get_last_nonempty_line(cont): timeout, internal_timeout, print_func) + async def read_until_last_line_matches_async(self, patterns, timeout=60.0, + internal_timeout=None, print_func=None): + """ + Read until the last non-empty line matches a pattern via a coroutine. + + All arguments are identical to the regular function. + """ + + def _get_last_nonempty_line(cont): + nonempty_lines = [_ for _ in cont.splitlines() if _.strip()] + if nonempty_lines: + return nonempty_lines[-1] + return "" + + return await self.read_until_output_matches_async(patterns, + _get_last_nonempty_line, + timeout, internal_timeout, + print_func) + def read_until_any_line_matches(self, patterns, timeout=60.0, internal_timeout=None, print_func=None): """ @@ -1016,6 +1138,19 @@ def read_until_any_line_matches(self, patterns, timeout=60.0, print_func, self.match_patterns_multiline) + async def read_until_any_line_matches_async(self, patterns, timeout=60.0, + internal_timeout=None, print_func=None): + """ + Read using read_nonblocking until any line matches a pattern via a coroutine. + + All arguments are identical to the regular function. + """ + return await self.read_until_output_matches_async(patterns, + lambda x: x.splitlines(), + timeout, internal_timeout, + print_func, + self.match_patterns_multiline) + class ShellSession(Expect): @@ -1170,6 +1305,18 @@ def read_up_to_prompt(self, timeout=60.0, internal_timeout=None, internal_timeout, print_func)[1] + async def read_up_to_prompt_async(self, timeout=60.0, internal_timeout=None, + print_func=None): + """ + Read until the last non-empty line matches the prompt via a coroutine. + + All arguments are identical to the regular function. + """ + _, data = await self.read_until_last_line_matches_async([self.prompt], timeout, + internal_timeout, + print_func) + return data + def cmd_output(self, cmd, timeout=60, internal_timeout=None, print_func=None, safe=False): """ @@ -1215,6 +1362,35 @@ def cmd_output(self, cmd, timeout=60, internal_timeout=None, return self.remove_last_nonempty_line(self.remove_command_echo(out, cmd)) + async def cmd_output_async(self, cmd, timeout=60, internal_timeout=None, + print_func=None, safe=False): + """ + Send a command and return its output via a coroutine. + + All arguments are identical to the regular function. + """ + if safe: + return await self.cmd_output_safe_async(cmd, timeout) + session_tag = f"[{self.output_prefix}] " if self.output_prefix else "" + LOG.debug("%sSending command: %s", session_tag, cmd) + self.read_nonblocking(0, timeout) + self.sendline(cmd) + try: + out = await self.read_up_to_prompt_async(timeout, internal_timeout, print_func) + except ExpectTimeoutError as error: + output = self.remove_command_echo(error.output, cmd) + raise ShellTimeoutError(cmd, output) from error + except ExpectProcessTerminatedError as error: + output = self.remove_command_echo(error.output, cmd) + raise ShellProcessTerminatedError(cmd, error.status, output) from error + except ExpectError as error: + output = self.remove_command_echo(error.output, cmd) + raise ShellError(cmd, output) from error + + # Remove the echoed command and the final shell prompt + return self.remove_last_nonempty_line(self.remove_command_echo(out, + cmd)) + def cmd_output_safe(self, cmd, timeout=60): """ Send a command and return its output (serial sessions). @@ -1264,6 +1440,42 @@ def cmd_output_safe(self, cmd, timeout=60): return self.remove_last_nonempty_line(self.remove_command_echo(out, cmd)) + async def cmd_output_safe_async(self, cmd, timeout=60): + """ + Send a command and return its output (serial sessions) via a coroutine. + + All arguments are identical to the regular function. + """ + session_tag = f"[{self.output_prefix}] " if self.output_prefix else "" + LOG.debug("%sSending command (safe): %s", session_tag, cmd) + self.read_nonblocking(0, timeout) + self.sendline(cmd) + out = "" + success = False + start_time = time.time() + while (time.time() - start_time) < timeout: + try: + out += await self.read_up_to_prompt_async(0.5) + success = True + break + except ExpectTimeoutError as error: + out = f"{out}{error.output}" + self.sendline() + except ExpectProcessTerminatedError as error: + output = self.remove_command_echo(f"{out}{error.output}", cmd) + raise ShellProcessTerminatedError(cmd, error.status, + output) from error + except ExpectError as error: + output = self.remove_command_echo(f"{out}{error.output}", cmd) + raise ShellError(cmd, output) from error + + if not success: + raise ShellTimeoutError(cmd, out) + + # Remove the echoed command and the final shell prompt + return self.remove_last_nonempty_line(self.remove_command_echo(out, + cmd)) + def cmd_status_output(self, cmd, timeout=60, internal_timeout=None, print_func=None, safe=False): """ @@ -1304,6 +1516,28 @@ def cmd_status_output(self, cmd, timeout=60, internal_timeout=None, return int(digit_lines[0].strip()), out raise ShellStatusError(cmd, out) + async def cmd_status_output_async(self, cmd, timeout=60, internal_timeout=None, + print_func=None, safe=False): + """ + Send a command and return its exit status and output via a coroutine. + + All arguments are identical to the regular function. + """ + out = await self.cmd_output_async(cmd, timeout, internal_timeout, print_func, safe) + try: + # Send the 'echo $?' (or equivalent) command to get the exit status + status = self.cmd_output(self.status_test_command, 30, + internal_timeout, print_func, safe) + except ShellError as error: + raise ShellStatusError(cmd, out) from error + + # Get the first line consisting of digits only + digit_lines = [_ for _ in status.splitlines() + if self.__RE_STATUS.match(_.strip())] + if digit_lines: + return int(digit_lines[0].strip()), out + raise ShellStatusError(cmd, out) + def cmd_status(self, cmd, timeout=60, internal_timeout=None, print_func=None, safe=False): """ @@ -1331,6 +1565,16 @@ def cmd_status(self, cmd, timeout=60, internal_timeout=None, return self.cmd_status_output(cmd, timeout, internal_timeout, print_func, safe)[0] + async def cmd_status_async(self, cmd, timeout=60, internal_timeout=None, + print_func=None, safe=False): + """ + Send a command and return its exit status via a coroutine. + + All arguments are identical to the regular function. + """ + return await self.cmd_status_output_async(cmd, timeout, internal_timeout, + print_func, safe)[0] + def cmd(self, cmd, timeout=60, internal_timeout=None, print_func=None, ok_status=None, ignore_all_errors=False): """ @@ -1372,6 +1616,28 @@ def cmd(self, cmd, timeout=60, internal_timeout=None, print_func=None, return None raise + async def cmd_async(self, cmd, timeout=60, internal_timeout=None, print_func=None, + ok_status=None, ignore_all_errors=False): + """ + Send a command and return its output via a coroutine. If the command's + exit status is nonzero, raise an exception. + + All arguments are identical to the regular function. + """ + if ok_status is None: + ok_status = [0, ] + try: + status, output = await self.cmd_status_output_async(cmd, timeout, + internal_timeout, + print_func) + if status not in ok_status: + raise ShellCmdError(cmd, status, output) + return output + except ShellError: + if ignore_all_errors: + return None + raise + def get_command_output(self, cmd, timeout=60, internal_timeout=None, print_func=None): """