• 批量执行(Linux命令,上传/下载文件)


    前言:

                                                       

    每个公司的网络环境大都划分 办公网络、线上网络,之所以划分的主要原因是为了保证线上操作安全;

    对于外部用户而言也只能访问线上网络的特定开放端口,那么是什么控制了用户访问线上网络的呢?

    防火墙过滤......!

    对于内部员工而言对线上系统日常运维、代码部署如何安全访问线上业务系统呢?如何监控、记录技术人员的操作记录?

    堡垒机策略:

    1.回收所有远程登录Linux主机的用户名、密码;

    2.中间设置堡垒机(保存所有线上Linux主机的用户名、密码);

    3.所有技术人员都要通过堡垒机去获取用户名、密码,然后在再去连接 线上系统,并记录操作日志;

    堡垒机策略优点:

    1.记录用户操作;

    2.实现远程操作权限集中管理;

    一、堡垒机表结构设计

    from django.db import models
    from django.contrib.auth.models import  User
    # Create your models here.
    
    
    class IDC(models.Model):
        name = models.CharField(max_length=64,unique=True)
        def __str__(self):
            return self.name
    
    class Host(models.Model):
        """存储所有主机信息"""
        hostname = models.CharField(max_length=64,unique=True)
        ip_addr = models.GenericIPAddressField(unique=True)
        port = models.IntegerField(default=22)
        idc = models.ForeignKey("IDC")
        #host_groups = models.ManyToManyField("HostGroup")
        #host_users = models.ManyToManyField("HostUser")
        enabled = models.BooleanField(default=True)
    
        def __str__(self):
            return "%s-%s" %(self.hostname,self.ip_addr)
    
    class HostGroup(models.Model):
        """主机组"""
        name = models.CharField(max_length=64,unique=True)
        host_user_binds  = models.ManyToManyField("HostUserBind")
        def __str__(self):
            return self.name
    
    
    class HostUser(models.Model):
        """存储远程主机的用户信息
        root 123
        root abc
        root sfsfs
        """
        auth_type_choices = ((0,'ssh-password'),(1,'ssh-key'))
        auth_type = models.SmallIntegerField(choices=auth_type_choices)
        username = models.CharField(max_length=32)
        password = models.CharField(blank=True,null=True,max_length=128)
    
        def __str__(self):
            return "%s-%s-%s" %(self.get_auth_type_display(),self.username,self.password)
    
        class Meta:
            unique_together = ('username','password')
    
    
    class HostUserBind(models.Model):
        """绑定主机和用户"""
        host = models.ForeignKey("Host")
        host_user = models.ForeignKey("HostUser")
    
        def __str__(self):
            return "%s-%s" %(self.host,self.host_user)
    
        class Meta:
            unique_together = ('host','host_user')
    
    
    class SessionLog(models.Model):
        ''' 记录每个用户登录操作,ID传给 shell生成文件命名 '''
        account=models.ForeignKey('Account')
        host_user_bind=models.ForeignKey('HostUserBind')
        start_date=models.DateField(auto_now_add=True)
        end_date=models.DateField(blank=True,null=True)
    
        def __str__(self):
            return '%s-%s'%(self.account,self.host_user_bind)
    
    class AuditLog(models.Model):
        """审计日志"""
    
    
    class Account(models.Model):
        """堡垒机账户
        1. 扩展
        2. 继承
        user.account.host_user_bind
        """
    
        user = models.OneToOneField(User)
        name = models.CharField(max_length=64)
    
        host_user_binds = models.ManyToManyField("HostUserBind",blank=True)
        host_groups = models.ManyToManyField("HostGroup",blank=True)
    models.py

     二、通过堡垒机远程登录Linux主机

    2种堡垒机登录方式:

    命令行登录堡垒机方式:

    方式1:通过 修改open_shh源码扩展-Z option生成唯一 ssh进程,使用Linux的strace 命令对唯一 ssh进程进行检测生成日志文件;

    0.用户执行audit_shell出现交互界面,提示用户输入机组和主机;

    import sys,os,django
    os.environ.setdefault("DJANGO_SETTINGS_MODULE","zhanggen_audit.settings")
    django.setup() #在Django视图之外,调用Django功能设置环境变量!
    from audit.backend import user_interactive
    
    
    if __name__ == '__main__':
        shell_obj=user_interactive.UserShell(sys.argv)
        shell_obj.start()
    audit_shell.py
    from django.contrib.auth import authenticate
    
    class UserShell(object):
        '''用户登录堡垒机,启动自定制shell  '''
        def __init__(self,sys_argv):
            self.sys_argv=sys_argv
            self.user=None
    
        def auth(self):
            count=0
            while count < 3:
                username=input('username:').strip()
                password=input('password:').strip()
                user=authenticate(username=username,password=password)
                #none 代表认证失败,返回用户对象认证成功!
                if not user:
                    count+=1
                    print('无效的用户名或者,密码!')
                else:
                    self.user=user
                    return True
            else:
                print('输入次数超过3次!')
    
    
        def start(self):
            """启动交互程序"""
    
            if self.auth():
                # print(self.user.account.host_user_binds.all()) #select_related()
                while True:
                    host_groups = self.user.account.host_groups.all()
                    for index, group in enumerate(host_groups):
                        print("%s.	%s[%s]" % (index, group, group.host_user_binds.count()))
                    print("%s.	未分组机器[%s]" % (len(host_groups), self.user.account.host_user_binds.count()))
    
                    choice = input("select group>:").strip()
                    if choice.isdigit():
                        choice = int(choice)
                        host_bind_list = None
                        if choice >= 0 and choice < len(host_groups):
                            selected_group = host_groups[choice]
                            host_bind_list = selected_group.host_user_binds.all()
                        elif choice == len(host_groups):  # 选择的未分组机器
                            # selected_group = self.user.account.host_user_binds.all()
                            host_bind_list = self.user.account.host_user_binds.all()
    
                        if host_bind_list:
                            while True:
                                for index, host in enumerate(host_bind_list):
                                    print("%s.	%s" % (index, host,))
                                choice2 = input("select host>:").strip()
                                if choice2.isdigit():
                                    choice2 = int(choice2)
                                    if choice2 >= 0 and choice2 < len(host_bind_list):
                                        selected_host = host_bind_list[choice2]
                                        print("selected host", selected_host)
                                elif choice2 == 'b':
                                    break
    user_interactive.py

    知识点:

    在Django视图之外,调用Django功能设置环境变量!(切记放在和Django manage.py 同级目录); 

    import sys,os,django
    os.environ.setdefault("DJANGO_SETTINGS_MODULE","Sensors_Data.settings")
    django.setup()  # 在Django视图之外,调用Django功能设置环境变量!
    from app01 import models
    objs=models.AlarmInfo.objects.all()
    for row in objs:
        print(row.comment)

    注意:在Django启动时会自动加载一些 文件,比如每个app中admin.py,不能在这些文件里面设置加载环境变量,因为已经加载完了,如果违反这个规则会导致Django程序启动失败;

    1.实现ssh用户指令检测

    1.0  修改open_shh源码,扩展 ssh -Z  唯一标识符;(这样每次ssh远程登录,都可以利用唯一标识符,分辨出 每个ssh会话进程;)

    修改OpenSsh下的ssh.c文件的608和609行、935行增加;
        while ((opt = getopt(ac, av, "1246ab:c:e:fgi:kl:m:no:p:qstvxz:"
            "ACD:E:F:GI:J:KL:MNO:PQ:R:S:TVw:W:XYyZ:")) != -1) {
    
            case 'Z':
                break;
    ssh.c 

    知识点:

    OpenSSH 是 SSH (Secure SHell) 协议的免费开源实现项目。

    1.1  修改openssh之后,编译、安装

    chmod 755 configure
    ./configure --prefix=/usr/local/openssh
    make
    chmod 755 mkinstalldirs
    make install
    sshpass -p xxxxxx123 /usr/local/openssh/bin/ssh root@172.17.10.112 -Z s1123ssssd212

    1.2  每个ssh会话进程可以唯一标识之后,在堡垒机使用会话脚本shell脚本检测 ssh会话进程;(strace命令进行监控,并生产 log日志文件);

    #!/usr/bin/bash
    
    for i in $(seq 1 30);do
        echo $i $1
        process_id=`ps -ef | grep $1 | grep -v 'ession_check.sh' | grep -v grep | grep -v sshpass | awk '{print $2}'`
    
        echo "process: $process_id"
    
        if [ ! -z "$process_id" ];then
            echo 'start run strace.....'
            strace -fp $process_id -t -o $2.log;
            break;
        fi
    
        sleep 5
    
    done;
    ssh 会话检测脚本

    知识点:

    strace 检测进程的IO调用,监控用户shell输入的命令字符;

     strace -fp 60864 -o /ssh.log 
     cat /ssh.log |grep 'write(8'
     rz -E #从xshell上传文件

     sshpass无需提示输入密码登录

    [root@localhost sshpass-1.06]# sshpass -p wsnb ssh root@172.16.22.1  -o StrictHostKeyChecking=no 
    Last login: Tue Jul 10 16:39:53 2018 from 192.168.113.84
    [root@ecdb ~]# 

    python生成唯一标识符

    s=string.ascii_lowercase+string.digits
    random_tag=''.join(random.sample(s,10))

    解决普通用户,无法执行 strace命令;

    方式1:执行文件  +s权限

    chmod u+s `which strace`

    方式2:修改sudo配置文件,使普通用户sudo时无需输入密码!

    修改sudo配置文件,防止修改出错,一定要切换到root用户;
    
    
    %普通用户  ALL=(ALL)       NOPASSWD: ALL
    
    wq! #退出
    vim /etc/sudoers
    #!/usr/bin/python3
    # -*- coding: utf-8 -*
    from django.contrib.auth import authenticate
    import subprocess,string,random
    from audit import models
    from django.conf import settings
    class UserShell(object):
        '''用户登录堡垒机,启动自定制shell  '''
        def __init__(self,sys_argv):
            self.sys_argv=sys_argv
            self.user=None
    
        def auth(self):
            count=0
            while count < 3:
                username=input('username:').strip()
                password=input('password:').strip()
                user=authenticate(username=username,password=password)
                #none 代表认证失败,返回用户对象认证成功!
                if not user:
                    count+=1
                    print('无效的用户名或者,密码!')
                else:
                    self.user=user
                    return True
            else:
                print('输入次数超过3次!')
    
    
        def start(self):
            """启动交互程序"""
    
            if self.auth():
                # print(self.user.account.host_user_binds.all()) #select_related()
                while True:
                    host_groups = self.user.account.host_groups.all()
                    for index, group in enumerate(host_groups):
                        print("%s.	%s[%s]" % (index, group, group.host_user_binds.count()))
                    print("%s.	未分组机器[%s]" % (len(host_groups), self.user.account.host_user_binds.count()))
    
                    choice = input("select group>:").strip()
                    if choice.isdigit():
                        choice = int(choice)
                        host_bind_list = None
                        if choice >= 0 and choice < len(host_groups):
                            selected_group = host_groups[choice]
                            host_bind_list = selected_group.host_user_binds.all()
                        elif choice == len(host_groups):  # 选择的未分组机器
                            # selected_group = self.user.account.host_user_binds.all()
                            host_bind_list = self.user.account.host_user_binds.all()
    
                        if host_bind_list:
                            while True:
                                for index, host in enumerate(host_bind_list):
                                    print("%s.	%s" % (index, host,))
                                choice2 = input("select host>:").strip()
                                if choice2.isdigit():
                                    choice2 = int(choice2)
                                    if choice2 >= 0 and choice2 < len(host_bind_list):
                                        selected_host = host_bind_list[choice2]
                                        s = string.ascii_lowercase + string.digits
                                        random_tag = ''.join(random.sample(s, 10))
                                        session_obj=models.SessionLog.objects.create(account=self.user.account,host_user_bind=selected_host)
    
                                        session_tracker_scipt='/bin/sh %s %s %s'%(settings.SESSION_TRACKER_SCRIPT,random_tag,session_obj.pk)
    
                                        session_tracker_process=subprocess.Popen(session_tracker_scipt,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
                                        cmd='sshpass -p %s /usr/local/openssh/bin/ssh %s@%s -p %s -o stricthostkeychecking=no -Z %s' % (selected_host.host_user.password,
                                                                                                                 selected_host.host_user.username,
                                                                                                                 selected_host.host.ip_addr,
                                                                                                                 selected_host.host.port,random_tag)
                                        subprocess.run(cmd,shell=True)#开启子进程交互
                                        print(session_tracker_process.stdout.readlines(),
                                              session_tracker_process.stderr.readlines())
    
    
                                elif choice2 == 'b':
                                    break
    汇总

    2.shell远程登录程序检查日志文件,分析;

    tab补全的命令,需要搜素write(5,该脚本实现思路,按键去尝试,循环多种条件判断;

    import re
    
    
    class AuditLogHandler(object):
        '''分析audit log日志'''
    
        def __init__(self,log_file):
            self.log_file_obj = self._get_file(log_file)
    
    
        def _get_file(self,log_file):
    
            return open(log_file)
    
        def parse(self):
            cmd_list = []
            cmd_str = ''
            catch_write5_flag = False #for tab complication
            for line in self.log_file_obj:
                #print(line.split())
                line = line.split()
                try:
                    pid,time_clock,io_call,char = line[0:4]
                    if io_call.startswith('read(4'):
                        if char == '"\177",':#回退
                            char = '[1<-del]'
                        if char == '"\33OB",': #vim中下箭头
                            char = '[down 1]'
                        if char == '"\33OA",': #vim中下箭头
                            char = '[up 1]'
                        if char == '"\33OC",': #vim中右移
                            char = '[->1]'
                        if char == '"\33OD",': #vim中左移
                            char = '[1<-]'
                        if char == '"33[2;2R",': #进入vim模式
                            continue
                        if char == '"\33[>1;95;0c",':  # 进入vim模式
                            char = '[----enter vim mode-----]'
    
    
                        if char == '"\33[A",': #命令行向上箭头
                            char = '[up 1]'
                            catch_write5_flag = True #取到向上按键拿到的历史命令
                        if char == '"\33[B",':  # 命令行向上箭头
                            char = '[down 1]'
                            catch_write5_flag = True  # 取到向下按键拿到的历史命令
                        if char == '"\33[C",':  # 命令行向右移动1位
                            char = '[->1]'
                        if char == '"\33[D",':  # 命令行向左移动1位
                            char = '[1<-]'
    
                        cmd_str += char.strip('"",')
                        if char == '"\t",':
                            catch_write5_flag = True
                            continue
                        if char == '"\r",':
                            cmd_list.append([time_clock,cmd_str])
                            cmd_str = ''  # 重置
                        if char == '"':#space
                            cmd_str += ' '
    
                    if catch_write5_flag: #to catch tab completion
                        if io_call.startswith('write(5'):
                            if io_call == '"7",': #空键,不是空格,是回退不了就是这个键
                                pass
                            else:
                                cmd_str += char.strip('"",')
                            catch_write5_flag = False
                except ValueError as e:
                    print("33[031;1mSession log record err,please contact your IT admin,33[0m",e)
    
            #print(cmd_list)
            for cmd in cmd_list:
                print(cmd)
            return cmd_list
    if __name__ == "__main__":
        parser = AuditLogHandler(r'D:zhanggen_auditlog6.log')
        parser.parse()
    日志分析

    3.修改bashrc文件,限制用户登录行为;

    alias rm='rm -i'
    alias cp='cp -i'
    alias mv='mv -i'
    
    # Source global definitions
    if [ -f /etc/bashrc ]; then
            . /etc/bashrc
    fi
    
    
    
    echo '-----------------------welcome  to  zhanggen  audit  --------------------------'
    
    python3 /root/zhanggen_audit/audit_shell.py
    
    echo 'bye'
    
    logout
    vim ~/.bashrc

    缺陷:

    虽然限制了用户shell登录,但无法阻止用户使用程序(paramiko)上传恶意文件!

    方式2:提取paramiko源码demos文件,对其进行修改支持交互式操作;

    from django.db import models
    from django.contrib.auth.models import  User
    # Create your models here.
    
    
    class IDC(models.Model):
        name = models.CharField(max_length=64,unique=True)
        def __str__(self):
            return self.name
    
    class Host(models.Model):
        """存储所有主机信息"""
        hostname = models.CharField(max_length=64,unique=True)
        ip_addr = models.GenericIPAddressField(unique=True)
        port = models.IntegerField(default=22)
        idc = models.ForeignKey("IDC")
        #host_groups = models.ManyToManyField("HostGroup")
        #host_users = models.ManyToManyField("HostUser")
        enabled = models.BooleanField(default=True)
    
        def __str__(self):
            return "%s-%s" %(self.hostname,self.ip_addr)
    
    class HostGroup(models.Model):
        """主机组"""
        name = models.CharField(max_length=64,unique=True)
        host_user_binds  = models.ManyToManyField("HostUserBind")
        def __str__(self):
            return self.name
    
    
    class HostUser(models.Model):
        """存储远程主机的用户信息
        root 123
        root abc
        root sfsfs
        """
        auth_type_choices = ((0,'ssh-password'),(1,'ssh-key'))
        auth_type = models.SmallIntegerField(choices=auth_type_choices)
        username = models.CharField(max_length=32)
        password = models.CharField(blank=True,null=True,max_length=128)
    
        def __str__(self):
            return "%s-%s-%s" %(self.get_auth_type_display(),self.username,self.password)
    
        class Meta:
            unique_together = ('username','password')
    
    
    class HostUserBind(models.Model):
        """绑定主机和用户"""
        host = models.ForeignKey("Host")
        host_user = models.ForeignKey("HostUser")
    
        def __str__(self):
            return "%s-%s" %(self.host,self.host_user)
    
        class Meta:
            unique_together = ('host','host_user')
    
    
    class AuditLog(models.Model):
        """审计日志"""
        session = models.ForeignKey("SessionLog")
        cmd = models.TextField()
        date = models.DateTimeField(auto_now_add=True)
        def __str__(self):
            return "%s-%s" %(self.session,self.cmd)
    
    
    class SessionLog(models.Model):
        account = models.ForeignKey("Account")
        host_user_bind = models.ForeignKey("HostUserBind")
        start_date = models.DateTimeField(auto_now_add=True)
        end_date = models.DateTimeField(blank=True,null=True)
    
        def __str__(self):
            return "%s-%s" %(self.account,self.host_user_bind)
    
    
    class Account(models.Model):
        """堡垒机账户
        1. 扩展
        2. 继承
        user.account.host_user_bind
        """
    
        user = models.OneToOneField(User)
        name = models.CharField(max_length=64)
    
        host_user_binds = models.ManyToManyField("HostUserBind",blank=True)
        host_groups = models.ManyToManyField("HostGroup",blank=True)
    model.py
    __author__ = 'Administrator'
    import subprocess,random,string
    from django.contrib.auth import authenticate
    from django.conf import settings 
    from audit import models
    from audit.backend import ssh_interactive 
    
    class UserShell(object):
        """用户登录堡垒机后的shell"""
    
        def __init__(self,sys_argv):
            self.sys_argv = sys_argv
            self.user = None
    
        def auth(self):
    
            count = 0
            while count < 3:
                username = input("username:").strip()
                password = input("password:").strip()
                user = authenticate(username=username,password=password)
                #None 代表认证不成功
                #user object ,认证对象 ,user.name
                if not user:
                    count += 1
                    print("Invalid username or password!")
                else:
                    self.user = user
                    return  True
            else:
                print("too many attempts.")
    
        def start(self):
            """启动交互程序"""
    
            if self.auth():
                #print(self.user.account.host_user_binds.all()) #select_related()
                while True:
                    host_groups = self.user.account.host_groups.all()
                    for index,group in enumerate(host_groups):
                        print("%s.	%s[%s]"%(index,group,group.host_user_binds.count()))
                    print("%s.	未分组机器[%s]"%(len(host_groups),self.user.account.host_user_binds.count()))
                    try:
                        choice = input("select group>:").strip()
                        if choice.isdigit():
                            choice = int(choice)
                            host_bind_list = None
                            if choice >=0 and choice < len(host_groups):
                                selected_group = host_groups[choice]
                                host_bind_list = selected_group.host_user_binds.all()
                            elif choice == len(host_groups): #选择的未分组机器
                                #selected_group = self.user.account.host_user_binds.all()
                                host_bind_list = self.user.account.host_user_binds.all()
                            if host_bind_list:
                                while True:
                                    for index,host in enumerate(host_bind_list):
                                        print("%s.	%s"%(index,host,))
                                    choice2 = input("select host>:").strip()
                                    if choice2.isdigit():
                                        choice2 = int(choice2)
                                        if choice2 >=0 and choice2 < len(host_bind_list):
                                            selected_host = host_bind_list[choice2]
    
                                            ssh_interactive.ssh_session(selected_host,self.user)
    
    
                                            # s = string.ascii_lowercase +string.digits
                                            # random_tag = ''.join(random.sample(s,10))
                                            # session_obj = models.SessionLog.objects.create(account=self.user.account,host_user_bind=selected_host)
                                            #
                                            # cmd = "sshpass -p %s /usr/local/openssh/bin/ssh %s@%s -p %s -o StrictHostKeyChecking=no -Z %s" %(selected_host.host_user.password,selected_host.host_user.username,selected_host.host.ip_addr,selected_host.host.port ,random_tag)
                                            # #start strace ,and sleep 1 random_tag, session_obj.id
                                            # session_tracker_script = "/bin/sh %s %s %s " %(settings.SESSION_TRACKER_SCRIPT,random_tag,session_obj.id)
                                            #
                                            # session_tracker_obj =subprocess.Popen(session_tracker_script, shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
                                            #
                                            # ssh_channel = subprocess.run(cmd,shell=True)
                                            # print(session_tracker_obj.stdout.read(), session_tracker_obj.stderr.read())
                                            #
                                    elif choice2 == 'b':
                                        break
    
                    except KeyboardInterrupt as e :
                        pass
    user_interactive.py
    #!/usr/bin/env python
    
    # Copyright (C) 2003-2007  Robey Pointer <robeypointer@gmail.com>
    #
    # This file is part of paramiko.
    #
    # Paramiko is free software; you can redistribute it and/or modify it under the
    # terms of the GNU Lesser General Public License as published by the Free
    # Software Foundation; either version 2.1 of the License, or (at your option)
    # any later version.
    #
    # Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
    # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
    # A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
    # details.
    #
    # You should have received a copy of the GNU Lesser General Public License
    # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
    # 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA.
    
    
    import base64
    from binascii import hexlify
    import getpass
    import os
    import select
    import socket
    import sys
    import time
    import traceback
    from paramiko.py3compat import input
    from audit import models
    import paramiko
    
    try:
        import interactive
    except ImportError:
        from . import interactive
    
    
    def manual_auth(t, username, password):
        # default_auth = 'p'
        # auth = input('Auth by (p)assword, (r)sa key, or (d)ss key? [%s] ' % default_auth)
        # if len(auth) == 0:
        #     auth = default_auth
        #
        # if auth == 'r':
        #     default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_rsa')
        #     path = input('RSA key [%s]: ' % default_path)
        #     if len(path) == 0:
        #         path = default_path
        #     try:
        #         key = paramiko.RSAKey.from_private_key_file(path)
        #     except paramiko.PasswordRequiredException:
        #         password = getpass.getpass('RSA key password: ')
        #         key = paramiko.RSAKey.from_private_key_file(path, password)
        #     t.auth_publickey(username, key)
        # elif auth == 'd':
        #     default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_dsa')
        #     path = input('DSS key [%s]: ' % default_path)
        #     if len(path) == 0:
        #         path = default_path
        #     try:
        #         key = paramiko.DSSKey.from_private_key_file(path)
        #     except paramiko.PasswordRequiredException:
        #         password = getpass.getpass('DSS key password: ')
        #         key = paramiko.DSSKey.from_private_key_file(path, password)
        #     t.auth_publickey(username, key)
        # else:
        # pw = getpass.getpass('Password for %s@%s: ' % (username, hostname))
        t.auth_password(username, password)
    
    
    def ssh_session(bind_host_user, user_obj):
        # now connect
        hostname = bind_host_user.host.ip_addr #自动输入 主机名
        port = bind_host_user.host.port        #端口
        username = bind_host_user.host_user.username
        password = bind_host_user.host_user.password
    
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) #生成socket连接
            sock.connect((hostname, port))
        except Exception as e:
            print('*** Connect failed: ' + str(e))
            traceback.print_exc()
            sys.exit(1)
    
        try:
            t = paramiko.Transport(sock) #使用paramiko的方法去连接服务器执行命令!
            try:
                t.start_client()
            except paramiko.SSHException:
                print('*** SSH negotiation failed.')
                sys.exit(1)
    
            try:
                keys = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
            except IOError:
                try:
                    keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts'))
                except IOError:
                    print('*** Unable to open host keys file')
                    keys = {}
    
            # check server's host key -- this is important.
            key = t.get_remote_server_key()
            if hostname not in keys:
                print('*** WARNING: Unknown host key!')
            elif key.get_name() not in keys[hostname]:
                print('*** WARNING: Unknown host key!')
            elif keys[hostname][key.get_name()] != key:
                print('*** WARNING: Host key has changed!!!')
                sys.exit(1)
            else:
                print('*** Host key OK.')
    
            if not t.is_authenticated():
                manual_auth(t, username, password) #密码校验
            if not t.is_authenticated():
                print('*** Authentication failed. :(')
                t.close()
                sys.exit(1)
    
            chan = t.open_session()
            chan.get_pty()  # terminal
            chan.invoke_shell()
            print('*** Here we go!
    ')
    
            session_obj = models.SessionLog.objects.create(account=user_obj.account,
                                                           host_user_bind=bind_host_user)
            interactive.interactive_shell(chan, session_obj)#开始进入交换模式·
            chan.close()
            t.close()
    
        except Exception as e:
            print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
            traceback.print_exc()
            try:
                t.close()
            except:
                pass
            sys.exit(1)
    ssh_interactive.py
    # Copyright (C) 2003-2007  Robey Pointer <robeypointer@gmail.com>
    #
    # This file is part of paramiko.
    #
    # Paramiko is free software; you can redistribute it and/or modify it under the
    # terms of the GNU Lesser General Public License as published by the Free
    # Software Foundation; either version 2.1 of the License, or (at your option)
    # any later version.
    #
    # Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
    # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
    # A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
    # details.
    #
    # You should have received a copy of the GNU Lesser General Public License
    # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
    # 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA.
    
    
    import socket
    import sys
    from paramiko.py3compat import u
    from audit import models
    # windows does not have termios...
    try:
        import termios
        import tty
        has_termios = True
    except ImportError:
        has_termios = False
    
    
    def interactive_shell(chan,session_obj):
        if has_termios: #
            posix_shell(chan,session_obj) #unix 通用协议标准
        else:
            windows_shell(chan)
    
    
    def posix_shell(chan,session_obj):
        import select
        
        oldtty = termios.tcgetattr(sys.stdin)
        try:
            tty.setraw(sys.stdin.fileno())
            tty.setcbreak(sys.stdin.fileno())
            chan.settimeout(0.0)
            flag = False
            cmd = ''
            while True: #开始输入命令
                r, w, e = select.select([chan, sys.stdin], [], []) #循环检测 输入、输出、错误,有反应就返回,没有就一直夯住!
    
                if chan in r:#远程 由返回 命令结果
                    try:
                        x = u(chan.recv(1024))
                        if len(x) == 0:
                            sys.stdout.write('
    *** EOF
    ')
                            break
                        if flag: #如果用户输入的Tab补全,服务器端返回
                            cmd += x
                            flag = False
                        sys.stdout.write(x)
                        sys.stdout.flush()
                    except socket.timeout:
                        pass
    
    
                if sys.stdin in r: #本地输入
                    x = sys.stdin.read(1) #输入1个字符就发送远程服务器
                    if len(x) == 0:
                        break
                    if x == '
    ': #回车·
                        models.AuditLog.objects.create(session=session_obj,cmd=cmd)
                        cmd = ''
                    elif x == '	':#tab 本地1个字符+远程返回的
                        flag = True
                    else:
                        cmd += x
                    chan.send(x) #发送本地输入 到远程服务器
    
        finally:
            termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)
    
        
    # thanks to Mike Looijmans for this code
    def windows_shell(chan):
        import threading
    
        sys.stdout.write("Line-buffered terminal emulation. Press F6 or ^Z to send EOF.
    
    ")
            
        def writeall(sock):
            while True:
                data = sock.recv(256)
                if not data:
                    sys.stdout.write('
    *** EOF ***
    
    ')
                    sys.stdout.flush()
                    break
                sys.stdout.write(data)
                sys.stdout.flush()
            
        writer = threading.Thread(target=writeall, args=(chan,))
        writer.start()
            
        try:
            while True:
                d = sys.stdin.read(1)
                if not d:
                    break
                chan.send(d)
        except EOFError:
            # user hit ^Z or F6
            pass
    interactive.py

     程序流程:用户界面---------->ssh自动输入用户&登录密码---------->进入shell命令交互模式

    知识点:

    1对1:      1个 对应  1个   (1个女人嫁给了1个男人,生活慢慢平淡下来,)

    1对多:      1个 对应  N个   (这个女人隐瞒丈夫相继出轨了N个男人,这个男人发现老婆出轨了,很愤懑)

    多对多:     双方都存在1对多关系 (也相继找了N个女情人,而这些女情人中就有他老婆出轨男人的老婆,故事结束。)

    感悟:

    这个故事很混乱! 怎么设计男、女表结构?  其实在做数据库表关系设计的时候,纠结2张表到底需要设计成什么关系?到不如加几张关系绑定表!

    完全是出于  你的程序在允许的过程中到底 要向用户展示什么信息? 而决定的!

    web页面使用堡垒机方式:

    web开发模式

    1.MTV/MVC 前后端杂交模式;(面向公司内部OA)

    优势:简单,一人全栈;

    缺陷:前后端耦合性高,性能低、单点压力

    2.前后端分离(面向大众用户)

    优势:前、后端开发人员商定好接口和数据格式,并行开发,效率高;解决了后端独自渲染模板的压力;

    缺陷:招前端得花钱

    3.hostlist 展示主机组和主机

      <div class="panel col-lg-3">
                <div class="panel-heading">
                    <h3 class="panel-title">主机组</h3>
                </div>
                <div class="panel-body">
                    <ul class="list-group">
                    {% for group in  request.user.account.host_groups.all %}
    
                        <li class="list-group-item " onclick="GetHostlist({{ group.id }},this)"><span class="badge badge-success">{{ group.host_user_binds.count }}</span>{{ group.name }}</li>
                    {% endfor %}
                        <li class="list-group-item " onclick="GetHostlist(-1,this)"> <span class="badge badge-success">{{ request.user.account.host_user_binds.count }}</span>未分组主机</li>
    
                    </ul>
                </div>
            </div>
    在标签上绑定事件
    <script>
    
    function GetHostlist(gid,self) {
    
        $.get("{% url 'get_host_list' %}",{'gid':gid},function(callback){
    
            var data  = JSON.parse(callback);
            console.log(data)
            var trs = ''
            $.each(data,function (index,i) {
                var tr = "<tr><td>" + i.host__hostname + "</td><td>" + i.host__ip_addr +"</td><td>" + i.host__idc__name
                        +"</td><td>" + i.host__port  + "</td><td>" + i.host_user__username+ "</td><td>Login</td></tr>";
                trs += tr
    
            })
            $("#hostlist").html(trs);
    
    
    
        });//end get
        $(self).addClass("active").siblings().removeClass('active');
    
    }
    
    </script>
    通过ajax向后端请求数据

    知识点:

    如果给标签绑定事件,需要传参数,可以直接在标签直接绑定。

    url(r'^get_tocken$', views.get_tocken, name="get_tocken"),
    Django路由别名
    function GetToken(self,bind_host_id) {
        $.post(
            '{% url "get_tocken" %}',     //通过url别名渲染url
            {'bind_host_id':bind_host_id,'csrfmiddlewaretoken':"{{ csrf_token }}"},//请求携带的参数
            function (callback) {          //回调函数
                console.log(callback)
            }
    
            )
    }
    Django模板语言
    @login_required
    def get_token(request):
        bind_host_id=request.POST.get('bind_host_id')
        time_obj = datetime.datetime.now() - datetime.timedelta(seconds=300)  # 5mins ago
        exist_token_objs = models.Token.objects.filter(account_id=request.user.account.id,
                                                       host_user_bind_id=bind_host_id,
                                                       date__gt=time_obj)
        if exist_token_objs:  # has token already
            token_data = {'token': exist_token_objs[0].val}
        else:
            token_val=''.join(random.sample(string.ascii_lowercase+string.digits,8))
    
            token_obj=models.Token.objects.create(
                host_user_bind_id=bind_host_id,
                account=request.user.account,
                val=token_val)
            token_data={"token":token_val}
    
        return HttpResponse(json.dumps(token_data))
    生成5分钟之内生效的token

    4.点击主机登录,通过Shellinabox 插件以web页面的形式远程登录Linux主机;

     4.0 安装sehllinabox

    yum install git openssl-devel pam-devel zlib-devel autoconf automake libtool
    
    git clone https://github.com/shellinabox/shellinabox.git && cd shellinabox
    
    autoreconf -i
    
    ./configure && make
    
    make install
    
    shellinaboxd -b -t  //-b选项代表在后台启动,-t选项表示不使用https方式启动,默认以nobody用户身份,监听TCP4200端口
    
    netstat -ntpl |grep shell

    5.django结合sehll inabox

    5.1:用户在Django的hostlist页面点击生成tocken(绑定了account+host_bind_user),记录到数据库。

    5.2: 用户在Django的hostlist页面 login跳转至 sehll inabox由于修改了bashrc跳转之后,就会执行python用户交互程序,python用户交互程序 提示用户输入 token;

    5.3: 用户输入token之后,python 用户交互程序去数据库查询token,进而查询到host_bind_user的ip、用户、密码,调用paramiko的demo.py自动输入ip、用户、密码进入shell交互界面;

    from django.db import models
    from django.contrib.auth.models import  User
    # Create your models here.
    
    
    class IDC(models.Model):
        name = models.CharField(max_length=64,unique=True)
        def __str__(self):
            return self.name
    
    class Host(models.Model):
        """存储所有主机信息"""
        hostname = models.CharField(max_length=64,unique=True)
        ip_addr = models.GenericIPAddressField(unique=True)
        port = models.IntegerField(default=22)
        idc = models.ForeignKey("IDC")
        #host_groups = models.ManyToManyField("HostGroup")
        #host_users = models.ManyToManyField("HostUser")
        enabled = models.BooleanField(default=True)
    
        def __str__(self):
            return "%s-%s" %(self.hostname,self.ip_addr)
    
    class HostGroup(models.Model):
        """主机组"""
        name = models.CharField(max_length=64,unique=True)
        host_user_binds  = models.ManyToManyField("HostUserBind")
        def __str__(self):
            return self.name
    
    
    class HostUser(models.Model):
        """存储远程主机的用户信息
        root 123
        root abc
        root sfsfs
        """
        auth_type_choices = ((0,'ssh-password'),(1,'ssh-key'))
        auth_type = models.SmallIntegerField(choices=auth_type_choices)
        username = models.CharField(max_length=32)
        password = models.CharField(blank=True,null=True,max_length=128)
    
        def __str__(self):
            return "%s-%s-%s" %(self.get_auth_type_display(),self.username,self.password)
    
        class Meta:
            unique_together = ('username','password')
    
    
    class HostUserBind(models.Model):
        """绑定主机和用户"""
        host = models.ForeignKey("Host")
        host_user = models.ForeignKey("HostUser")
    
        def __str__(self):
            return "%s-%s" %(self.host,self.host_user)
    
        class Meta:
            unique_together = ('host','host_user')
    
    
    class AuditLog(models.Model):
        """审计日志"""
        session = models.ForeignKey("SessionLog")
        cmd = models.TextField()
        date = models.DateTimeField(auto_now_add=True)
        def __str__(self):
            return "%s-%s" %(self.session,self.cmd)
    
    
    class SessionLog(models.Model):
        account = models.ForeignKey("Account")
        host_user_bind = models.ForeignKey("HostUserBind")
        start_date = models.DateTimeField(auto_now_add=True)
        end_date = models.DateTimeField(blank=True,null=True)
    
        def __str__(self):
            return "%s-%s" %(self.account,self.host_user_bind)
    
    
    class Account(models.Model):
        """堡垒机账户
        1. 扩展
        2. 继承
        user.account.host_user_bind
        """
    
        user = models.OneToOneField(User)
        name = models.CharField(max_length=64)
    
        host_user_binds = models.ManyToManyField("HostUserBind",blank=True)
        host_groups = models.ManyToManyField("HostGroup",blank=True)
    
    
    
    class Token(models.Model):
        host_user_bind = models.ForeignKey("HostUserBind")
        val = models.CharField(max_length=128,unique=True)
        account = models.ForeignKey("Account")
        expire = models.IntegerField("超时时间(s)",default=300)
        date = models.DateTimeField(auto_now_add=True)
        def __str__(self):
            return "%s-%s" %(self.host_user_bind,self.val)
    models.py
    @login_required
    def get_token(request):
        bind_host_id=request.POST.get('bind_host_id')
        time_obj = datetime.datetime.now() - datetime.timedelta(seconds=300)  # 5mins ago
        exist_token_objs = models.Token.objects.filter(account_id=request.user.account.id,
                                                       host_user_bind_id=bind_host_id,
                                                       date__gt=time_obj)
        if exist_token_objs:  # has token already
            token_data = {'token': exist_token_objs[0].val}
        else:
            token_val=''.join(random.sample(string.ascii_lowercase+string.digits,8))
    
            token_obj=models.Token.objects.create(
                host_user_bind_id=bind_host_id,
                account=request.user.account,
                val=token_val)
            token_data={"token":token_val}
    
        return HttpResponse(json.dumps(token_data))
    View.py
    {% extends 'index.html' %}
    
    
    
    {% block content-container %}
        <div id="page-title">
            <h1 class="page-header text-overflow">主机列表</h1>
    
            <!--Searchbox-->
            <div class="searchbox">
                <div class="input-group custom-search-form">
                    <input type="text" class="form-control" placeholder="Search..">
                    <span class="input-group-btn">
                        <button class="text-muted" type="button"><i class="pli-magnifi-glass"></i></button>
                    </span>
                </div>
            </div>
        </div>
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <!--End page title-->
            <!--Breadcrumb-->
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <ol class="breadcrumb">
            <li><a href="#">Home</a></li>
            <li><a href="#">Library</a></li>
            <li class="active">主机列表</li>
        </ol>
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <!--End breadcrumb-->
    
        <div id="page-content">
    
            <div class="panel col-lg-3">
                <div class="panel-heading">
                    <h3 class="panel-title">主机组</h3>
                </div>
                <div class="panel-body">
                    <ul class="list-group">
                    {% for group in  request.user.account.host_groups.all %}
    
                        <li class="list-group-item " onclick="GetHostlist({{ group.id }},this)"><span class="badge badge-success">{{ group.host_user_binds.count }}</span>{{ group.name }}</li>
                    {% endfor %}
                        <li class="list-group-item " onclick="GetHostlist(-1,this)"> <span class="badge badge-success">{{ request.user.account.host_user_binds.count }}</span>未分组主机</li>
    
                    </ul>
                </div>
            </div>
            <div class="panel col-lg-9">
                <div class="panel-heading">
                    <h3 class="panel-title">主机列表</h3>
                </div>
                <div class="panel-body">
    
                    <div class="table-responsive">
                        <table class="table table-striped">
                            <thead>
                                <tr>
                                    <th>Hostname</th>
                                    <th>IP</th>
                                    <th>IDC</th>
                                    <th>Port</th>
                                    <th>Username</th>
                                    <th>Login</th>
                                    <th>Token</th>
                                </tr>
                            </thead>
                            <tbody id="hostlist">
    {#                            <tr>#}
    {#                                <td><a href="#fakelink" class="btn-link">Order #53451</a></td>#}
    {#                                <td>Scott S. Calabrese</td>#}
    {#                                <td>$24.98</td>#}
    {#                            </tr>#}
    
                            </tbody>
                        </table>
                    </div>
    
                </div>
            </div>
    
        </div>
    
    
    
    <script>
    
    function GetHostlist(gid,self) {
    
        $.get("{% url 'get_host_list' %}",{'gid':gid},function(callback){
    
            var data  = JSON.parse(callback);
            console.log(data);
            var trs = '';
            $.each(data,function (index,i) {
                var tr = "<tr><td>" + i.host__hostname + "</td><td>" + i.host__ip_addr +"</td><td>" + i.host__idc__name
                        +"</td><td>" + i.host__port  + "</td><td>" + i.host_user__username+ "</td><td><a class='btn btn-sm btn-info' onclick=GetToken(this,'"+i.id +"')>Token</a><a href='http://192.168.226.135:4200/' class='btn btn-sm btn-info'')>login</a></td><td ></td></tr>";
                trs += tr
    
            });
            $("#hostlist").html(trs);
    
    
    
        });//end get
        $(self).addClass("active").siblings().removeClass('active');
    
    }
    
    function GetToken(self,bind_host_id) {
        $.post(
            '{% url "get_token" %}',     //通过url别名渲染url
            {'bind_host_id':bind_host_id,'csrfmiddlewaretoken':"{{ csrf_token }}"},//请求携带的参数
            function (callback) {          //回调函数
                console.log(callback);
                var data = JSON.parse(callback); //django响应的数据
                $(self).parent().next().text(data.token);
            }
    
            )
    }
    
    
    
    </script>
    {% endblock %}
    hostlist.html
    import subprocess,random,string,datetime
    from django.contrib.auth import authenticate
    from django.conf import settings
    from audit import models
    from audit.backend import ssh_interactive
    
    class UserShell(object):
        """用户登录堡垒机后的shell"""
    
        def __init__(self,sys_argv):
            self.sys_argv = sys_argv
            self.user = None
    
        def auth(self):
    
            count = 0
            while count < 3:
                username = input("username:").strip()
                password = input("password:").strip()
                user = authenticate(username=username,password=password)
                #None 代表认证不成功
                #user object ,认证对象 ,user.name
                if not user:
                    count += 1
                    print("Invalid username or password!")
                else:
                    self.user = user
                    return  True
            else:
                print("too many attempts.")
    
        def token_auth(self):
            count = 0
            while count < 3:
                user_input = input("请输入token:").strip()
                if len(user_input) == 0:
                    return
                if len(user_input) != 8:
                    print("token length is 8")
                else:
                    time_obj = datetime.datetime.now() - datetime.timedelta(seconds=300)  # 5mins ago
                    token_obj = models.Token.objects.filter(val=user_input, date__gt=time_obj).first()
                    if token_obj:
                        if token_obj.val == user_input:  # 口令对上了
                            self.user = token_obj.account.user #进入交互式shll需要用户认证!
                            return token_obj
                count+=1
        def start(self):
            """启动交互程序"""
            token_obj = self.token_auth()
            if token_obj:
                ssh_interactive.ssh_session(token_obj.host_user_bind, self.user)
                exit()
            if self.auth():
                #print(self.user.account.host_user_binds.all()) #select_related()
                while True:
                    host_groups = self.user.account.host_groups.all()
                    for index,group in enumerate(host_groups):
                        print("%s.	%s[%s]"%(index,group,group.host_user_binds.count()))
                    print("%s.	未分组机器[%s]"%(len(host_groups),self.user.account.host_user_binds.count()))
                    try:
                        choice = input("select group>:").strip()
                        if choice.isdigit():
                            choice = int(choice)
                            host_bind_list = None
                            if choice >=0 and choice < len(host_groups):
                                selected_group = host_groups[choice]
                                host_bind_list = selected_group.host_user_binds.all()
                            elif choice == len(host_groups): #选择的未分组机器
                                #selected_group = self.user.account.host_user_binds.all()
                                host_bind_list = self.user.account.host_user_binds.all()
                            if host_bind_list:
                                while True:
                                    for index,host in enumerate(host_bind_list):
                                        print("%s.	%s"%(index,host,))
                                    choice2 = input("select host>:").strip()
                                    if choice2.isdigit():
                                        choice2 = int(choice2)
                                        if choice2 >=0 and choice2 < len(host_bind_list):
                                            selected_host = host_bind_list[choice2]
    
                                            ssh_interactive.ssh_session(selected_host,self.user)
    
    
                                            # s = string.ascii_lowercase +string.digits
                                            # random_tag = ''.join(random.sample(s,10))
                                            # session_obj = models.SessionLog.objects.create(account=self.user.account,host_user_bind=selected_host)
                                            #
                                            # cmd = "sshpass -p %s /usr/local/openssh/bin/ssh %s@%s -p %s -o StrictHostKeyChecking=no -Z %s" %(selected_host.host_user.password,selected_host.host_user.username,selected_host.host.ip_addr,selected_host.host.port ,random_tag)
                                            # #start strace ,and sleep 1 random_tag, session_obj.id
                                            # session_tracker_script = "/bin/sh %s %s %s " %(settings.SESSION_TRACKER_SCRIPT,random_tag,session_obj.id)
                                            #
                                            # session_tracker_obj =subprocess.Popen(session_tracker_script, shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
                                            #
                                            # ssh_channel = subprocess.run(cmd,shell=True)
                                            # print(session_tracker_obj.stdout.read(), session_tracker_obj.stderr.read())
                                            #
                                    elif choice2 == 'b':
                                        break
    
                    except KeyboardInterrupt as e :
                        pass
    user_interactive.py

     

     

     三、通过堡垒机批量执行Linux命令

    1.批量执行命令前端页面

    {% extends 'index.html' %}
    
    
    
    {% block content-container %}
    {#    {% csrf_token %}#}
        <div id="page-title">
            <h1 class="page-header text-overflow">主机列表</h1>
    
            <!--Searchbox-->
            <div class="searchbox">
                <div class="input-group custom-search-form">
                    <input type="text" class="form-control" placeholder="Search..">
                    <span class="input-group-btn">
                        <button class="text-muted" type="button"><i class="pli-magnifi-glass"></i></button>
                    </span>
                </div>
            </div>
        </div>
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <!--End page title-->
            <!--Breadcrumb-->
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <ol class="breadcrumb">
            <li><a href="#">Home</a></li>
            <li><a href="#">Library</a></li>
            <li class="active">主机列表</li>
        </ol>
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <!--End breadcrumb-->
    
        <div id="page-content">
            <div class="panel col-lg-3">
                <div class="panel-heading">
                    <h3 class="panel-title">主机组 <span id="selected_hosts"></span></h3>
                </div>
                <div class="panel-body">
    
                    <ul class="list-group" id="host_groups">
                    {% for group in  request.user.account.host_groups.all %}
    
                        <li class="list-group-item " ><span class="badge badge-success">{{ group.host_user_binds.count }}</span>
                            <input type="checkbox" onclick="CheckAll(this)">
                            <a onclick="DisplayHostList(this)">{{ group.name }}</a>  <!--点击组名,组名下的 主机列表通过toggleclass 展示/隐藏 -->
                            <ul class="hide">
                                {% for bind_host in group.host_user_binds.all %}
                                    <li><input onclick="ShowCheckedHostCount()" type="checkbox" value="{{ bind_host.id }}">{{ bind_host.host.ip_addr }}</li>
                                {% endfor %}
                            </ul>
                        </li>
    
                    {% endfor %}
                        <li class="list-group-item " > <span class="badge badge-success">{{ request.user.account.host_user_binds.count }}</span>
                           <input type="checkbox" onclick="CheckAll(this)">
                            <a onclick="DisplayHostList(this)">未分组主机</a>
                            <ul class="hide">
                                {% for bind_host in request.user.account.host_user_binds.all %}
                                    <li><input onclick="ShowCheckedHostCount()" type="checkbox" value="{{ bind_host.id }}">{{ bind_host.host.ip_addr }}</li>
                                {% endfor %}
                            </ul>
                        </li>
    
                    </ul>
    
    
    
                </div>
            </div>
    
            <div class="col-lg-9">
                <div class="panel">
                    <div class="panel-heading">
                        <h3 class="panel-title">命令</h3>
                    </div>
                    <div class="panel-body">
                        <textarea class="form-control" id="cmd"></textarea>
                        <button onclick="PostTask('cmd')" class="btn btn-info pull-right">执行</button>
                        <button  class="btn btn-danger ">终止</button>
    
                    </div>
                </div>
                <div class="panel">
                    <div class="panel-heading">
                        <h3 class="panel-title">任务结果</h3>
                    </div>
                    <div class="panel-body">
    
                        <div id="task_result">
                    </div>
                </div>
            </div>
    
            </div>
        </div>
    
    
    <script>
        function  DisplayHostList(self) {
            $(self).next().toggleClass("hide");
        }
    
        function CheckAll(self){
            console.log($(self).prop('checked'));
            $(self).parent().find("ul :checkbox").prop('checked',$(self).prop('checked'));
    
            ShowCheckedHostCount()
        }
    
        function ShowCheckedHostCount(){
            var selected_host_count = $("#host_groups ul").find(":checked").length;
            console.log(selected_host_count);
            $("#selected_hosts").text(selected_host_count);
            return selected_host_count
        }
    
    
    {#    function GetTaskResult(task_id) {#}
    {#        $.getJSON("{% url 'get_task_result' %}",{'task_id':task_id},function(callback){#}
    {##}
    {#            console.log(callback);#}
    {##}
    {#            var result_ele = '';#}
    {#            $.each(callback,function (index,i) {#}
    {#                var p_ele = "<p>" + i.host_user_bind__host__hostname + "(" +i.host_user_bind__host__ip_addr +") ------" +#}
    {#                    i.status + "</p>";#}
    {#                var res_ele = "<pre>" + i.result +"</pre>";#}
    {##}
    {#                var single_result = p_ele + res_ele;#}
    {#                result_ele += single_result#}
    {#            });#}
    {##}
    {#            $("#task_result").html(result_ele)#}
    {##}
    {##}
    {#        });//end getJSON#}
    {##}
    {#    }#}
    
    
        function  PostTask(task_type) {
            //1. 验证主机列表已选,命令已输入
            //2. 提交任务到后台
            var selected_host_ids = [];
            var selected_host_eles = $("#host_groups ul").find(":checked");
            $.each(selected_host_eles,function (index,ele) {
                selected_host_ids.push($(ele).val())
            });
            console.log(selected_host_ids);
            if ( selected_host_ids.length == 0){
                alert("主机未选择!");
                return false;
            }
            var cmd_text = $.trim($("#cmd").val());
            if ( cmd_text.length == 0){
                alert("未输入命令!");
                return false;
    
            }
    
    
            var task_data = {
                'task_type':task_type,
                'selected_host_ids': selected_host_ids,
                'cmd': cmd_text
            };
    
            $.post("{% url 'multitask' %}",{'csrfmiddlewaretoken':"{{ csrf_token }}",'task_data':JSON.stringify(task_data)},
                function(callback){
                        console.log(callback) ;// task id
                        var callback = JSON.parse(callback);
    
                        GetTaskResult(callback.task_id);
                        var result_timer = setInterval(function () {
                            GetTaskResult(callback.task_id)
                        },2000)
    
    
                } );//end post
    
        }
    </script>
    {% endblock %}
    multi_cmd.html

    2.前端收集批量执行的主机,通过ajax发送到后台

    @login_required
    def multitask(request):
        task_obj = task_handler.Task(request)
        respose=HttpResponse(json.dumps(task_obj.errors))
        if task_obj.is_valid():      # 如果验证成功
            result = task_obj.run()  #run()去选择要执行的任务类型,然后通过 getattr()去执行
            respose=HttpResponse(json.dumps({'task_id':result})) #返回数据库pk task_id
    
        return respose
    views.py

    3.后端通过is_valid方法验证数据的合法性

    4.验证失败响应前端self.errors信息,验证成功执行run()选择任务类型;

    5.选择任务类型(cmd/files_transfer)之后初始化数据库(更新Task、TaskLog表数据)

    6.cmd/files_transfer方法开启新进程(multitask_execute.py)新进程开启进程池 去执行批量命令;

    7.前端使用定时器不断去后台获取数据;

    8.程序中断按钮

    """
    Django settings for zhanggen_audit project.
    
    Generated by 'django-admin startproject' using Django 1.11.4.
    
    For more information on this file, see
    https://docs.djangoproject.com/en/1.11/topics/settings/
    
    For the full list of settings and their values, see
    https://docs.djangoproject.com/en/1.11/ref/settings/
    """
    
    import os
    
    # Build paths inside the project like this: os.path.join(BASE_DIR, ...)
    BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    
    
    # Quick-start development settings - unsuitable for production
    # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/
    
    # SECURITY WARNING: keep the secret key used in production secret!
    SECRET_KEY = '5ivlngau4a@_3y4vizrcxnnj(&vz2en#edpq%i&jr%99-xxv)&'
    
    # SECURITY WARNING: don't run with debug turned on in production!
    DEBUG = True
    
    ALLOWED_HOSTS = ['*']
    
    
    # Application definition
    
    INSTALLED_APPS = [
        'django.contrib.admin',
        'django.contrib.auth',
        'django.contrib.contenttypes',
        'django.contrib.sessions',
        'django.contrib.messages',
        'django.contrib.staticfiles',
        'audit.apps.AuditConfig',
    ]
    
    MIDDLEWARE = [
        'django.middleware.security.SecurityMiddleware',
        'django.contrib.sessions.middleware.SessionMiddleware',
        'django.middleware.common.CommonMiddleware',
        'django.middleware.csrf.CsrfViewMiddleware',
        'django.contrib.auth.middleware.AuthenticationMiddleware',
        'django.contrib.messages.middleware.MessageMiddleware',
        'django.middleware.clickjacking.XFrameOptionsMiddleware',
    ]
    
    ROOT_URLCONF = 'zhanggen_audit.urls'
    
    TEMPLATES = [
        {
            'BACKEND': 'django.template.backends.django.DjangoTemplates',
            'DIRS': [os.path.join(BASE_DIR,  'templates'),],
            'APP_DIRS': True,
            'OPTIONS': {
                'context_processors': [
                    'django.template.context_processors.debug',
                    'django.template.context_processors.request',
                    'django.contrib.auth.context_processors.auth',
                    'django.contrib.messages.context_processors.messages',
                ],
            },
        },
    ]
    
    WSGI_APPLICATION = 'zhanggen_audit.wsgi.application'
    
    
    # Database
    # https://docs.djangoproject.com/en/1.11/ref/settings/#databases
    
    DATABASES = {
        'default': {
            'ENGINE': 'django.db.backends.sqlite3',
            'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
        }
    }
    
    
    # Password validation
    # https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators
    
    AUTH_PASSWORD_VALIDATORS = [
        {
            'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
        },
        {
            'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
        },
        {
            'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
        },
        {
            'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
        },
    ]
    
    
    # Internationalization
    # https://docs.djangoproject.com/en/1.11/topics/i18n/
    
    LANGUAGE_CODE = 'en-us'
    TIME_ZONE = 'Asia/Shanghai'
    
    USE_I18N = True
    
    USE_L10N = True
    
    USE_TZ = True
    
    
    # Static files (CSS, JavaScript, Images)
    # https://docs.djangoproject.com/en/1.11/howto/static-files/
    
    
    STATIC_URL = '/static/'
    STATICFILES_DIRS=(
    os.path.join(BASE_DIR,'static'),
    )
    
    
    SESSION_TRACKER_SCRIPT=os.path.join(BASE_DIR,'audit%sbackend%ssession_check.sh')%(os.sep,os.sep) 
    SESSION_TRACKER_SCRIPT_LOG_PATH=os.path.join(BASE_DIR,'log')#日志路径
    MULTI_TASK_SCRIPT = os.path.join(BASE_DIR,'multitask_execute.py') #脚本路径
     
    CURRENT_PGID=None #进程的 pgid
    settings.py
    """zhanggen_audit URL Configuration
    
    The `urlpatterns` list routes URLs to views. For more information please see:
        https://docs.djangoproject.com/en/1.11/topics/http/urls/
    Examples:
    Function views
        1. Add an import:  from my_app import views
        2. Add a URL to urlpatterns:  url(r'^$', views.home, name='home')
    Class-based views
        1. Add an import:  from other_app.views import Home
        2. Add a URL to urlpatterns:  url(r'^$', Home.as_view(), name='home')
    Including another URLconf
        1. Import the include() function: from django.conf.urls import url, include
        2. Add a URL to urlpatterns:  url(r'^blog/', include('blog.urls'))
    """
    from django.conf.urls import url
    from django.contrib import admin
    from audit import views
    
    urlpatterns = [
        url(r'^admin/', admin.site.urls),
        url(r'^$', views.index ),
        url(r'^login/$', views.acc_login ),
        url(r'^logout/$', views.acc_logout ),
        url(r'^hostlist/$', views.host_list ,name="host_list"),
        url(r'^multitask/$', views.multitask ,name="multitask"),
        url(r'^multitask/result/$', views.multitask_result ,name="get_task_result"),
        url(r'^multitask/cmd/$', views.multi_cmd ,name="multi_cmd"),
        url(r'^multitask/file_transfer/$', views.multi_file_transfer ,name="multi_file_transfer"),
        url(r'^api/hostlist/$', views.get_host_list ,name="get_host_list"),
        url(r'^api/token/$', views.get_token ,name="get_token"),
        url(r'^api/task/file_upload/$', views.task_file_upload ,name="task_file_upload"),
        url(r'^api/task/file_download/$', views.task_file_download ,name="task_file_download"),
        url(r'^end_cmd/$', views.end_cmd,name="end_cmd"),
    
    ]
    urls.py
    from django.shortcuts import render,redirect,HttpResponse
    from django.contrib.auth import authenticate,login,logout
    from django.contrib.auth.decorators import login_required
    from django.views.decorators.csrf import csrf_exempt
    from django.conf import settings
    import signal
    
    import json,os
    from audit import models
    import random,string
    import datetime
    from audit import task_handler
    from django import conf
    import zipfile
    from wsgiref.util import FileWrapper #from django.core.servers.basehttp import FileWrapper
    
    @login_required
    def index(request):
        return render(request,'index.html')
    
    
    
    def acc_login(request):
        error = ''
        if request.method == "POST":
            username = request.POST.get('username')
            password = request.POST.get('password')
            user = authenticate(username=username,password=password)
            if user:
                login(request, user)
                return  redirect(request.GET.get('next') or  '/')
            else:
                error = "Wrong username or password!"
        return render(request,'login.html',{'error':error })
    
    
    @login_required
    def acc_logout(request):
        logout(request)
    
        return  redirect('/login/')
    
    @login_required
    def host_list(request):
    
        return render(request,'hostlist.html')
    
    
    @login_required
    def get_host_list(request):
        gid = request.GET.get('gid')
        if gid:
            if gid == '-1':#未分组
                host_list = request.user.account.host_user_binds.all()
            else:
                group_obj = request.user.account.host_groups.get(id=gid)
                host_list = group_obj.host_user_binds.all()
    
            data = json.dumps(list(host_list.values('id','host__hostname','host__ip_addr','host__idc__name','host__port',
                                    'host_user__username')))
            return HttpResponse(data)
    
    @login_required
    def get_token(request):
        bind_host_id=request.POST.get('bind_host_id')
        time_obj = datetime.datetime.now() - datetime.timedelta(seconds=300)  # 5mins ago
        exist_token_objs = models.Token.objects.filter(account_id=request.user.account.id,
                                                       host_user_bind_id=bind_host_id,
                                                       date__gt=time_obj)
        if exist_token_objs:  # has token already
            token_data = {'token': exist_token_objs[0].val}
        else:
            token_val=''.join(random.sample(string.ascii_lowercase+string.digits,8))
    
            token_obj=models.Token.objects.create(
                host_user_bind_id=bind_host_id,
                account=request.user.account,
                val=token_val)
            token_data={"token":token_val}
    
        return HttpResponse(json.dumps(token_data))
    
    
    
    @login_required
    def multi_cmd(request):
        """多命令执行页面"""
        return render(request,'multi_cmd.html')
    
    
    @login_required
    def multitask(request):
        task_obj = task_handler.Task(request)
        respose=HttpResponse(json.dumps(task_obj.errors))
        if task_obj.is_valid():      # 如果验证成功
            task_obj = task_obj.run()  #run()去选择要执行的任务类型,然后通过 getattr()去执行
            respose=HttpResponse(json.dumps({'task_id':task_obj.id,'timeout':task_obj.timeout})) #返回数据库pk task_id
    
        return respose
    
    
    @login_required
    def multitask_result(request):
        """多任务结果"""
        task_id = request.GET.get('task_id')
        # [ {
        #     'task_log_id':23.
        #     'hostname':
        #     'ipaddr'
        #     'username'
        #     'status'
        # } ]
    
    
        task_obj = models.Task.objects.get(id=task_id)
    
        results = list(task_obj.tasklog_set.values('id','status',
                                    'host_user_bind__host__hostname',
                                    'host_user_bind__host__ip_addr',
                                    'result'
                                    ))
    
        return HttpResponse(json.dumps(results))
    
    
    
    
    
    @login_required
    def multi_file_transfer(request):
        random_str = ''.join(random.sample(string.ascii_lowercase + string.digits, 8))
        #return render(request,'multi_file_transfer.html',{'random_str':random_str})
        return render(request,'multi_file_transfer.html',locals())
    
    @login_required
    @csrf_exempt
    def task_file_upload(request):
        random_str = request.GET.get('random_str')
        upload_to = "%s/%s/%s" %(conf.settings.FILE_UPLOADS,request.user.account.id,random_str)
        if not os.path.isdir(upload_to):
            os.makedirs(upload_to,exist_ok=True)
    
        file_obj = request.FILES.get('file')
        f = open("%s/%s"%(upload_to,file_obj.name),'wb')
        for chunk in file_obj.chunks():
            f.write(chunk)
        f.close()
        print(file_obj)
    
        return HttpResponse(json.dumps({'status':0}))
    
    
    
    
    def send_zipfile(request,task_id,file_path):
        """
        Create a ZIP file on disk and transmit it in chunks of 8KB,
        without loading the whole file into memory. A similar approach can
        be used for large dynamic PDF files.
        """
        zip_file_name = 'task_id_%s_files' % task_id
        archive = zipfile.ZipFile(zip_file_name , 'w', zipfile.ZIP_DEFLATED)
        file_list = os.listdir(file_path)
        for filename in file_list:
            archive.write('%s/%s' %(file_path,filename),arcname=filename)
        archive.close()
    
    
        wrapper = FileWrapper(open(zip_file_name,'rb'))
        response = HttpResponse(wrapper, content_type='application/zip')
        response['Content-Disposition'] = 'attachment; filename=%s.zip' % zip_file_name
        response['Content-Length'] = os.path.getsize(zip_file_name)
        #temp.seek(0)
        return response
    
    @login_required
    def task_file_download(request):
        task_id = request.GET.get('task_id')
        print(task_id)
        task_file_path = "%s/%s"%( conf.settings.FILE_DOWNLOADS,task_id)
        return send_zipfile(request,task_id,task_file_path)
    
    
    def end_cmd(request):
        current_task_pgid=settings.CURRENT_PGID
        os.killpg(current_task_pgid,signal.SIGKILL)
        return HttpResponse(current_task_pgid)
    views.py
    import json,subprocess,os,signal
    from audit import models
    from django.conf import settings
    from django.db.transaction import atomic
    class Task(object):
        '''  '''
        def __init__(self,request):
            self.request=request
            self.errors=[]
            self.task_data=None
    
        def is_valid(self):
            task_data=self.request.POST.get('task_data')#{"task_type":"cmd","selected_host_ids":["1","2"],"cmd":"DF"}
            if task_data:
                self.task_data=json.loads(task_data)
                self.task_type=self.task_data.get('task_type')
                if self.task_type == 'cmd':
                    selected_host_ids=self.task_data.get('selected_host_ids')
                    if selected_host_ids:
                        return True
                    self.errors.append({'invalid_argument': '命令/主机不存在'})
    
                elif self.task_type == 'files_transfer':
                    selected_host_ids =self.task_data.get('selected_host_ids')
                    pass
                    #验证文件路径
    
    
                else:
                    self.errors.append({'invalid_argument': '不支持的任务类型!'})
            self.errors.append({'invalid_data': 'task_data不存在!'})
    
        def run(self):
            task_func = getattr(self, self.task_data.get('task_type'))  #
            task_obj = task_func() #调用执行命令
            print(task_obj.pk)  # 100 #这里是任务id是自增的
            return task_obj
    
    
        @atomic #事物操作 任务信息和 子任务都要同时创建完成!
        def cmd(self):
            task_obj=models.Task.objects.create(
                task_type=0,
                account=self.request.user.account,
                content=self.task_data.get('cmd'),
            ) #1.增加批量任务信息,并返回批量任务信息的 pk
    
    
            tasklog_objs=[] #2.增加子任务信息(初始化数据库)
            host_ids = set(self.task_data.get("selected_host_ids"))  # 获取选中的主机id,并用集合去重
            for host_id in host_ids:
                tasklog_objs.append(models.TaskLog(task_id=task_obj.id,
                                   host_user_bind_id=host_id,
                                   status = 3))
            models.TaskLog.objects.bulk_create(tasklog_objs,100)  # 没100条记录 commit 1次!
    
            task_id=task_obj.pk
            cmd_str = "python3 %s %s" % (settings.MULTI_TASK_SCRIPT,task_id)  # 执行multitask.py脚本路径
            print('------------------>',cmd_str)
            multitask_obj = subprocess.Popen(cmd_str,stdout=subprocess.PIPE,shell=True,stderr=subprocess.PIPE) #新打开1个新进程
            settings.CURRENT_PGID=os.getpgid(multitask_obj.pid) #os.getpgid(multitask_obj.pid)
    
            # os.killpg(pgid=pgid,sig=signal.SIGKILL)
    
            # print(multitask_obj.stderr.read().decode('utf-8') or multitask_obj.stdout.read().decode('utf-8'))
            #print("task result :",multitask_obj.stdout.read().decode('utf-8'),multitask_obj.stderr.read().decode('utf-8'))
            # print(multitask_obj.stdout.read())
    
            # for host_id in self.task_data.get('selected_host_ids'):
            #     t=Thread(target=self.run_cmd,args=(host_id,self.task_data.get('cmd')))
            #     t.start()
    
            return task_obj
    
        def run_cmd(self,host_id,cmd):
            pass
    
        def files_transfer(self):
            pass
    task_handler.py
    {% extends 'index.html' %}
    
    
    
    {% block content-container %}
        {#    {% csrf_token %}#}
        <div id="page-title">
            <h1 class="page-header text-overflow">主机列表</h1>
    
            <!--Searchbox-->
            <div class="searchbox">
                <div class="input-group custom-search-form">
                    <input type="text" class="form-control" placeholder="Search..">
                    <span class="input-group-btn">
                        <button class="text-muted" type="button"><i class="pli-magnifi-glass"></i></button>
                    </span>
                </div>
            </div>
        </div>
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <!--End page title-->
        <!--Breadcrumb-->
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <ol class="breadcrumb">
            <li><a href="#">Home</a></li>
            <li><a href="#">Library</a></li>
            <li class="active">主机列表</li>
        </ol>
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <!--End breadcrumb-->
    
        <div id="page-content">
            <div class="panel col-lg-3">
                <div class="panel-heading">
                    <h3 class="panel-title">主机组 <span id="selected_hosts"></span></h3>
                </div>
                <div class="panel-body">
    
                    <ul class="list-group" id="host_groups">
                        {% for group in  request.user.account.host_groups.all %}
    
                            <li class="list-group-item "><span
                                    class="badge badge-success">{{ group.host_user_binds.count }}</span>
                                <input type="checkbox" onclick="CheckAll(this)">
                                <a onclick="DisplayHostList(this)">{{ group.name }}</a>
                                <!--点击组名,组名下的 主机列表通过toggleclass 展示/隐藏 -->
                                <ul class="hide">
                                    {% for bind_host in group.host_user_binds.all %}
                                        <li><input onclick="ShowCheckedHostCount()" type="checkbox"
                                                   value="{{ bind_host.id }}">{{ bind_host.host.ip_addr }}</li>
                                    {% endfor %}
                                </ul>
                            </li>
    
                        {% endfor %}
                        <li class="list-group-item "><span
                                class="badge badge-success">{{ request.user.account.host_user_binds.count }}</span>
                            <input type="checkbox" onclick="CheckAll(this)">
                            <a onclick="DisplayHostList(this)">未分组主机</a>
                            <ul class="hide">
                                {% for bind_host in request.user.account.host_user_binds.all %}
                                    <li><input onclick="ShowCheckedHostCount()" type="checkbox"
                                               value="{{ bind_host.id }}">{{ bind_host.host.ip_addr }}</li>
                                {% endfor %}
                            </ul>
                        </li>
    
                    </ul>
    
    
                </div>
            </div>
    
            <div class="col-lg-9">
                <div class="panel">
                    <div class="panel-heading">
                        <h3 class="panel-title">命令</h3>
                    </div>
                    <div class="panel-body">
                        <textarea class="form-control" id="cmd"></textarea>
                        <button onclick="PostTask('cmd')" class="btn btn-info pull-right">执行</button>
                        <button class="btn btn-danger" onclick="End()">终止</button>
    
                    </div>
    
                </div>
    
                <div id="task_result_panel" class="panel">
                    <div class="panel-heading">
                        <h3 class="panel-title">任务结果</h3>
                    </div>
                    <div class="panel-body">
                        <div class="progress">
                            <div id='task_progress' style=" 0%;" class="progress-bar progress-bar-info"></div>
                        </div>
                        <div id="task_result"></div>
    
                    </div>
                </div>
    
            </div>
    
    
            <script>
                function DisplayHostList(self) {
                    $(self).next().toggleClass("hide");
                }
    
                function CheckAll(self) {
                    console.log($(self).prop('checked'));
                    $(self).parent().find("ul :checkbox").prop('checked', $(self).prop('checked'));
    
                    ShowCheckedHostCount()
                }
    
                function ShowCheckedHostCount() {
                    var selected_host_count = $("#host_groups ul").find(":checked").length;
                    console.log(selected_host_count);
                    $("#selected_hosts").text(selected_host_count);
                    return selected_host_count
                }
    
    
                function GetTaskResult(task_id, task_timeout) {
                    $.getJSON("{% url 'get_task_result' %}", {'task_id': task_id}, function (callback) {
                            console.log(callback);
                            var result_ele = '';
                            var all_task_finished = true;   //全部完成flag
                            var finished_task_count = 0;   //已完成的任务数量
                            $.each(callback, function (index, i) {
                                var p_ele = "<p>" + i.host_user_bind__host__hostname + "(" + i.host_user_bind__host__ip_addr + ") ------" +
                                    i.status + "</p>";
                                var res_ele = "<pre>" + i.result + "</pre>"; //<pre> 标签按后端格式显示数据
    
                                var single_result = p_ele + res_ele;
                                result_ele += single_result;
    
                                if (i.status == 3) {
                                    all_task_finished = false;
                                } else {
                                    //task not finished yet
                                    finished_task_count += 1;
    
                                }
    
                            });
    
                            if (task_timeout_counter < task_timeout) {
                                task_timeout_counter += 2;
                            }
                            else {
                                all_task_finished = true
                            }
                            if (all_task_finished) {   //完成!
    
                                clearInterval(result_timer);
                                 var unexecuted =callback.length-finished_task_count;
                                $.niftyNoty({   //提示超时
                                    type: 'danger',
                                    container: '#task_result_panel',
                                    html: '<h4 id="Prompt">'+'执行:'+callback.length +'  '+ '完成:'+finished_task_count+'  '+'失败:'+ unexecuted +'</h4>',
                                    closeBtn: false
                                });
                                console.log("timmer canceled....")
                            }
                            $("#task_result").html(result_ele);
    
                            var total_finished_percent = parseInt(finished_task_count / callback.length * 100);
                            $("#task_progress").text(total_finished_percent + "%");
                            $("#task_progress").css("width", total_finished_percent + "%");
    
    
                        }
                    )
                    ;//end getJSON
    
                }
    
    
                function PostTask(task_type) {
                    //1. 验证主机列表已选,命令已输入
                    //2. 提交任务到后台
                    $('.alert').remove();
                    var selected_host_ids = [];
                    var selected_host_eles = $("#host_groups ul").find(":checked");
                    $.each(selected_host_eles, function (index, ele) {
                        selected_host_ids.push($(ele).val())
                    });
                    console.log(selected_host_ids);
                    if (selected_host_ids.length == 0) {
                        alert("主机未选择!");
                        return false;
                    }
                    var cmd_text = $.trim($("#cmd").val());
                    if (cmd_text.length == 0) {
                        alert("未输入命令!");
                        return false;
    
                    }
    
    
                    var task_data = {
                        'task_type': task_type,
                        'selected_host_ids': selected_host_ids,
                        'cmd': cmd_text
                    };
    
                    $.post("{% url 'multitask' %}", {
                            'csrfmiddlewaretoken': "{{ csrf_token }}",
                            'task_data': JSON.stringify(task_data)
                        },
                        function (callback) {
                            console.log(callback);// task id
                            var callback = JSON.parse(callback);
    
                            task_timeout_counter = 0;// add 2 during each call of GetTaskResult
    
                            GetTaskResult(callback.task_id, callback.timeout); //那批量任务ID 去获取子任务的进展!那超时时间做对比
    
                            result_timer = setInterval(function () {
                                GetTaskResult(callback.task_id, callback.timeout)
                            }, 2000);
    
                            //diplay download file btn
                            $("#file-download-btn").removeClass("hide").attr('href', "{% url 'task_file_download' %}?task_id=" + callback.task_id);
    
    
                        });//end post
    
    
                }
    
                function End(){
                     $.getJSON("{% url 'end_cmd' %}", function (callback) {
                         console.log(callback)
                     })
                }
            </script>
    {% endblock %}
    multi_cmd.html
    import time
    import sys,os
    import multiprocessing
    import paramiko
    
    def cmd_run(tasklog_id,cmd_str):
        try:
            import django
            django.setup()
            from audit import models
            tasklog_obj = models.TaskLog.objects.get(id=tasklog_id)
            print(tasklog_obj, cmd_str)
            ssh = paramiko.SSHClient()
            ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            ssh.connect(tasklog_obj.host_user_bind.host.ip_addr,
                        tasklog_obj.host_user_bind.host.port,
                        tasklog_obj.host_user_bind.host_user.username,
                        tasklog_obj.host_user_bind.host_user.password,
                        timeout=15) #配置超时时间15秒!
            stdin, stdout, stderr = ssh.exec_command(cmd_str)
            result = stdout.read() + stderr.read()
            print('---------%s--------' % tasklog_obj.host_user_bind)
            print(result)
            ssh.close()
            tasklog_obj.result = result or 'cmd has no result output .'#如果没有 返回结果 /出现错误
            tasklog_obj.status = 0
            tasklog_obj.save()
        except Exception as e:
            print(e)
    
    def file_transfer(bind_host_obj):
        pass
    
    
    if __name__ == '__main__':
        BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
        sys.path.append(BASE_DIR)
        os.environ.setdefault("DJANGO_SETTINGS_MODULE", "zhanggen_audit.settings")
        import django
        django.setup()
    
        from audit import models
        task_id = sys.argv[1]
        from audit import models
        task_id=int(sys.argv[1])
        # 1. 根据Taskid拿到任务对象,
        # 2. 拿到任务关联的所有主机
        # 3.  根据任务类型调用多进程 执行不同的方法
        # 4 . 每个子任务执行完毕后,自己把 子任务执行结果 写入数据库 TaskLog表
        task_obj = models.Task.objects.get(id=task_id)
        pool=multiprocessing.Pool(processes=10) #开启 1个拥有10个进程的进程池
    
    
        if task_obj.task_type == 0:
            task_func=cmd_run
        else:
            task_func =file_transfer
    
        for task_log in task_obj.tasklog_set.all(): #查询子任务信息,并更新子任务,进入执行阶段!
            pool.apply_async(task_func,args=(task_log.pk,task_obj.content)) #开启子进程,把子任务信息的pk、和 批量任务的命令传进去!
    
        pool.close()
        pool.join()
    multitask_execute.py

     四、通过堡垒机批量上传和下载文件

     1.上传本地文件至多台服务器(批量上传)

    每次访问批量上传页面上传唯一字符串

    使用filedropzone组件做批量上传ul,并限制文件大小、个数,文件提交后端时携带 唯一字符串

    后端生成   /固定上传路径/用户ID/唯一字符串/文件的路径,并写入文件;(filedropzone组件把文件拖拽进去之后,自动上传)

    前端点击执行 验证堡垒机上的用户上传路径是否合法,然后开启多进程 分别通过paramiko去发送至远程服务的路径

    """
    Django settings for zhanggen_audit project.
    
    Generated by 'django-admin startproject' using Django 1.11.4.
    
    For more information on this file, see
    https://docs.djangoproject.com/en/1.11/topics/settings/
    
    For the full list of settings and their values, see
    https://docs.djangoproject.com/en/1.11/ref/settings/
    """
    
    import os
    
    # Build paths inside the project like this: os.path.join(BASE_DIR, ...)
    BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    
    
    # Quick-start development settings - unsuitable for production
    # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/
    
    # SECURITY WARNING: keep the secret key used in production secret!
    SECRET_KEY = '5ivlngau4a@_3y4vizrcxnnj(&vz2en#edpq%i&jr%99-xxv)&'
    
    # SECURITY WARNING: don't run with debug turned on in production!
    DEBUG = True
    
    ALLOWED_HOSTS = ['*']
    
    
    # Application definition
    
    INSTALLED_APPS = [
        'django.contrib.admin',
        'django.contrib.auth',
        'django.contrib.contenttypes',
        'django.contrib.sessions',
        'django.contrib.messages',
        'django.contrib.staticfiles',
        'audit.apps.AuditConfig',
    ]
    
    MIDDLEWARE = [
        'django.middleware.security.SecurityMiddleware',
        'django.contrib.sessions.middleware.SessionMiddleware',
        'django.middleware.common.CommonMiddleware',
        'django.middleware.csrf.CsrfViewMiddleware',
        'django.contrib.auth.middleware.AuthenticationMiddleware',
        'django.contrib.messages.middleware.MessageMiddleware',
        'django.middleware.clickjacking.XFrameOptionsMiddleware',
    ]
    
    ROOT_URLCONF = 'zhanggen_audit.urls'
    
    TEMPLATES = [
        {
            'BACKEND': 'django.template.backends.django.DjangoTemplates',
            'DIRS': [os.path.join(BASE_DIR,  'templates'),],
            'APP_DIRS': True,
            'OPTIONS': {
                'context_processors': [
                    'django.template.context_processors.debug',
                    'django.template.context_processors.request',
                    'django.contrib.auth.context_processors.auth',
                    'django.contrib.messages.context_processors.messages',
                ],
            },
        },
    ]
    
    WSGI_APPLICATION = 'zhanggen_audit.wsgi.application'
    
    
    # Database
    # https://docs.djangoproject.com/en/1.11/ref/settings/#databases
    
    DATABASES = {
        'default': {
            'ENGINE': 'django.db.backends.sqlite3',
            'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
        }
    }
    
    
    # Password validation
    # https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators
    
    AUTH_PASSWORD_VALIDATORS = [
        {
            'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
        },
        {
            'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
        },
        {
            'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
        },
        {
            'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
        },
    ]
    
    
    # Internationalization
    # https://docs.djangoproject.com/en/1.11/topics/i18n/
    
    LANGUAGE_CODE = 'en-us'
    TIME_ZONE = 'Asia/Shanghai'
    
    USE_I18N = True
    
    USE_L10N = True
    
    USE_TZ = True
    
    
    # Static files (CSS, JavaScript, Images)
    # https://docs.djangoproject.com/en/1.11/howto/static-files/
    
    
    STATIC_URL = '/static/'
    STATICFILES_DIRS=(
    os.path.join(BASE_DIR,'static'),
    )
    
    
    SESSION_TRACKER_SCRIPT=os.path.join(BASE_DIR,'audit%sbackend%ssession_check.sh')%(os.sep,os.sep)
    SESSION_TRACKER_SCRIPT_LOG_PATH=os.path.join(BASE_DIR,'log')#日志路径
    MULTI_TASK_SCRIPT = os.path.join(BASE_DIR,'multitask_execute.py') #脚本路径
    
    CURRENT_PGID=None #进程的 pgid
    FILE_UPLOADS = os.path.join(BASE_DIR,'uploads')     #上传文件的堡垒机路径
    FILE_DOWNLOADS = os.path.join(BASE_DIR,'downloads') #下载文件的堡垒机路径
    配置堡垒机上传和下载文件的路径
    <script>
        function  DisplayHostList(self) {
            $(self).next().toggleClass("hide");
        }
    
        function CheckAll(self){
            console.log($(self).prop('checked'));
            $(self).parent().find("ul :checkbox").prop('checked',$(self).prop('checked'));
    
            ShowCheckedHostCount()
        }
    
        function ShowCheckedHostCount(){
            var selected_host_count = $("#host_groups ul").find(":checked").length
            console.log(selected_host_count);
            $("#selected_hosts").text(selected_host_count);
            return selected_host_count
        }
    
    
        function GetTaskResult(task_id,task_timeout) {
            $.getJSON("{% url 'get_task_result' %}",{'task_id':task_id},function(callback){
    
                console.log(callback)
    
                var result_ele = ''
                var all_task_finished = true
                var finished_task_count = 0 ;
                $.each(callback,function (index,i) {
                    var p_ele = "<p>" + i.host_user_bind__host__hostname + "(" +i.host_user_bind__host__ip_addr +") ------" +
                        i.status + "</p>";
                    var res_ele = "<pre>" + i.result +"</pre>";
    
                    var single_result = p_ele + res_ele;
                    result_ele += single_result;
    
                    //check if ths sub task is finished.
                    if ( i.status == 3){
                        all_task_finished = false;
                    }else {
                        //task not finished yet
                        finished_task_count += 1;
                    }
    
                });//end each
                //check if the task_timer_count < task_timeout, otherwise it means the task is timedout, setInterval function need to be cancelled
                if (task_timeout_counter < task_timeout){
                    // not timed out yet
                    task_timeout_counter += 2;
    
                }else {
                    all_task_finished = true; // set all task to be finished ,because it 's already reached the global timeout
    
                    $.niftyNoty({
                        type: 'danger',
                        container : '#task_result_panel',
                        html : '<h4 class="alert-title">Task timed out!</h4><p class="alert-message">The task has timed out!</p><div class="mar-top"><button type="button" class="btn btn-info" data-dismiss="noty">Close this notification</button></div>',
                        closeBtn : false
                    });
                }
    
                if ( all_task_finished){
                    clearInterval(result_timer);
                    console.log("timmer canceled....")
                }
    
    
                $("#task_result").html(result_ele);
                // set progress bar
                var total_finished_percent = parseInt(finished_task_count / callback.length * 100 );
                $("#task_progress").text(total_finished_percent+"%");
                $("#task_progress").css("width",total_finished_percent +"%");
            });//end getJSON
    
        }
    
    
        function  PostTask(task_type) {
            //1. 验证主机列表已选,命令已输入
            //2. 提交任务到后台
            var selected_host_ids = [];
            var selected_host_eles = $("#host_groups ul").find(":checked")
            $.each(selected_host_eles,function (index,ele) {
                selected_host_ids.push($(ele).val())
            });
            console.log(selected_host_ids)
            if ( selected_host_ids.length == 0){
                alert("主机未选择!")
                return false
            }
    
            if ( task_type == 'cmd'){
                var cmd_text = $.trim($("#cmd").val())
                if ( cmd_text.length == 0){
                    alert("未输入命令!")
                    return false
    
                }
            }else {
                //file_transfer
                var remote_path = $("#remote_path").val();
                if ($.trim(remote_path).length == 0){
                    alert("必须输入1个远程路径")
                    return false
                }
            }
    
    
    
            var task_data = {
                'task_type':task_type,
                'selected_host_ids': selected_host_ids,
                //'cmd': cmd_text
            };
            if ( task_type == 'cmd'){
                task_data['cmd'] =  cmd_text
    
            }else {
    
                var file_transfer_type = $("select[name='transfer-type']").val();
                task_data['file_transfer_type'] = file_transfer_type;
                task_data['random_str'] = "{{ random_str }}";
                task_data['remote_path'] = $("#remote_path").val();
    
    
            }
    
    
            $.post("{% url 'multitask' %}",{'csrfmiddlewaretoken':"{{ csrf_token }}",'task_data':JSON.stringify(task_data)},
                function(callback){
                        console.log(callback) ;// task id
                        var callback = JSON.parse(callback);
    
                        GetTaskResult(callback.task_id,callback.timeout);
                        task_timeout_counter = 0; // add 2 during each call of GetTaskResult
                        result_timer = setInterval(function () {
                            GetTaskResult(callback.task_id,callback.timeout)
                        },2000);
    
                        //diplay download file btn
                        $("#file-download-btn").removeClass("hide").attr('href', "{% url 'task_file_download' %}?task_id="+callback.task_id);
    
    
                } );//end post
    
        }
    </script>
    multi_file_transfer.html
    {% extends 'index.html' %}
    {% block extra-css %}
        <link href="/static/plugins/dropzone/dropzone.css" rel="stylesheet">
        <script src="/static/plugins/dropzone/dropzone.js"></script>
    {% endblock %}
    
    
    {% block content-container %}
        <div id="page-title">
            <h1 class="page-header text-overflow">主机列表</h1>
    
            <!--Searchbox-->
            <div class="searchbox">
                <div class="input-group custom-search-form">
                    <input type="text" class="form-control" placeholder="Search..">
                    <span class="input-group-btn">
                        <button class="text-muted" type="button"><i class="pli-magnifi-glass"></i></button>
                    </span>
                </div>
            </div>
        </div>
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <!--End page title-->
        <!--Breadcrumb-->
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <ol class="breadcrumb">
            <li><a href="#">Home</a></li>
            <li><a href="#">Library</a></li>
            <li class="active">主机列表</li>
        </ol>
        <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
        <!--End breadcrumb-->
        <div id="page-content">
            {% include 'components/hostgroups.html' %}
            <div class="col-lg-9">
                <div class="panel">
                    <div class="panel-heading">
                        <h3 class="panel-title">文件传输</h3>
                    </div>
                    <div class="panel-body">
                        <select name="transfer-type" onchange="ToggleUploadEle(this)">
                            <option value="send">发送文件到远程主机</option>
                            <option value="get">from远程主机下载文件</option>
                        </select>
    
    
                        <form id="filedropzone" class="dropzone">
    
                        </form>
                        {#                    <input type="hidden" value="{{ random_str }}" name="random_str">#}
                        <input id="remote_path" class="form-control" type="text" placeholder="远程路径">
    
                        <button id="file_count" onclick="PostTask('file_transfer')" class="btn btn-info pull-right">执行</button>
                        <button class="btn btn-danger ">终止</button>
                        <a id="file-download-btn" class="btn btn-info hide" href="">下载任务文件到本地</a>
    
    
                    </div>
                </div>
                {% include 'components/taskresult.html' %}
            </div>
    
        </div>
        </div>
    
        {% include 'components/multitask_js.html' %}
        <script>
    
            $('#filedropzone').dropzone({
                url: "{% url 'task_file_upload' %}?random_str={{ random_str }}", //必须填写
                method: "post",  //也可用put
                maxFiles: 10,//一次性上传的文件数量上限
                maxFilesize: 2, //MB
                //acceptedFiles: ".jpg,.gif,.png"//限制上传的类型
                dictMaxFilesExceeded: "您最多只能上传10个文件!",
                dictFileTooBig: "文件过大上传文件最大支持."
                /*
                init: function () {
                    this.on("success", function (file) { //文件上传成功触发事件
                        $('#file_count').attr('file_count')
                    });
                }
                */
    
            });
            Dropzone.autoDiscover = false;
    
    
            function ToggleUploadEle(self) {
    
                console.log($(self).val());
                if ($(self).val() == 'get') {
                    $(self).next().addClass("hide")
                } else {
                    $(self).next().removeClass('hide')
                }
    
            }
    
        </script>
    
    {% endblock %}
    multi_file_transfer.html
    from django.conf.urls import url
    from django.contrib import admin
    from audit import views
    
    urlpatterns = [
        url(r'^admin/', admin.site.urls),
        url(r'^$', views.index ),
        url(r'^login/$', views.acc_login ),
        url(r'^logout/$', views.acc_logout ),
        url(r'^hostlist/$', views.host_list ,name="host_list"),
        url(r'^multitask/$', views.multitask ,name="multitask"),
        url(r'^multitask/result/$', views.multitask_result ,name="get_task_result"),
        url(r'^multitask/cmd/$', views.multi_cmd ,name="multi_cmd"),
        url(r'^api/hostlist/$', views.get_host_list ,name="get_host_list"),
        url(r'^api/token/$', views.get_token ,name="get_token"),
        url(r'^multitask/file_transfer/$', views.multi_file_transfer, name="multi_file_transfer"),
        url(r'^api/task/file_upload/$', views.task_file_upload ,name="task_file_upload"),
        url(r'^api/task/file_download/$', views.task_file_download ,name="task_file_download"),
        url(r'^end_cmd/$', views.end_cmd,name="end_cmd"),
    
    ]
    urls.py
    import json,subprocess,os,signal
    from audit import models
    from django.conf import settings
    from django.db.transaction import atomic
    class Task(object):
        '''  '''
        def __init__(self,request):
            self.request=request
            self.errors=[]
            self.task_data=None
    
        def is_valid(self):
            task_data=self.request.POST.get('task_data')#{"task_type":"cmd","selected_host_ids":["1","2"],"cmd":"DF"}
            if task_data:
                self.task_data=json.loads(task_data)
                self.task_type=self.task_data.get('task_type')
                if self.task_type == 'cmd':
                    selected_host_ids=self.task_data.get('selected_host_ids')
                    if selected_host_ids:
                        return True
                    self.errors.append({'invalid_argument': '命令/主机不存在'})
    
                elif self.task_type == 'file_transfer': #
                    selected_host_ids =self.task_data.get('selected_host_ids')
                    self.task_type = self.task_data.get('task_type')
                    #验证文件路径
                    user_id=models.Account.objects.filter(user=self.request.user).first().pk
                    random_str=self.task_data.get('random_str')
                    file_path=settings.FILE_UPLOADS+os.sep+str(user_id)+os.sep+random_str
                    if os.path.isdir(file_path):
                        return True
                    if not os.path.isdir(file_path):
                        self.errors.append({'invalid_argument': '上传路径失败,请重新上传'})
                    if not selected_host_ids:
                        self.errors.append({'invalid_argument': '远程主机不存在'})
    
    
    
                else:
                    self.errors.append({'invalid_argument': '不支持的任务类型!'})
            self.errors.append({'invalid_data': 'task_data不存在!'})
    
        def run(self):
            task_func = getattr(self, self.task_data.get('task_type'))  #
            task_obj = task_func() #调用执行命令
            #print(task_obj.pk)  # 100 #这里是任务id是自增的
            return task_obj
    
    
        @atomic #事物操作 任务信息和 子任务都要同时创建完成!
        def cmd(self):
            task_obj=models.Task.objects.create(
                task_type=0,
                account=self.request.user.account,
                content=self.task_data.get('cmd'),
            ) #1.增加批量任务信息,并返回批量任务信息的 pk
    
    
            tasklog_objs=[] #2.增加子任务信息(初始化数据库)
            host_ids = set(self.task_data.get("selected_host_ids"))  # 获取选中的主机id,并用集合去重
            for host_id in host_ids:
                tasklog_objs.append(models.TaskLog(task_id=task_obj.id,
                                   host_user_bind_id=host_id,
                                   status = 3))
            models.TaskLog.objects.bulk_create(tasklog_objs,100)  # 没100条记录 commit 1次!
    
            task_id=task_obj.pk
            cmd_str = "python %s %s" % (settings.MULTI_TASK_SCRIPT,task_id)  # 执行multitask.py脚本路径
            print('------------------>',cmd_str)
            multitask_obj = subprocess.Popen(cmd_str,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE) #新打开1个新进程
            #settings.CURRENT_PGID=os.getpgid(multitask_obj.pid) #os.getpgid(multitask_obj.pid)
    
            # os.killpg(pgid=pgid,sig=signal.SIGKILL)
    
            # print(multitask_obj.stderr.read().decode('utf-8') or multitask_obj.stdout.read().decode('utf-8'))
            #print("task result :",multitask_obj.stdout.read().decode('utf-8'),multitask_obj.stderr.read().decode('utf-8'))
            # print(multitask_obj.stdout.read())
    
            # for host_id in self.task_data.get('selected_host_ids'):
            #     t=Thread(target=self.run_cmd,args=(host_id,self.task_data.get('cmd')))
            #     t.start()
    
            return task_obj
    
        @atomic  # 事物操作 任务信息和 子任务都要同时创建完成!
        def file_transfer(self):
            print(self.task_data) #{'task_type': 'file_transfer', 'selected_host_ids': ['3'], 'file_transfer_type': 'send', 'random_str': 'iuon9bhm', 'remote_path': '/'}
            task_obj = models.Task.objects.create(
                task_type=1,
                account=self.request.user.account,
                content=json.dumps(self.task_data),
            )  # 1.增加批量任务信息,并返回批量任务信息的 pk
    
            tasklog_objs = []  # 2.增加子任务信息(初始化数据库)
            host_ids = set(self.task_data.get("selected_host_ids"))  # 获取选中的主机id,并用集合去重
            for host_id in host_ids:
                tasklog_objs.append(models.TaskLog(task_id=task_obj.id,
                                                   host_user_bind_id=host_id,
                                                   status=3))
            models.TaskLog.objects.bulk_create(tasklog_objs, 100)  # 没100条记录 commit 1次!
    
            task_id = task_obj.pk
            cmd_str = "python %s %s" % (settings.MULTI_TASK_SCRIPT, task_id)  # 执行multitask.py脚本路径
            print('------------------>', cmd_str)
            multitask_obj = subprocess.Popen(cmd_str, shell=True, stdout=subprocess.PIPE,
                                             stderr=subprocess.PIPE)  # 新打开1个新进程
    
            return task_obj
    task_handler.py
    import time,json
    import sys,os
    import multiprocessing
    import paramiko
    
    def cmd_run(tasklog_id,task_obj_id,cmd_str,):
        try:
            import django
            django.setup()
            from audit import models
            tasklog_obj = models.TaskLog.objects.get(id=tasklog_id)
            print(tasklog_obj, cmd_str)
            ssh = paramiko.SSHClient()
            ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            ssh.connect(tasklog_obj.host_user_bind.host.ip_addr,
                        tasklog_obj.host_user_bind.host.port,
                        tasklog_obj.host_user_bind.host_user.username,
                        tasklog_obj.host_user_bind.host_user.password,
                        timeout=15) #配置超时时间15秒!
            stdin, stdout, stderr = ssh.exec_command(cmd_str)
            result = stdout.read() + stderr.read()
            print('---------%s--------' % tasklog_obj.host_user_bind)
            print(result)
            ssh.close()
            #修改子任务数据库结果
            tasklog_obj.result = result or 'cmd has no result output .'#如果没有 返回结果 /出现错误
            tasklog_obj.status = 0
            tasklog_obj.save()
        except Exception as e:
            print(e)
    
    def file_transfer(tasklog_id,task_id,task_content):
        import django
        django.setup()
        from django.conf import settings
        from audit import models
        tasklog_obj = models.TaskLog.objects.get(id=tasklog_id)
        try:
            print('task contnt:', tasklog_obj)
            task_data = json.loads(tasklog_obj.task.content)
            t = paramiko.Transport((tasklog_obj.host_user_bind.host.ip_addr, tasklog_obj.host_user_bind.host.port))
            t.connect(username=tasklog_obj.host_user_bind.host_user.username, password=tasklog_obj.host_user_bind.host_user.password,)
            sftp = paramiko.SFTPClient.from_transport(t)
    
            if task_data.get('file_transfer_type') =='send':
                local_path = "%s/%s/%s" %( settings.FILE_UPLOADS,
                                           tasklog_obj.task.account.id,
                                           task_data.get('random_str'))
                print("local path",local_path)
                for file_name in os.listdir(local_path):
                    sftp.put('%s/%s' %(local_path,file_name), '%s/%s'%(task_data.get('remote_path'), file_name))
                tasklog_obj.result = "send all files done..."
    
            else:
                # 循环到所有的机器上的指定目录下载文件
                download_dir = "{download_base_dir}/{task_id}".format(download_base_dir=settings.FILE_DOWNLOADS,
                                                                      task_id=task_id)
                if not os.path.exists(download_dir):
                    os.makedirs(download_dir,exist_ok=True)
    
                remote_filename = os.path.basename(task_data.get('remote_path'))
                local_path = "%s/%s.%s" %(download_dir,tasklog_obj.host_user_bind.host.ip_addr,remote_filename)
                sftp.get(task_data.get('remote_path'),local_path )
                #remote path  /tmp/test.py
                tasklog_obj.result = 'get remote file [%s] to local done' %(task_data.get('remote_path'))
            t.close()
    
            tasklog_obj.status = 0
            tasklog_obj.save()
            # ssh = paramiko.SSHClient()
            # ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    
        except Exception as e:
            print("error :",e )
            tasklog_obj.result = str(e)
            tasklog_obj.save()
    
    
    
    
    if __name__ == '__main__':
        BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
        sys.path.append(BASE_DIR)
        os.environ.setdefault("DJANGO_SETTINGS_MODULE", "zhanggen_audit.settings")
        import django
        django.setup()
    
        from audit import models
        task_id = sys.argv[1]
        from audit import models
        task_id=int(sys.argv[1])
        # 1. 根据Taskid拿到任务对象,
        # 2. 拿到任务关联的所有主机
        # 3.  根据任务类型调用多进程 执行不同的方法
        # 4 . 每个子任务执行完毕后,自己把 子任务执行结果 写入数据库 TaskLog表
        task_obj = models.Task.objects.get(id=task_id)
    
        pool=multiprocessing.Pool(processes=10) #开启 1个拥有10个进程的进程池
    
        if task_obj.task_type == 0:
            task_func=cmd_run
        else:
            task_func =file_transfer
    
        for task_log in task_obj.tasklog_set.all(): #查询子任务信息,并更新子任务,进入执行阶段!
            pool.apply_async(task_func,args=(task_log.id,task_obj.id,task_obj.content)) #开启子进程,把子任务信息的pk、和 批量任务的命令传进去!
    
        pool.close()
        pool.join()
    multitask_execute.py

     2.从多台服务器上get文件至本地(批量下载)

     

    用户输入远程服务器文件路径,堡垒机生成本地下载路径( /下载文件路径/task_id/ip.远程文件名)

    开启多进程 通过paramiko下载远程主机的文件 到堡垒机下载路径;

    任务执行完毕前端弹出 下载文件到本地按钮 (携带?批量任务ID)

    用户点击下载文件到本地 a标签,后端获取当前批量任务的ID,把当前批量任务下载的files,打包返回给用户浏览器!

    def send_zipfile(request,task_id,file_path):
    
        zip_file_name = 'task_id_%s_files' % task_id
        archive = zipfile.ZipFile(zip_file_name , 'w', zipfile.ZIP_DEFLATED) #创建1个zip 包
    
        file_list = os.listdir(file_path) #找到堡垒机目录下 所有文件
    
        for filename in file_list:      #把所有文件写入 zip包中!
            archive.write('%s/%s' %(file_path,filename),arcname=filename)
        archive.close()
        #-------------------------------------------------------------- #文件打包完毕!
    
        wrapper = FileWrapper(open(zip_file_name,'rb')) #在内存中打开 打包好的压缩包
    
        response = HttpResponse(wrapper, content_type='application/zip') #修改Django的response的content_type
        response['Content-Disposition'] = 'attachment; filename=%s.zip' % zip_file_name #告诉流量器以 附件形式下载
        response['Content-Length'] = os.path.getsize(zip_file_name)               #文件大小
        #temp.seek(0)
        return response
    
    
    
    
    
    
    @login_required
    def task_file_download(request): #下载文件到本地
        task_id = request.GET.get('task_id')
        print(task_id)
        task_file_path = "%s/%s"%( conf.settings.FILE_DOWNLOADS,task_id)
        download_files=os.listdir(task_file_path)
        print(download_files)
        return send_zipfile(request,task_id,task_file_path) #调用打包函数
    Django响应压缩文件

    3.架构描述

    当前架构缺陷:multitask在堡垒机上开多进程,随着用户量的增长,开启的进程数量也会越多;

    未来设想:在Django 和 multitask之间增加队列,实现用户大并发!

    GitHub:https://github.com/zhanggen3714/zhanggen_audit

    GateOne安装

                                             

  • 相关阅读:
    The commands of Disk
    How to build a NFS Service
    Apache的dbutils的架构图
    Spring使用ThreadLocal解决线程安全问题
    NIO流程
    Servlet 生命周期、工作原理
    forward和redirect的区别
    笔记:Java 性能优化权威指南 第9、10、11章 GlassFish、Web应用、Web Service、EJB性能调优
    几个关于Java内存泄露方面的面试题
    MAT 检测 Java内存泄露检测
  • 原文地址:https://www.cnblogs.com/sss4/p/9280118.html
Copyright © 2020-2023  润新知