• 深入解析当下大热的前后端分离组件django-rest_framework系列三


    三剑客之认证、权限与频率组件

    认证组件

    局部视图认证

    在app01.service.auth.py:

    复制代码
    复制代码
    class Authentication(BaseAuthentication):
    
        def authenticate(self,request):
            token=request._request.GET.get("token")
            token_obj=UserToken.objects.filter(token=token).first()
            if not token_obj:
                raise exceptions.AuthenticationFailed("验证失败!")
            return (token_obj.user,token_obj)
    复制代码
    复制代码

    在views.py:

    复制代码
    复制代码
    def get_random_str(user):
        import hashlib,time
        ctime=str(time.time())
    
        md5=hashlib.md5(bytes(user,encoding="utf8"))
        md5.update(bytes(ctime,encoding="utf8"))
    
        return md5.hexdigest()
    
    
    from app01.service.auth import *
    
    from django.http import JsonResponse
    class LoginViewSet(APIView):
        authentication_classes = [Authentication,]
        def post(self,request,*args,**kwargs):
            res={"code":1000,"msg":None}
            try:
                user=request._request.POST.get("user")
                pwd=request._request.POST.get("pwd")
                user_obj=UserInfo.objects.filter(user=user,pwd=pwd).first()
                print(user,pwd,user_obj)
                if not user_obj:
                    res["code"]=1001
                    res["msg"]="用户名或者密码错误"
                else:
                    token=get_random_str(user)
                    UserToken.objects.update_or_create(user=user_obj,defaults={"token":token})
                    res["token"]=token
    
            except Exception as e:
                res["code"]=1002
                res["msg"]=e
    
            return JsonResponse(res,json_dumps_params={"ensure_ascii":False})
    复制代码
    复制代码

    备注:一个知识点:update_or_create(参数1,参数2...,defaults={‘字段’:'对应的值'}),这个方法使用于:如果对象存在,则进行更新操作,不存在,则创建数据。使用时会按照前面的参数进行filter,结果为True,则执行update  defaults中的值,否则,创建。

    使用方式: 

      使用认证组件时,自定制一个类,在自定制的类下,必须实现方法:authenticate(self,request),必须接收一个request参数,结果必须返回一个含有两个值的元组,认证失败,就抛一个异常rest_framework下的exceptions.APIException。然后在对应的视图类中添加一个属性:authentication_classes=[自定制类名]即可。

     备注:from rest_framework.exceptions import AuthenticationFailed   也可以抛出这个异常(这个异常是针对认证的异常)

    使用这个异常时,需要在自定制的类中,定制一个方法:def authenticate_header(self,request):pass,每次这样定制一个方法很麻烦,我们可以在自定义类的时候继承一个类: BaseAuthentication

    调用方式:from rest_framework.authentication import BaseAuthentication

    为什么是这个格式,我们看看这个组件内部的 实现吧!
    不管是认证,权限还是频率都发生在分发之前:

    复制代码
    class APIView(View):
            def dispatch(self, request, *args, **kwargs):
            """
            `.dispatch()` is pretty much the same as Django's regular dispatch,
            but with extra hooks for startup, finalize, and exception handling.
            """
            self.args = args
            self.kwargs = kwargs
            request = self.initialize_request(request, *args, **kwargs)
            self.request = request
            self.headers = self.default_response_headers  # deprecate?
    
            try:
                self.initial(request, *args, **kwargs)
    
                # Get the appropriate handler method
                if request.method.lower() in self.http_method_names:
                    handler = getattr(self, request.method.lower(),
                                      self.http_method_not_allowed)
                else:
                    handler = self.http_method_not_allowed
    
                response = handler(request, *args, **kwargs)
    
            except Exception as exc:
                response = self.handle_exception(exc)
    
            self.response = self.finalize_response(request, response, *args, **kwargs)
            return self.response
    复制代码

    所有的实现都发生在       self.initial(request, *args, **kwargs)        这行代码中

    复制代码
    class APIView(View):
         def initial(self, request, *args, **kwargs):
            """
            Runs anything that needs to occur prior to calling the method handler.
            """
            self.format_kwarg = self.get_format_suffix(**kwargs)
    
            # Perform content negotiation and store the accepted info on the request
            neg = self.perform_content_negotiation(request)
            request.accepted_renderer, request.accepted_media_type = neg
    
            # Determine the API version, if versioning is in use.
           #处理版本信息 version, scheme = self.determine_version(request, *args, **kwargs) request.version, request.versioning_scheme = version, scheme # Ensure that the incoming request is permitted
          # 认证 self.perform_authentication(request)
          # 权限 self.check_permissions(request)
          # 用户访问频率的限制 self.check_throttles(request)
    复制代码

    我们先看认证相关   self.perform_authentication

    复制代码
    class APIView(View):
            def perform_authentication(self, request):
            """
            Perform authentication on the incoming request.
    
            Note that if you override this and simply 'pass', then authentication
            will instead be performed lazily, the first time either
            `request.user` or `request.auth` is accessed.
            """
            request.user
    复制代码

    整个这个方法就执行了一句代码:request.user   ,这个request是新的request,是Request()类的实例对象,  .user 肯定在Request类中有一个user的属性方法。

    复制代码
    class Request(object):
        @property
        def user(self):
            """
            Returns the user associated with the current request, as authenticated
            by the authentication classes provided to the request.
            """
            if not hasattr(self, '_user'):
                with wrap_attributeerrors():
                    self._authenticate()
            return self._user
    复制代码

    这个user方法中,执行了一个 self._authenticate()方法,继续看这个方法中执行了什么:

    复制代码
    class Request(object):
            def _authenticate(self):
            """
            Attempt to authenticate the request using each authentication instance
            in turn.
            """
            for authenticator in self.authenticators:
                try:
                    user_auth_tuple = authenticator.authenticate(self)
                except exceptions.APIException:
                    self._not_authenticated()
                    raise
    
                if user_auth_tuple is not None:
                    self._authenticator = authenticator
                    self.user, self.auth = user_auth_tuple
                    return
    
            self._not_authenticated()
    复制代码

    这个方法中,循环一个东西:self.authenticators  ,这个self,是新的request,我们看看这个self.authenticators是什么。

    复制代码
    class Request(object):
        """
        Wrapper allowing to enhance a standard `HttpRequest` instance.
    
        Kwargs:
            - request(HttpRequest). The original request instance.
            - parsers_classes(list/tuple). The parsers to use for parsing the
              request content.
            - authentication_classes(list/tuple). The authentications used to try
              authenticating the request's user.
        """
    
        def __init__(self, request, parsers=None, authenticators=None,
                     negotiator=None, parser_context=None):
           
            self._request = request
            self.parsers = parsers or ()
            self.authenticators = authenticators or ()
    复制代码

    这个self.authenticators是实例化时初始化的一个属性,这个值是实例化新request时传入的参数。

    复制代码
    class APIView(View):
       
         def initialize_request(self, request, *args, **kwargs):
            """
            Returns the initial request object.
            """
            parser_context = self.get_parser_context(request)
    
            return Request(
                request,
                parsers=self.get_parsers(),
                authenticators=self.get_authenticators(),
                negotiator=self.get_content_negotiator(),
                parser_context=parser_context
            )
    复制代码

    实例化传参时,这个参数来自于self.get_authenticators(),这里的self就不是request,是谁调用它既是谁,看看吧:

    复制代码
    class APIView(View):
            def get_authenticators(self):
            """
            Instantiates and returns the list of authenticators that this view can use.
            """
            return [auth() for auth in self.authentication_classes]
    复制代码

    这个方法,返回了一个列表推导式   [auth() for auth in self.authentication_classes] ,到了这里,是不是很晕,这个self.authentication_classes又是啥,这不是我们在视图函数中定义的属性吗,值是一个列表,里面放着我们自定义认证的类

    class APIView(View):
    
        authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES

    如果我们不定义这个属性,会走默认的值  api_settings.DEFAULT_AUTHENTICATION_CLASSES

      走了一圈了,我们再走回去呗,首先,这个列表推导式中放着我们自定义类的实例,回到这个方法中def _authenticate(self),当中传的值self指的是新的request,这个我们必须搞清楚,然后循环这个包含所有自定义类实例的列表,得到一个个实例对象。

    for authenticator in self.authenticators:
    try: user_auth_tuple = authenticator.authenticate(self) except exceptions.APIException: self._not_authenticated() raise if user_auth_tuple is not None: self._authenticator = authenticator self.user, self.auth = user_auth_tuple return

    然后执行每一个实例对象的authenticate(self)方法,也就是我们在自定义类中必须实现的方法,这就是原因,因为源码执行时,会找这个方法,认证成功,返回一个元组 ,认证失败,捕捉一个异常,APIException。认证成功,这个元组会被self.user,self.auth接收值,所以我们要在认证成功时返回含有两个值的元组。这里的self是我们新的request,这样我们在视图函数和模板中,只要在这个request的生命周期,我们都可以通过request.user得到我们返回的第一个值,通过request.auth得到我们返回的第二个值。这就是认证的内部实现,很牛逼。

    备注:所以我们认证成功后返回值时,第一个值,最好时当前登录人的名字,第二个值,按需设置,一般是token值(标识身份的随机字符串)

     这种方式只是实现了对某一张表的认证,如果我们有100张表,那这个代码我就要写100遍,复用性很差,所以需要在全局中定义。

    全局视图认证组件

    settings.py配置如下:

    1
    2
    3
    REST_FRAMEWORK={
        "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",]
    }

    为什么要这样配置呢?我们看看内部的实现吧:

      从哪开始呢?在局部中,我们在视图类中加一个属性authentication_classes=[自定义类],那么在内部,肯定有一个默认的值:

    class APIView(View):
    
        authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES

    默认值是api_settings中的一个属性,看看这个默认值都实现了啥

    api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)

    首先,这个api_settings是APISettings类的一个实例化对象,传了三个参数,那这个DEFAULTS是啥,看看:

    复制代码
    #rest_framework中的settings.py
    
    DEFAULTS = {
        # Base API policies
        'DEFAULT_RENDERER_CLASSES': (
            'rest_framework.renderers.JSONRenderer',
            'rest_framework.renderers.BrowsableAPIRenderer',
        ),
        'DEFAULT_PARSER_CLASSES': (
            'rest_framework.parsers.JSONParser',
            'rest_framework.parsers.FormParser',
            'rest_framework.parsers.MultiPartParser'
        ),
        'DEFAULT_AUTHENTICATION_CLASSES': (
            'rest_framework.authentication.SessionAuthentication',
            'rest_framework.authentication.BasicAuthentication'
        ),
        'DEFAULT_PERMISSION_CLASSES': (
            'rest_framework.permissions.AllowAny',
        ),
        'DEFAULT_THROTTLE_CLASSES': (),
        'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
        'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
        'DEFAULT_VERSIONING_CLASS': None,
    
        # Generic view behavior
        'DEFAULT_PAGINATION_CLASS': None,
        'DEFAULT_FILTER_BACKENDS': (),
    
        # Schema
        'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema',
    
        # Throttling
        'DEFAULT_THROTTLE_RATES': {
            'user': None,
            'anon': None,
        },
        'NUM_PROXIES': None,
    
        # Pagination
        'PAGE_SIZE': None,
    
        # Filtering
        'SEARCH_PARAM': 'search',
        'ORDERING_PARAM': 'ordering',
    
        # Versioning
        'DEFAULT_VERSION': None,
        'ALLOWED_VERSIONS': None,
        'VERSION_PARAM': 'version',
    
        # Authentication
        'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
        'UNAUTHENTICATED_TOKEN': None,
    
        # View configuration
        'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
        'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
    
        # Exception handling
        'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
        'NON_FIELD_ERRORS_KEY': 'non_field_errors',
    
        # Testing
        'TEST_REQUEST_RENDERER_CLASSES': (
            'rest_framework.renderers.MultiPartRenderer',
            'rest_framework.renderers.JSONRenderer'
        ),
        'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',
    
        # Hyperlink settings
        'URL_FORMAT_OVERRIDE': 'format',
        'FORMAT_SUFFIX_KWARG': 'format',
        'URL_FIELD_NAME': 'url',
    
        # Input and output formats
        'DATE_FORMAT': ISO_8601,
        'DATE_INPUT_FORMATS': (ISO_8601,),
    
        'DATETIME_FORMAT': ISO_8601,
        'DATETIME_INPUT_FORMATS': (ISO_8601,),
    
        'TIME_FORMAT': ISO_8601,
        'TIME_INPUT_FORMATS': (ISO_8601,),
    
        # Encoding
        'UNICODE_JSON': True,
        'COMPACT_JSON': True,
        'STRICT_JSON': True,
        'COERCE_DECIMAL_TO_STRING': True,
        'UPLOADED_FILES_USE_URL': True,
    
        # Browseable API
        'HTML_SELECT_CUTOFF': 1000,
        'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",
    
        # Schemas
        'SCHEMA_COERCE_PATH_PK': True,
        'SCHEMA_COERCE_METHOD_NAMES': {
            'retrieve': 'read',
            'destroy': 'delete'
        },
    }
    复制代码

    很直观的DEFAULTS是一个字典,包含多组键值,key是一个字符串,value是一个元组,我们需要的数据:

    'DEFAULT_AUTHENTICATION_CLASSES': ( 'rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.BasicAuthentication' ), 这有啥用呢,api_settings是怎么使用这个 参数的呢

    复制代码
    class APISettings(object):
        """
        A settings object, that allows API settings to be accessed as properties.
        For example:
    
            from rest_framework.settings import api_settings
            print(api_settings.DEFAULT_RENDERER_CLASSES)
    
        Any setting with string import paths will be automatically resolved
        and return the class, rather than the string literal.
        """
        def __init__(self, user_settings=None, defaults=None, import_strings=None):
            if user_settings:
                self._user_settings = self.__check_user_settings(user_settings)
            self.defaults = defaults or DEFAULTS
            self.import_strings = import_strings or IMPORT_STRINGS
            self._cached_attrs = set()
    复制代码

    好像也看不出什么有用的信息,只是一些赋值操作,将DEFAULTS赋给了self.defaults,那我们再回去看,

    在视图函数中,我们不定义authentication_classes 就会执行默认的APIView下的authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES,通过这个api_settings对象调用值时,调不到会怎样,我们看看怎么取值吧,首先会在__getattribute__方法中先找,找不到,取自己的属性中找,找不到去本类中找,在找不到就去父类中找,再找不到就去__getattr__中找,最后报错。按照这个逻辑,执行默认 值时最后会走到类中的__getattr__方法中

    复制代码
    class APISettings(object):    
        def __getattr__(self, attr):
            if attr not in self.defaults:
                raise AttributeError("Invalid API setting: '%s'" % attr)
    
            try:
                # Check if present in user settings
                val = self.user_settings[attr]
            except KeyError:
                # Fall back to defaults
                val = self.defaults[attr]
    
            # Coerce import strings into classes
            if attr in self.import_strings:
                val = perform_import(val, attr)
    
            # Cache the result
            self._cached_attrs.add(attr)
            setattr(self, attr, val)
            return val
    复制代码

    备注:调用__getattr__方法时,会将取的值,赋值给__getattr__中的参数attr,所有,此时的attr是DEFAULT_AUTHENTICATION_CLASSES,这个值在self.defaluts中,所以,代码会执行到try中,val=self.user_settings[attr],这句代码的意思是,val的值是self.user_settings这个东西中取attr,所以self.user_settings一定是个字典,我们看看这个东西做l什么

    复制代码
    class APISettings(object):
        @property
        def user_settings(self):
            if not hasattr(self, '_user_settings'):
                self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
            return self._user_settings
    复制代码

    这个属性方法,返回self._user_settings,这个值从哪里来,看代码逻辑,通过反射,来取settings中的REST_FRAMEWORK的值,取不到,返回一个空字典。那这个settings是哪的,是我们项目的配置文件。

      综合起来,意思就是如果在api_settings对象中找不到这个默认值,就从全局settings中找一个变量REST_FRAMEWORK,这个变量是个字典,从这个字典中找attr(DEFAULT_AUTHENTICATION_CLASSES)的值,找不到,返回一个空字典。而我们的配置文件settings中并没有这个变量,很明显,我们可以在全局中配置这个变量,从而实现一个全局的认证。怎么配?

      配置格式:REST_FRAMEWORK={“DEFAULT_AUTHENTICATION_CLASSES”:("认证代码所在路径","...")}

    代码继续放下执行,执行到if attr in self.import_strings:       val = perform_import(val, attr)   这两行就开始,根据取的值去通过字符串的形式去找路径,最后得到我们配置认证的类,在通过setattr(self, attr, val)   实现,调用这个默认值,返回配置类的执行。

    复制代码
    def perform_import(val, setting_name):
        """
        If the given setting is a string import notation,
        then perform the necessary import or imports.
        """
        if val is None:
            return None
        elif isinstance(val, six.string_types):
            return import_from_string(val, setting_name)
        elif isinstance(val, (list, tuple)):
            return [import_from_string(item, setting_name) for item in val]
        return val
    
    
    def import_from_string(val, setting_name):
        """
        Attempt to import a class from a string representation.
        """
        try:
            # Nod to tastypie's use of importlib.
            module_path, class_name = val.rsplit('.', 1)
            module = import_module(module_path)
            return getattr(module, class_name)
        except (ImportError, AttributeError) as e:
            msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
            raise ImportError(msg)
    复制代码

    备注:通过importlib模块的import_module找py文件。

    权限组件

    局部视图权限

    在app01.service.permissions.py中:

    复制代码
    from rest_framework.permissions import BasePermission
    class SVIPPermission(BasePermission):
        message="SVIP才能访问!"
        def has_permission(self, request, view):
            if request.user.user_type==3:
                return True
            return False
    复制代码

    在views.py:

    复制代码
    from app01.service.permissions import *
    
    class BookViewSet(generics.ListCreateAPIView):
        permission_classes = [SVIPPermission,]
        queryset = Book.objects.all()
        serializer_class = BookSerializers
    复制代码

    全局视图权限

    settings.py配置如下:

    1
    2
    3
    4
    REST_FRAMEWORK={
        "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",],
        "DEFAULT_PERMISSION_CLASSES":["app01.service.permissions.SVIPPermission",]
    }

    权限组件一样的逻辑,同样的配置,要配置权限组件,就要在视图类中定义一个变量permission_calsses = [自定义类]

    同样的在这个自定义类中,要定义一个固定的方法   def has_permission(self,request,view)  必须传两个参数,一个request,一个是

    当前视图类的对象。权限通过返回True,不通过返回False,我们可以定制错误信息,在自定义类中配置一个静态属性message="错误信息",也可以继承一个类BasePermission。跟认证一样。

    复制代码
        def check_permissions(self, request):
            """
            Check if the request should be permitted.
            Raises an appropriate exception if the request is not permitted.
            """
            for permission in self.get_permissions():
                if not permission.has_permission(request, self):
                    self.permission_denied(
                        request, message=getattr(permission, 'message', None)
                    )
    复制代码

    逻辑很简单,循环这个自定义类实例对象列表,从每一个对象中找has_permission,if not  返回值:表示返回False,,就会抛一个异常,这个异常的信息,会先从对象的message属性中找。

      同样的,配置全局的权限,跟认证一样,在settings文件中的REST_FRAMEWORK字典中配一个键值即可。

    throttle(访问频率)组件

    局部视图throttle

    在app01.service.throttles.py中:

    复制代码
    复制代码
    from rest_framework.throttling import BaseThrottle
    
    VISIT_RECORD={}
    class VisitThrottle(BaseThrottle):
    
        def __init__(self):
            self.history=None
    
        def allow_request(self,request,view):
            remote_addr = request.META.get('REMOTE_ADDR')
            print(remote_addr)
            import time
            ctime=time.time()
    
            if remote_addr not in VISIT_RECORD:
                VISIT_RECORD[remote_addr]=[ctime,]
                return True
    
            history=VISIT_RECORD.get(remote_addr)
            self.history=history
    
            while history and history[-1]<ctime-60:
                history.pop()
    
            if len(history)<3:
                history.insert(0,ctime)
                return True
            else:
                return False
    
        def wait(self):
            import time
            ctime=time.time()
            return 60-(ctime-self.history[-1])
    复制代码
    复制代码

    在views.py中:

    复制代码
    from app01.service.throttles import *
    
    class BookViewSet(generics.ListCreateAPIView):
        throttle_classes = [VisitThrottle,]
        queryset = Book.objects.all()
        serializer_class = BookSerializers
    复制代码

    全局视图throttle

    REST_FRAMEWORK={
        "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",],
        "DEFAULT_PERMISSION_CLASSES":["app01.service.permissions.SVIPPermission",],
        "DEFAULT_THROTTLE_CLASSES":["app01.service.throttles.VisitThrottle",]
    }

    同样的,局部使用访问频率限制组件时,也要在视图类中定义一个变量:throttle_classes = [自定义类],同样在自定义类中也要定义一个固定的方法def allow_request(self,request,view),接收两个参数,里面放我们频率限制的逻辑代码,返回True通过,返回False限制,同时要定义一个def wait(self):pass   放限制的逻辑代码。

    内置throttle类

    在app01.service.throttles.py修改为:

    复制代码
    class VisitThrottle(SimpleRateThrottle):
    
        scope="visit_rate"
        def get_cache_key(self, request, view):
    
            return self.get_ident(request)
    复制代码

    settings.py设置:

    复制代码
    REST_FRAMEWORK={
        "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",],
        "DEFAULT_PERMISSION_CLASSES":["app01.service.permissions.SVIPPermission",],
        "DEFAULT_THROTTLE_CLASSES":["app01.service.throttles.VisitThrottle",],
        "DEFAULT_THROTTLE_RATES":{
            "visit_rate":"5/m",
        }
    }
    复制代码

    总结

       restframework的三大套件中为我们提供了多个粒度的控制。局部的管控和全局的校验,都可以很灵活的控制。下一个系列中,将会带来restframework中的查漏补缺。

  • 相关阅读:
    Flutter 中那么多组件,难道要都学一遍?
    【Flutter实战】自定义滚动条
    Python 为什么只需一条语句“a,b=b,a”,就能直接交换两个变量?
    一篇文章掌握 Python 内置 zip() 的全部内容
    Python 3.10 的首个 PEP 诞生,内置类型 zip() 迎来新特性
    Python 3.10 版本采纳了首个 PEP,中文翻译即将推出
    Python 为什么不支持 i++ 自增语法,不提供 ++ 操作符?
    Python 为什么推荐蛇形命名法?
    Python 为什么没有 main 函数?为什么我不推荐写 main 函数?
    Python 为什么不用分号作终止符?
  • 原文地址:https://www.cnblogs.com/zhaopanpan/p/9450364.html
Copyright © 2020-2023  润新知