Source code for apssh.sshproxy

#!/usr/bin/env python3

"""
The SshProxy class models an ssh connection, and is mainly in charge
* lazily initializing connections on a need-by-need basis
* and reasembling lines as they come back from the remote
"""

import asyncio

import asyncssh

from .util import print_stderr, check_arg_type
# a dummy formatter
from .formatters import HostFormatter


class _LineBasedSession(asyncssh.SSHClientSession):
    """
    A session that records both outputs (out and err)
    in its internal attributes.
    It also may have an associated formatter (through its proxy reference)
    and in that case the formatter receives a line() call
    each time a line is received.
    """

    ##########
    class Channel:
        """
        typically a session will have one Channel for stdout and one for stderr

        aggregates text as it comes in
        .buffer: gathers the full contents
        .line: the current line
        """

        def __init__(self, name, proxy):
            self.name = name
            self.proxy = proxy
            # buffering
            self.buffer = ""
            self.line = ""

        # pylint: disable=c0111
        def data_received(self, data, datatype):
            # preserve it before any postprocessing occurs
            self.buffer += data
            # not adding a \n since it's already in there
            if self.proxy.debug:
                print_stderr(
                    f'BS {self.proxy.hostname} DR: -> {data} [[of type {self.name}]]')
            chunks = list(data.split("\n"))
            # len(chunks) cannot be 0
            assert chunks, "unexpected data received"
            # what goes in the current line, if any
            current_line = chunks.pop(0)
            self.line += current_line
            for chunk in chunks:
                # restore the \n that we removed by calling split
                self.flush(datatype, newline=True)
                self.line = chunk

        def flush(self, datatype, newline):
            # add newline to current line f requested
            if newline:
                self.line += "\n"
            # actually write line, if there's anything to write
            # (EOF calls flush too)
            if self.line:
                self.proxy.formatter.line(self.line, datatype,
                                          self.proxy.hostname)
                self.line = ""

    ##########
    def __init__(self, proxy, command, *args, **kwds):
        # self.proxy is expected to be set already by the closure/subclass
        self.proxy = proxy
        self.command = command
        self.stdout = self.Channel("stdout", proxy)
        self.stderr = self.Channel("stderr", proxy)
        self._exit = None
        super().__init__(*args, **kwds)

    # this seems right only for text streams...
    def data_received(self, data, datatype):
        channel = self.stderr if datatype == asyncssh.EXTENDED_DATA_STDERR \
            else self.stdout
        channel.data_received(data, datatype)

    def connection_made(self, chan):               # pylint:disable=w0221,w0613
        self.proxy.formatter.session_start(self.proxy.hostname, self.command)

    def connection_lost(self, exc):
        self.proxy.formatter.session_stop(self.proxy.hostname, self.command)

    def eof_received(self):
        self.stdout.flush(None, newline=False)
        self.stderr.flush(asyncssh.EXTENDED_DATA_STDERR, newline=False)
        self.proxy.debug_line("EOF")

    def exit_status_received(self, status):
        self._exit = status
        self.proxy.debug_line(f"STATUS = {status}\n")

    def exit_signal_received(self, signal,
                             core_dumped, msg, lang):   # pylint: disable=w0613
        # When a process now receive a signal that make him exit,
        # we will put the name of the signal as _exit so
        # that we avoid error type "task [...] returned None on node ...."
        self._exit = signal
        self.proxy.debug_line(f"SIGNAL = {signal}--{msg}\n")

# _VerboseClient is created through factories attached to each proxy

