shell_command: add support for sending data to a process

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-01-13 12:39:11 -05:00 committed by Eric Callahan
parent 8546cd6ac5
commit 89d0bbdb63

@ -181,7 +181,8 @@ class ShellCommand:
timeout: float = 2., timeout: float = 2.,
verbose: bool = True, verbose: bool = True,
log_complete: bool = True, log_complete: bool = True,
sig_idx: int = 1 sig_idx: int = 1,
proc_input: Optional[str] = None
) -> bool: ) -> bool:
async with self.run_lock: async with self.run_lock:
self.factory.add_running_command(self) self.factory.add_running_command(self)
@ -196,12 +197,18 @@ class ShellCommand:
): ):
# No callbacks set so output cannot be verbose # No callbacks set so output cannot be verbose
verbose = False verbose = False
if not await self._create_subprocess(use_callbacks=verbose): created = await self._create_subprocess(
verbose, proc_input is not None)
if not created:
self.factory.remove_running_command(self) self.factory.remove_running_command(self)
return False return False
assert self.proc is not None assert self.proc is not None
try: try:
ret = self.proc.wait() if proc_input is not None:
ret: Coroutine = self.proc.communicate(
input=proc_input.encode())
else:
ret = self.proc.wait()
await asyncio.wait_for(ret, timeout=timeout) await asyncio.wait_for(ret, timeout=timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
complete = False complete = False
@ -215,19 +222,23 @@ class ShellCommand:
timeout: float = 2., timeout: float = 2.,
retries: int = 1, retries: int = 1,
log_complete: bool = True, log_complete: bool = True,
sig_idx: int = 1 sig_idx: int = 1,
proc_input: Optional[str] = None
) -> str: ) -> str:
async with self.run_lock: async with self.run_lock:
self.factory.add_running_command(self) self.factory.add_running_command(self)
retries = max(1, retries) retries = max(1, retries)
stdin: Optional[bytes] = None
if proc_input is not None:
stdin = proc_input.encode()
while retries > 0: while retries > 0:
self._reset_command_data() self._reset_command_data()
timed_out = False timed_out = False
stdout = stderr = b"" stdout = stderr = b""
if await self._create_subprocess(): if await self._create_subprocess(has_input=stdin is not None):
assert self.proc is not None assert self.proc is not None
try: try:
ret = self.proc.communicate() ret = self.proc.communicate(input=stdin)
stdout, stderr = await asyncio.wait_for( stdout, stderr = await asyncio.wait_for(
ret, timeout=timeout) ret, timeout=timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -256,7 +267,10 @@ class ShellCommand:
f"Error running shell command: '{self.command}'", f"Error running shell command: '{self.command}'",
self.return_code, stdout, stderr) self.return_code, stdout, stderr)
async def _create_subprocess(self, use_callbacks: bool = False) -> bool: async def _create_subprocess(self,
use_callbacks: bool = False,
has_input: bool = False
) -> bool:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
def protocol_factory(): def protocol_factory():
@ -265,20 +279,24 @@ class ShellCommand:
std_out_cb=self.std_out_cb, std_err_cb=self.std_err_cb, std_out_cb=self.std_out_cb, std_err_cb=self.std_err_cb,
log_stderr=self.log_stderr) log_stderr=self.log_stderr)
try: try:
stdpipe: Optional[int] = None
if has_input:
stdpipe = asyncio.subprocess.PIPE
if self.std_err_cb is not None or self.log_stderr: if self.std_err_cb is not None or self.log_stderr:
errpipe = asyncio.subprocess.PIPE errpipe = asyncio.subprocess.PIPE
else: else:
errpipe = asyncio.subprocess.STDOUT errpipe = asyncio.subprocess.STDOUT
if use_callbacks: if use_callbacks:
transport, protocol = await loop.subprocess_exec( transport, protocol = await loop.subprocess_exec(
protocol_factory, *self.command, protocol_factory, *self.command, stdin=stdpipe,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=errpipe, env=self.env, cwd=self.cwd) stderr=errpipe, env=self.env, cwd=self.cwd)
self.proc = asyncio.subprocess.Process( self.proc = asyncio.subprocess.Process(
transport, protocol, loop) transport, protocol, loop)
else: else:
self.proc = await asyncio.create_subprocess_exec( self.proc = await asyncio.create_subprocess_exec(
*self.command, stdout=asyncio.subprocess.PIPE, *self.command, stdin=stdpipe,
stdout=asyncio.subprocess.PIPE,
stderr=errpipe, env=self.env, cwd=self.cwd) stderr=errpipe, env=self.env, cwd=self.cwd)
except Exception: except Exception:
logging.exception( logging.exception(