• python ssh库


    1. 版本1
    # -*- coding: utf-8 -*-
    
    import time
    import paramiko
    import sys
    import re
    from tenacity import retry, stop_after_attempt, wait_fixed
    import functools
    from concurrent import futures
    from threading import Lock
    import socket
    from paramiko.ssh_exception import SSHException
    from paramiko.ssh_exception import AuthenticationException
    
    
    class NetmikoTimeoutException(SSHException):
        """SSH session timed trying to connect to the device."""
        pass
    
    
    class NetmikoAuthenticationException(AuthenticationException):
        """SSH authentication exception based on Paramiko AuthenticationException."""
        pass
    
    
    MAX_BUFFER = 65535
    global_delay_factor=1
    NetMikoTimeoutException = NetmikoTimeoutException
    NetMikoAuthenticationException = NetmikoAuthenticationException
    
    executor = futures.ThreadPoolExecutor(1)
    
    
    def timeout(timeout):
        def decorator(func):
            functools.wraps(func)
            def wrapper(*args, **kw):
                return executor.submit(func, *args, **kw).result(timeout=timeout)
            return wrapper
        return decorator
    
    
    class ParaSession(object):
        # will init invoke_shell
        def __init__(self, hostname, password, port=22, username='root', timeout=60):
            self.t = None  # paramiko.Transport
            self.sftp = None
            self._closed = True
            self._channel_closed = True
            self._sftp_closed = True
            self.hostname = hostname
            self.password = password
            self.port = port
            self.username = username
            self.timeout = timeout
            print('- start to create SSH connection -')
            self.client = paramiko.SSHClient()
            self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            # noinspection PyBroadException
            try:
                self.client.connect(hostname=hostname,
                                    port=port,
                                    username=username,
                                    password=password,
                                    timeout=timeout)
                self._closed = False
            except Exception as e:
                print(f'Connection Error And please check the ENV: {str(e)}')
            else:
                try:
                    self.channel = self.client.invoke_shell()
                    self._channel_closed = False
                    print('- Connection created successfully -')
                except Exception as e:
                    try:
                        self.close()
                    except:
                        pass
                    print(f'Open channel failed And please check the ENV: {str(e)}')
    
        def sftp_client(self):
            """
            :return: sftp obj
            """
            try:
                self.t = paramiko.Transport((self.hostname, self.port))
                self.t.connect(username=self.username, password=self.password)
                self._sftp_closed = False
                self.sftp = paramiko.SFTPClient.from_transport(self.t)
                print('- SFTP created successfully -')
                return self.sftp
            except Exception as e:
                try:
                    self.close()
                except:
                    pass
                print(f'Open sftp failed And please check the ENV: {str(e)}')
    
        def close(self):
            if not self._closed:
                if not self._channel_closed:
                    self.channel.close()
                    self._channel_closed = True
                self.client.close()
                self._closed = True
            if not self._sftp_closed:
                self.sftp.close()
                self._sftp_closed = True
    
        def __del__(self):
            if not self._closed:
                self.close()
    
    
    class BaseSSHSession(object):
        def __init__(self, hostname, username, password, su_password, port=22,timeout=100):
            self.hostname = hostname
            self.password = password
            self.username = username
            self.su_password = su_password
            self.port = port
            self.global_delay_factor = global_delay_factor
            self.base_prompt = None
            self.root_prompt = None
            self.timeout = timeout
            self.RETURN = "
    "
            self.ssh_obj = ParaSession(hostname=hostname, username=username, password=password, port=port)
            self.remote_conn = self.ssh_obj.channel
            self.session_preparation()
            self._session_locker = Lock()
    
        # 使用的时候不一定是超时才会到这里,调试慎重
        # @retry(stop=stop_after_attempt(2), wait=wait_fixed(5))
        # @timeout(6)
        def session_preparation(self, delay_factor=1):
            EXIT = 'exit'+ self.RETURN
            self.set_base_prompt()
            # self.set_root_prompt()
            # self.write_channel(EXIT)
            self.clear_buffer()
    
            # self.exec_ssh_cmd('whoami')
            # print(self.base_prompt)
            # print(self.root_prompt)
    
        def exec_ssh_cmd(self, cmd, delay_factor=2):
            sleep_time = delay_factor * 0.1
            time.sleep(sleep_time)
            if hasattr(self.ssh_obj, 'channel'):
                try:
                    self.write_channel(cmd + self.RETURN)
                except Exception as e:
                    print(f'CMD send errors: {str(e)}')
                    return None
                time.sleep(delay_factor)
                print('Exec ssh command: %s' % str(cmd))
                try:
                    buff = ''
                    while not (self.base_prompt in buff):
                        resp = self.read_channel()
                        buff += resp
                except Exception as e:
                    print(f'CMD receive errors: {str(e)}')
                return resp
            else:
                self.close_all_session()
                print(f'session error and please check the ENV.')
                return None
    
        def open_sftp_session(self, hostname, password, port=22, username='root'):
            ssh_sftp = ParaSession(hostname=hostname, password=password, port=port, username=username).sftp_client()
            return ssh_sftp
    
        def get_remotefile(self, remote_path, local_path, session='default'):
            if hasattr(self.ssh_obj,  'sftp'):
                try:
                    self.ssh_obj.sftp.get(remote_path, local_path)
                    return True
                except Exception as e:
                    print(f'Get remote file errors: {str(e)}')
                    return False
            else:
                self.close_all_session()
                print(f'session: {session} is error and please check the ENV.')
                return None
    
        def put_localfile(self, local_path, remote_path, session='default'):
            if hasattr(self.ssh_obj,  'sftp'):
                try:
                    self.ssh_obj.sftp.put(local_path, remote_path)
                    return True
                except Exception as e:
                    print(f'Put local file errors: {str(e)}')
                    return False
            else:
                self.close_all_session()
                print(f'session: {session} is error and please check the ENV.')
                return None
    
        def close_session(self):
            self.ssh_obj.close()
    
        def close_all_session(self):
            self.close_session()
    
        def __del__(self):
            self.close_all_session()
    
        def _read_channel(self):
            output = ""
            while True:
                if self.remote_conn.recv_ready():
                    outbuf = self.remote_conn.recv(MAX_BUFFER)  # 会挂住,需要recv_ready()判断
                    if len(outbuf) == 0:
                        raise EOFError("Channel stream closed by remote device.")
                    output += outbuf.decode("utf-8", "ignore")
                else:
                    break
            return output
    
        def normalize_linefeeds(self, a_string):
            newline = re.compile("(
    
    
    |
    
    |
    |
    
    )")
            a_string = newline.sub("
    ", a_string)
            return re.sub("
    ", "
    ", a_string)
    
        def read_channel(self):
            return self._read_channel()
    
        def write_bytes(self, out_data, encoding="ascii"):
            """Legacy for Python2 and Python3 compatible byte stream."""
            if sys.version_info[0] >= 3:
                if isinstance(out_data, type("")):
                    if encoding == "utf-8":
                        return out_data.encode("utf-8")
                    else:
                        return out_data.encode("ascii", "ignore")
                elif isinstance(out_data, type(b"")):
                    return out_data
            msg = "Invalid value for out_data neither unicode nor byte string: {}".format(
                out_data
            )
            raise ValueError(msg)
    
        def _write_channel(self, out_data):
            self.remote_conn.sendall(self.write_bytes(out_data))
    
        def write_channel(self, out_data):
            self._write_channel(out_data)
    
        def find_prompt(self, delay_factor=1):
            # RETURN = "
    "
            sleep_time = delay_factor * 0.1
            time.sleep(sleep_time)
            prompt = self.read_channel().strip()
            # Check if the only thing you received was a newline
            count = 0
            while count <= 12 and not prompt:
                prompt = self.read_channel().strip()
                if not prompt:
                    self.write_channel(self.RETURN)
                    time.sleep(sleep_time)
                    if sleep_time <= 3:
                        # Double the sleep_time when it is small
                        sleep_time *= 2
                    else:
                        sleep_time += 1
                count += 1
            # If multiple lines in the output take the last line
            prompt = self.normalize_linefeeds(prompt)
            prompt = prompt.split("
    ")[-1]
            prompt = prompt.strip()
            if not prompt:
                raise ValueError(f"Unable to find prompt: {prompt}")
            time.sleep(delay_factor * 0.1)
            return prompt
    
        def clear_buffer(self, backoff=True, delay_factor=1):
            """Read any data available in the channel."""
            sleep_time = 0.1 * delay_factor
            for _ in range(10):
                time.sleep(sleep_time)
                data = self.read_channel()
                if not data:
                    break
                if backoff:
                    sleep_time *= 2
                    sleep_time = 3 if sleep_time >= 3 else sleep_time
    
        # 待废弃
        def root_su(self, password, delay_factor=1):
            sleep_time = delay_factor * 0.1
            RETURN = "
    "
            waite_for_password = re.compile("Password:")
            prompt = self.read_channel().strip()
            count = 0
            while count <= 13 and not prompt:
                prompt = self.read_channel().strip()
                if not prompt:
                    self.write_channel(RETURN)
                    time.sleep(sleep_time)
                    if sleep_time <= 3:
                        sleep_time *= 2
                    else:
                        sleep_time += 1
                else:
                    prompt = self.normalize_linefeeds(prompt)
                    prompt = prompt.split("
    ")[-1]
                    prompt = prompt.strip()
                    if prompt.endswith('$'):
                        self.write_channel('su' + RETURN)
                        time.sleep(sleep_time)
                        prompt = self.read_channel().strip()
                    if waite_for_password.search(prompt):
                        self.write_channel(password + RETURN)
                        time.sleep(sleep_time)
                        prompt = self.read_channel().strip()
                    if prompt.endswith('#'):
                        return prompt
                count += 1
            if not prompt:
                raise ValueError(f"Unable to find prompt: {prompt}")
    
        def set_base_prompt(self, delay_factor=1, prompt_terminator="$", ):
            prompt = self.find_prompt(delay_factor=delay_factor)
            if not prompt[-1] in prompt_terminator:
                raise ValueError(f"Prompt not found: {repr(prompt)}")
            self.base_prompt = prompt[:-1]
            return self.base_prompt
    
        # 待废弃
        def set_root_prompt(self, delay_factor=1, prompt_terminator="#"):
            prompt = self.root_su(self.su_password, delay_factor=delay_factor)
            if not prompt[-1] in prompt_terminator:
                raise ValueError(f"Router prompt not found: {repr(prompt)}")
            self.root_prompt = prompt[:-1]
            return self.root_prompt
    
        def check_base_prompt(self, check_sre, prompt_terminator="$"):
            return self.base_prompt + prompt_terminator in check_sre
    
        def check_root_prompt(self, check_sre, prompt_terminator="#"):
            return self.root_prompt + prompt_terminator in check_sre
    
        def enable(self, cmd="", pattern="ssword", secret="", re_flags=re.IGNORECASE):
            output = ""
            msg = (
                "Failed to enter su mode. Please ensure you pass "
                "the 'secret' argument to ConnectHandler."
            )
            if not self.check_enable_mode():
                self.write_channel(self.normalize_cmd(cmd))
                try:
                    output += self.read_until_prompt_or_pattern(
                        pattern=pattern, re_flags=re_flags
                    )
                    self.write_channel(self.normalize_cmd(secret))
                    # output += self.read_until_prompt(pattern="#")
                    output += self.read_until_prompt()
                except NetmikoTimeoutException:
                    raise ValueError(msg)
                if not self.check_enable_mode():
                    raise ValueError(msg)
            return output
    
        def exit_enable_mode(self, exit_command=""):
            output = ""
            if self.check_enable_mode():
                self.write_channel(self.normalize_cmd(exit_command))
                output += self.read_until_prompt()
                if self.check_enable_mode():
                    raise ValueError("Failed to exit enable mode.")
            return output
    
        def normalize_cmd(self, command):
            command = command.rstrip()
            command += self.RETURN
            return command
    
        def check_enable_mode(self, check_string=""):
            self.write_channel(self.RETURN)
            output = self.read_until_prompt()
            return check_string in output
    
        def read_until_prompt_or_pattern(self, pattern="", re_flags=0):
            combined_pattern = re.escape(self.base_prompt)
            if pattern:
                combined_pattern = r"({}|{})".format(combined_pattern, pattern)
            return self._read_channel_expect(combined_pattern, re_flags=re_flags)
    
        def _read_channel_expect(self, pattern="", re_flags=0, max_loops=150):
            output = ""
            if not pattern:
                # 这里设置问题导致hang,需要重写子类set_base_prompt
                pattern = re.escape(self.base_prompt)
                # pattern = re.escape('int4-Standard-PC-i440FX-PIIX-1996')
            i = 1
            loop_delay = 0.1
            # Default to making loop time be roughly equivalent to self.timeout
            if max_loops == 150:
                max_loops = int(self.timeout / loop_delay)
            while i < max_loops:
                try:
                    self._lock_netmiko_session()
                    new_data = self.remote_conn.recv(MAX_BUFFER)
                    if len(new_data) == 0:
                        raise EOFError("Channel stream closed by remote device.")
                    new_data = new_data.decode("utf-8", "ignore")
                    output += new_data
    
                except socket.timeout:
                    raise NetmikoTimeoutException(
                        "Timed-out reading channel, data not available."
                    )
                finally:
                    self._unlock_netmiko_session()
    
                if re.search(pattern, output, flags=re_flags):
                    return output
                time.sleep(loop_delay * self.global_delay_factor)
                i += 1
                # print('_read_channel_expect:',i,': ',loop_delay * self.global_delay_factor)
            raise NetmikoTimeoutException(
                f"Timed-out reading channel, pattern not found in output: {pattern}"
            )
    
        def read_until_prompt(self, *args, **kwargs):
            return self._read_channel_expect(*args, **kwargs)
    
        def _lock_netmiko_session(self, start=None):
            if not start:
                start = time.time()
            # Wait here until the SSH channel lock is acquired or until session_timeout exceeded
            while not self._session_locker.acquire(False) and not self._timeout_exceeded(
                start, "The netmiko channel is not available!"
            ):
                time.sleep(0.1)
            return True
    
        def _unlock_netmiko_session(self):
            if self._session_locker.locked():
                self._session_locker.release()
    
        def _timeout_exceeded(self, start, msg="Timeout exceeded!"):
            if not start:
                # Must provide a comparison time
                return False
            if time.time() - start > self.session_timeout:
                # session_timeout exceeded
                raise NetmikoTimeoutException(msg)
            return False
    
    
    class LinuxBaseConnection(BaseSSHSession):
        """Base Class for cisco-like behavior."""
    
        def check_enable_mode(self, check_string="#"):
            """Check if in enable mode. Return boolean."""
            return super().check_enable_mode(check_string=check_string)
    
        def enable(self, cmd="su", pattern="ssword", secret="nokia123", re_flags=re.IGNORECASE):
            """Enter enable mode."""
            return super().enable(cmd=cmd, pattern=pattern, secret=secret, re_flags=re_flags)
    
        def exit_enable_mode(self, exit_command="disable"):
            """Exits enable (privileged exec) mode."""
            return super().exit_enable_mode(exit_command=exit_command)
    
        def set_base_prompt(self):
            base_prompt_re = re.compile("@(w.*):")
            prompt = super().set_base_prompt()
            prompt = base_prompt_re.search(prompt)
            self.base_prompt = prompt[1]
            return self.base_prompt
    
    
    if __name__ == '__main__':
        host_ip = '10.101.35.249'
        user_name = 'int4'
        pass_word = 'nokia123'
        su_password = pass_word
    
        ssh_session_obj = LinuxBaseConnection(host_ip, user_name, pass_word, su_password)
        result_su = ssh_session_obj.exec_ssh_cmd('ifconfig')
        print(result_su)
    
        # 方式1,待废弃
        # ssh_session_obj.root_su(password=su_password)
        # 方式2
        ssh_session_obj.enable()
    
        result_su = ssh_session_obj.exec_ssh_cmd('whoami')
        print(result_su)
        result_su = ssh_session_obj.exec_ssh_cmd('ifconfig')
        print(result_su)
    
    
  • 相关阅读:
    ftp上传下载
    java生成xml
    Java:删除某文件夹下的所有文件
    java读取某个文件夹下的所有文件
    JFileChooser 中文API
    得到java异常printStackTrace的详细信息
    关于SQL命令中不等号(!=,<>)
    ABP前端保存notify提示在Edge浏览器显示null
    关于MY Sql 查询锁表信息和解锁表
    VS2019 backspace键失效,无法使用
  • 原文地址:https://www.cnblogs.com/amize/p/15146720.html
Copyright © 2020-2023  润新知