class _VerboseClient(asyncssh.SSHClient):

    # pylint: disable=c0111

    def __init__(self, proxy, direct, *args, **kwds):
        self.proxy = proxy
        self.formatter = proxy.formatter
        self.direct = direct
        self._connection_lost = False
        asyncssh.SSHClient.__init__(self, *args, **kwds)

    def connection_made(self, conn):
        self.formatter.connection_made(
            self.proxy.hostname, self.proxy.username, self.direct)

    # xxx we don't get this; at least, not always
    # the issue seems to be that we use close() on the asyncssh connection
    # which is a synchroneous call and I am not sure
    # for what other future I should await instead/afterwards
    # this actually triggers though occasionnally esp. with several targets
    def connection_lost(self, exc):
        self.formatter.connection_lost(
            self.proxy.hostname, exc, self.proxy.username)
        if exc:
            self._connection_lost = True

    def auth_completed(self):
        self.formatter.auth_completed(self.proxy.hostname, self.proxy.username)

####################


[docs] class SshProxy: # pylint: disable=r0902 """ A proxy essentially wraps an ssh connection. It can connect to a remote, and then can run several commands in the most general sense, i.e. including file transfers. Parameters: hostname: remote hostname username: remote login name gateway (SshProxy): when set, this node is then used as a hop for creating a 2-leg ssh connection. formatter: each SshProxy instance has an attached formatter that is in charge of rendering the output of the various commands. The default is to use an instance of :class:`~apssh.formatters.HostFormatter`, that outputs lines of the form ``hostname:actual-output`` verbose: allows to get some user-level feedback on ssh negociation. `Permission denied` messages and similar won't show up unless verbose is set. """ def __init__(self, hostname, *, username=None, gateway=None, # if another SshProxy is given # it is used as an ssh gateway keys=None, # this class has no smart way to guess for keys known_hosts=None, port=22, formatter=None, verbose=None, debug=False, timeout=30): # early type verifications check_arg_type(hostname, str, "SshProxy.hostname") self.hostname = hostname check_arg_type(username, (str, type(None)), "SshProxy.username") self.username = username self.known_hosts = known_hosts self.keys = keys if keys is not None else [] self.port = int(port) check_arg_type(gateway, (SshProxy, type(None)), "SshProxy.gateway") self.gateway = gateway # if not specified we use a basic colon formatter self.formatter = formatter or HostFormatter() if verbose is not None: self.formatter.verbose = verbose self.debug = debug self.timeout = timeout # self.conn, self.sftp_client = None, None self.client = None # critical sections require mutual exclusions self._connect_lock = asyncio.Lock() self._disconnect_lock = asyncio.Lock() # make this an asynchroneous context manager # async with SshProxy(...) as ssh: # async def __aenter__(self): await self.connect_lazy() return self async def __aexit__(self, exc_type, exc_value, traceback): # xxx this might be a little harsh, in the case # where an exception did occur await self.close() def __user_host__(self): return f"{self.username}@{self.hostname}" if self.username \ else "@" + self.hostname def __str__(self): text = ("" if not self.gateway else f"{self.gateway.__user_host__()}->") text += self.__user_host__() return text def __repr__(self): text = self.__str__() + " " text += "[no key] " if not self.keys else f"[{len(self.keys)} keys] " if self.conn: text += "<-SSH->" if self.sftp_client: text += "<-SFTP->" return f"<{type(self).__name__} {text}>" def __str__(self): text = ("" if not self.gateway else f"{self.gateway.__user_host__()}->") text += self.__user_host__() return text def debug_line(self, line): # pylint: disable=c0111 if line.endswith("\n"): line = line[:-1] line += " ((from:" + repr(self) + "))\n" if self.debug: self.formatter.line( line, asyncssh.EXTENDED_DATA_STDERR, self.hostname)
[docs] def is_connected(self): """ Returns: bool: whether the connection is up """ return self.conn is not None
[docs] async def connect_lazy(self): """ Connects if needed - uses a lock to make it safe for several coroutines to simultaneously try to run commands on the same SshProxy instance. Returns: connection object """ async with self._connect_lock: if self.conn is None: await self._connect() return self.conn
async def _connect(self): """ Unconditionnaly attemps to connect and raise an exception otherwise """ if self.gateway: return await self._connect_tunnel() return await self._connect_direct() async def _connect_direct(self): """ The code for connecting to the first ssh hop (i.e. when self.gateway is None) """ assert self.gateway is None # pylint: disable=c0111 class ClientClosure(_VerboseClient): # it is crucial that the first param here is *NOT* called self def __init__(client_self, *args, **kwds): # pylint: disable=e0213 _VerboseClient.__init__( client_self, self, direct=True, *args, **kwds) self.debug_line("SSH direct connecting") # second returned value is client, but is unused self.conn, self.client = \ await asyncio.wait_for( asyncssh.create_connection( ClientClosure, self.hostname, port=self.port, username=self.username, known_hosts=self.known_hosts, client_keys=self.keys, # it is rather crucial that we skip config-loading # at least to be consistent with prevous user-experience config=None, ), timeout=self.timeout) async def _connect_tunnel(self): """ The code to connect to a higher-degree hop We expect gateway to have its connection open, and issue connect_ssh on that connection """ # make sure the gateway has connected already assert self.gateway is not None await self.gateway.connect_lazy() # pylint: disable=c0111 class ClientClosure(_VerboseClient): def __init__(client_self, *args, **kwds): # pylint: disable=e0213 _VerboseClient.__init__( client_self, self, direct=False, *args, **kwds) self.debug_line("SSH tunnel connecting") # second returned value is client, but is unused try: self.conn, self.client = \ await asyncio.wait_for( self.gateway.conn.create_ssh_connection( ClientClosure, self.hostname, port=self.port, username=self.username, known_hosts=self.known_hosts, client_keys=self.keys ), timeout=self.timeout) self.debug_line("SSH tunnel connected") except asyncssh.misc.ChannelOpenError: self.formatter.stderr_line( f"Cannot open channel to {self.username}@{self.hostname}", self.hostname) raise
[docs] def is_sftp_connected(self): """ Returns: bool: whether the SFTP subsystem is up """ return self.sftp_client is not None
[docs] async def sftp_connect_lazy(self): """ Initializes SFTP connection if needed Returns: SFTP connection object """ await self.connect_lazy() async with self._connect_lock: if self.sftp_client is None: await self._sftp_connect() return self.sftp_client
async def _sftp_connect(self): if self.conn is None: return False try: self.sftp_client = await self.conn.start_sftp_client() self.formatter.sftp_start(self.hostname) except asyncssh.sftp.SFTPError: self.formatter.stderr_line( "Cannot start STFP subsystem", self.hostname) raise async def _close_sftp(self): """ close the SFTP client if relevant """ if self.sftp_client is not None: # set self.sftp_client to None *before* awaiting # to avoid duplicate attempts preserve = self.sftp_client self.sftp_client = None try: preserve.exit() except Exception: # pylint: disable=w0703 pass await preserve.wait_closed() self.formatter.sftp_stop(self.hostname) async def _close_ssh(self): """ close the SSH connection if relevant """ if self.conn is not None: preserve = self.conn self.conn = None try: preserve.close() # xxx harsh here too except Exception: pass await preserve.wait_closed() if self.client._connection_lost: # pylint: disable=protected-access raise ConnectionError("Close connection went wrong")
[docs] async def close(self): """ Close everything open, i.e. ssh connection and SFTP subsystem """ # beware that when used with asynciojobs, we often have several jobs # sharing the same proxy, and so there might be several calls to # close() sent to the same object at the same time... async with self._disconnect_lock: await self._close_sftp() await self._close_ssh()
##############################
[docs] async def run(self, command, **x11_kwds): """ Run a command, and write its output on the fly according to instance's formatter. Parameters: command: remote command to run x11_kwds: optional keyword args that will be passed to create_session, like typically ``x11_forwarding=True`` Returns: remote command exit status - or None if nothing could be run at all """ # pylint: disable=c0111 # this closure is a _LineBasedSession # with a .proxy attribute that points back here class SessionClosure(_LineBasedSession): # not using 'self' because 'self' is the SshProxy instance already def __init__(session_self, *args, **kwds): # pylint: disable=e0213 _LineBasedSession.__init__( session_self, self, command, *args, **kwds) chan, session = \ await asyncio.wait_for( self.conn.create_session(SessionClosure, command, **x11_kwds), timeout=self.timeout) await chan.wait_closed() return session._exit # pylint: disable=w0212
[docs] async def mkdir(self, remotedir): """ Create a remote directory if needed. Parameters: remotedir(str): remote repository to create. Returns: True if remote directory existed or could be created, False if SFTP subsystem could not be set up. Raises: asyncssh.sftp.SFTPError """ if not await self.sftp_connect_lazy(): return False exists = await self.sftp_client.isdir(remotedir) if exists: self.debug_line( f"{remotedir} already exists - no need to create") return True try: self.debug_line(f"actual creation of {remotedir}") await self.sftp_client.mkdir(remotedir) return True except asyncssh.sftp.SFTPError as exc: self.debug_line( f"Could not create {remotedir} on {self}\n{exc}") raise exc
# shows up first in doc
[docs] async def get_file_s(self, remotepaths, localpath, **kwds): """ Retrieve a collection of remote files locally into the same directory. The ssh connection and SFTP subsystem are created and set up if needed. Parameters: remotepaths(list): remote files to retrieve localpath: where to store them kwds: passed along to the underlying asyncssh's sftp client, typically: ``preserve``, ``recurse`` and ``follow_symlinks`` are honored like in http://asyncssh.readthedocs.io/en/latest/api.html#asyncssh.SFTPClient.get Returns: True if all went well, or raise exception """ await self.sftp_connect_lazy() try: self.debug_line( f"doing SFTP get with {remotepaths} -> {localpath}") await self.sftp_client.get(remotepaths, localpath, **kwds) except asyncssh.sftp.SFTPError as exc: self.debug_line( f"Could not SFTP GET remotes {remotepaths} to local {localpath}" f" - exception={exc}") raise exc return True
[docs] async def put_file_s(self, localpaths, remotepath, **kwds): """ Copy a collection of local files remotely into the same directory. The ssh connection and SFTP subsystem are created and set up if needed. Parameters: localpaths (list): files to copy remotepath (str): where to copy kwds: passed along to the underlying asyncssh's sftp client, typically: ``preserve``, ``recurse`` and ``follow_symlinks`` are honored like in http://asyncssh.readthedocs.io/en/latest/api.html#asyncssh.SFTPClient.put Returns: True if all went well, or raise exception """ await self.sftp_connect_lazy() try: self.debug_line( f"doing SFTP put with {localpaths} -> {remotepath}") await self.sftp_client.put(localpaths, remotepath, **kwds) except asyncssh.sftp.SFTPError as exc: self.debug_line( f"Could not SFTP PUT local {localpaths} to remote {remotepath}" f" - exception={exc}") raise exc return True
[docs] async def put_string_script(self, script_body, remotefile, **kwds): """ A convenience for copying over a local script before remote execution. The ssh connection and SFTP subsystem are created and set up if needed. Resulting remote file has mode `755`. Parameters: script_body (str): the **contents** of the script to create **WARNING** this is **not** a filename. remotefile: filename on the remote end kwds: passed along to http://asyncssh.readthedocs.io/en/latest/api.html#asyncssh.SFTPClient.open i.e. for setting ``encoding`` or ``errors``. Returns: True if all went well, or raise exception """ await self.sftp_connect_lazy() sftp_attrs = asyncssh.SFTPAttrs() sftp_attrs.permissions = 0o755 try: async with self.sftp_client.open(remotefile, pflags_or_mode='w', attrs=sftp_attrs, **kwds) as writer: await writer.write(script_body) except Exception as exc: self.debug_line( f"Could not create remotefile {remotefile} - exception={exc}") raise exc return True