• Python3高级核心技术 第八章 MyOrm


    import numbers
    import collections
    from helper import MysqlHelper
    
    
    OrderByTuple = collections.namedtuple("OrderByTuple", ["name", "direction"])
    
    
    class Field:
        pass
    
    
    class IntField(Field):
        # 数据描述符
        def __init__(self, db_column, min_value=None, max_value=None):
            self._name = 'i_' + db_column
            self.min_value = min_value
            self.max_value = max_value
            self.db_column = db_column
            if min_value is not None:
                if not isinstance(min_value, numbers.Integral):
                    raise ValueError("min_value must be int")
                elif min_value < 0:
                    raise ValueError("min_value must be positive int")
            if max_value is not None:
                if not isinstance(max_value, numbers.Integral):
                    raise ValueError("max_value must be int")
                elif max_value < 0:
                    raise ValueError("max_value must be positive int")
            if min_value is not None and max_value is not None:
                if min_value > max_value:
                    raise ValueError("min_value must be smaller than max_value")
    
        def __get__(self, instance, owner):
            return getattr(instance, self._name)
    
        def __set__(self, instance, value):
            if not isinstance(value, numbers.Integral):
                raise ValueError("int value need")
            if (self.min_value is not None and value < self.min_value) or 
                    (self.max_value is not None and value > self.max_value):
                raise ValueError("value must between min_value and max_value")
            setattr(instance, self._name, value)
    
    
    class CharField(Field):
        def __init__(self, db_column, max_length=None):
            self._name = 's_' + db_column
            self.db_column = db_column
            if max_length is None:
                raise ValueError("you must spcify max_lenth for charfiled")
            self.max_length = max_length
    
        def __get__(self, instance, owner):
            return getattr(instance, self._name)
    
        def __set__(self, instance, value):
            if not isinstance(value, str):
                raise ValueError("string value need")
            if len(value) > self.max_length:
                raise ValueError("value len excess len of max_length")
            setattr(instance, self._name, value)
    
    
    class PrimaryKeyMixIn:
        id = IntField(db_column="id")
    
    
    class ModelMetaClass(type):
        def __new__(cls, name, bases, attrs):
            if name == "BaseModel":
                return super().__new__(cls, name, bases, attrs)
            fields = {"id": IntField(db_column="id")}
            for key, value in attrs.items():
                if isinstance(value, Field):
                    fields[key] = value
            attrs_meta = attrs.get("Meta", None)
            _meta = {}
            db_table = name.lower()
            if attrs_meta is not None:
                table = getattr(attrs_meta, "db_table", None)
                if table is not None:
                    db_table = table
            _meta["db_table"] = db_table
            attrs["_meta"] = _meta
            attrs["fields"] = fields
            del attrs["Meta"]
            return super().__new__(cls, name, bases, attrs)
    
    
    class BaseModel(metaclass=ModelMetaClass):
        def __init__(self, *args, **kwargs):
            for key, value in kwargs.items():
                setattr(self, key, value)
            super().__init__()
    
        def save(self):
            fields = []
            values = []
            for key, value in self.fields.items():
                db_column = value.db_column
                if db_column is None:
                    db_column = key.lower()
                fields.append(db_column)
                value = getattr(self, key)
                if isinstance(value, str):
                    values.append("'" + value + "'")
                else:
                    values.append(str(value))
            # insert user(id, name) values (1, 'admin')
            sql = "insert {db_table}({fields}) values ({values})".format(db_table=self._meta["db_table"],
                                                                         fields=",".join(fields),
                                                                         values=",".join(values))
            print(sql)
            db = MysqlHelper()
            db.insert(sql)
    
        def delete(self):
            conditions = []
            for key, value in self.fields.items():
                db_column = value.db_column
                if db_column is None:
                    db_column = key.lower()
                value = getattr(self, key)
                if value is not None:
                    if isinstance(value, str):
                        conditions.append(db_column + " = '" + value + "'")
                    else:
                        conditions.append(db_column + " = " + str(value))
            # delete from user where id = 3
            sql = "delete from " + self._meta["db_table"] + " where " + " and ".join(conditions)
            print(sql)
            db = MysqlHelper()
            db.delete(sql)
    
        @classmethod
        def batch_delete(cls, ids):
            # delete from user where id in (9, 10, 11)
            sql = "delete from " + cls._meta["db_table"] + " where id in (" + ",".join([str(p_key) for p_key in ids]) + ")"
            print(sql)
            db = MysqlHelper()
            db.delete(sql)
    
        def update(self):
            conditions = []
            primary_key = getattr(self, 'id', None)
            if primary_key is None:
                raise KeyError("primary key cannot be empty")
            for key, value in self.fields.items():
                db_column = value.db_column
                if db_column is None:
                    db_column = key.lower()
                value = getattr(self, key)
                if isinstance(value, str):
                    conditions.append(db_column + " = '" + value + "'")
                else:
                    conditions.append(db_column + " = " + str(value))
            # update user set name = 'test' where id = 4
            sql = "update " + self._meta["db_table"] + " set " + ", ".join(conditions) + " where id = " + str(primary_key)
            print(sql)
            db = MysqlHelper()
            db.update(sql)
    
        @classmethod
        def batch_update(cls, obs):
            conditions = []
            fields = set()
            ids = []
            for item in obs:
                ids.append(item.id)
                for key, value in cls.fields.items():
                    value = getattr(item, key, None)
                    if value is not None:
                        fields.add(key)
    
            for key in fields:
                db_column = cls.fields[key].db_column
                if db_column is None:
                    db_column = key.lower()
                when_tens = []
                for p_key in ids:
                    instance = list(filter(lambda x: x.id == p_key, obs))[0]
                    value = getattr(instance, key)
                    if isinstance(value, str):
                        when_tens.append("when " + str(p_key) + " then '" + value + "'")
                    else:
                        when_tens.append("when " + str(p_key) + " then " + str(value))
    
                conditions.append(db_column + " = case id " + " ".join(when_tens) + " end")
    
            # update user set
            # id = case id when 1 then 1 when 2 then 2 when 3 then 3 end,
            # title = case id when 1 then 'test1' when 2 then 'test2' when 3 then 'test3' end
            # where id in (1,2,3)
            sql = "update " + cls._meta["db_table"] + " set "
            sql += ", ".join(conditions) + " "
            sql += "where id in (" + ",".join([str(p_key) for p_key in ids]) + ")"
            print(sql)
            db = MysqlHelper()
            db.update(sql)
    
        @classmethod
        def all(cls, **kwargs):
            sql = "select * from " + cls._meta["db_table"] + " "
            order_by = kwargs.get("order_by", None)
            if order_by is not None:
                options = []
                for item in order_by:
                    option = item.name + " " + item.direction
                    options.append(option)
                sql += "order by " + ", ".join(options)
            # select * from user order by id desc, name asc
            print(sql)
            db = MysqlHelper()
            res = db.get_all_obj(sql, cls._meta["db_table"])
            obs = []
            for item in res:
                obj = cls()
                for k, v in item.items():
                    setattr(obj, k, v)
                obs.append(obj)
            return obs
    
        @classmethod
        def filter(cls, **kwargs):
            conditions = []
            has_order_by = False
            for key, value in kwargs.items():
                if key == "order_by":
                    has_order_by = True
                    continue
                field_value = cls.fields.get(key, None)
                if field_value is None:
                    raise KeyError("the field " + key + " does not exist")
                db_column = field_value.db_column
                if db_column is None:
                    db_column = key.lower()
                if isinstance(value, str):
                    conditions.append(db_column + " like '%" + value + "%'")
                else:
                    conditions.append(db_column + " = " + str(value))
            sql = "select * from " + cls._meta["db_table"] + " where " + " and ".join(conditions) + " "
            options = []
            if has_order_by:
                order_by = kwargs.get("order_by", None)
                if order_by is not None:
                    for item in order_by:
                        option = item.name + " " + item.direction
                        options.append(option)
            if len(options) > 0:
                sql += "order by " + ", ".join(options)
            # select * from user where name = 'admin' order by id desc, name asc
            print(sql)
            db = MysqlHelper()
            res = db.get_all_obj(sql, cls._meta["db_table"])
            obs = []
            for item in res:
                obj = cls()
                for k, v in item.items():
                    setattr(obj, k, v)
                obs.append(obj)
            return obs
    
        @classmethod
        def one(cls, **kwargs):
            conditions = []
            for key, value in kwargs.items():
                field_value = cls.fields.get(key, None)
                if field_value is None:
                    raise KeyError("the field " + key + " does not exist")
                db_column = field_value.db_column
                if db_column is None:
                    db_column = key.lower()
                if isinstance(value, str):
                    conditions.append(db_column + " = '%" + value + "%'")
                else:
                    conditions.append(db_column + " = " + str(value))
            # select * from user where name = 'admin'
            sql = "select * from " + cls._meta["db_table"] + " where " + " and ".join(conditions)
            print(sql)
            db = MysqlHelper()
            item = db.get_one_obj(sql, cls._meta["db_table"])
            if item is None:
                return None
            obj = cls()
            for k, v in item.items():
                setattr(obj, k, v)
            return obj
  • 相关阅读:
    java程序员怎么创建自己的网站:第一章:总体流程
    技术汇总:第五章:使用angularjs做首页三级分类
    js中Function的apply方法与call方法理解
    常用方法
    Array对象(一)
    一张图理解is_nll isset empty
    解析centos中Apache、php、mysql 默认安装路径
    常用命令
    centos虚拟机启用网卡
    初学Linux笔记
  • 原文地址:https://www.cnblogs.com/yejing-snake/p/13565425.html
Copyright © 2020-2023  润新知