• python轻量级orm


    python下的orm使用SQLAlchemy比较多,用了一段时间感觉不顺手,主要问题是SQLAlchemy太重,所以自己写了一个orm,实现方式和netsharp类似,oql部分因为代码比较多,没有完全实现

    下面是源代码

    一,系统配置

    configuration.py

    # !/usr/bin/python
    # -*- coding: UTF-8 -*-
    
    host="192.168.4.1"
    port=3306
    db="wolf"
    user="root"
    pwd="xxx"
    user_name="xxx"

    二,orm.py

       1 # !/usr/bin/python
       2 # -*- coding: UTF-8 -*-
       3 
       4 import sys
       5 from datetime import *
       6 from enum import Enum
       7 import logging
       8 
       9 import MySQLdb
      10 import MySQLdb.cursors
      11 
      12 import configuration
      13 
      14 #########################################################################
      15         
      16 class Field(object):
      17     
      18     property_name = None
      19     column_name = None
      20     group_name = None
      21     column_type_name = None
      22     header = None
      23     memoto = None
      24 
      25     is_primary_key = None
      26     is_auto = False
      27     is_name_equals = None
      28     is_required = None
      29     is_unique = None
      30 
      31     size = None
      32     precision = None
      33 
      34     def __init__(self,**kw):        
      35         for k,v in kw.iteritems():
      36             setattr(self,k,v) 
      37 
      38 class ShortField(Field) :
      39     
      40     def __init__(self,**kw):
      41         self.column_type_name="smallint"
      42         super(ShortField,self).__init__(**kw)
      43        
      44 
      45 class IntField(Field) :
      46     
      47     def __init__(self,**kw):
      48         self.column_type_name="int"
      49         super(IntField,self).__init__(**kw)
      50        
      51 
      52 class LongField(Field) :
      53     
      54     def __init__(self,**kw):
      55         self.column_type_name="bigint"
      56         super(LongField,self).__init__(**kw)
      57               
      58 
      59 class StringFiled(Field) :
      60     
      61     def __init__(self,**kw):
      62         self.column_type_name="nvarchar"
      63         self.size=50         
      64         super(StringFiled,self).__init__(**kw)
      65 
      66 
      67 class BoolFiled(Field) :
      68     def __init__(self,**kw):
      69         self.column_type_name="bool"         
      70         super(BoolFiled,self).__init__(**kw)
      71        
      72 
      73 class FloatFiled(Field) :
      74     def __init__(self,**kw):
      75         self.column_type_name="float"
      76         self.size=8
      77         self.precision=4         
      78         super(FloatFiled,self).__init__(**kw)
      79 
      80 
      81 class DoubleFiled(Field) :
      82     def __init__(self,**kw):
      83         self.column_type_name="double"
      84         self.size=8
      85         self.precision=4         
      86         super(DoubleFiled,self).__init__(**kw)
      87 
      88 
      89 class DecimalFiled(Field) :
      90     def __init__(self,**kw):
      91         self.column_type_name="decimal"
      92         self.size=14
      93         self.precision=8         
      94         super(DecimalFiled,self).__init__(**kw)
      95 
      96 class DateTimeFiled(Field) :
      97     def __init__(self,**kw):
      98        self.column_type_name="datetime"         
      99        super(DateTimeFiled,self).__init__(**kw)
     100 
     101 class BinaryFiled(Field) :
     102     def __init__(self,**kw):
     103        self.column_type_name="longblob"         
     104        super(BinaryFiled,self).__init__(**kw)
     105        
     106 class Reference(object) :
     107     
     108     property_name = None
     109     header = None
     110     foreign_key = None
     111     primary_key = None
     112 
     113     reference_type = None
     114 
     115     def __init__(self,**kw):        
     116         for k,v in kw.iteritems():
     117             setattr(self,k,v)
     118 
     119 class Subs(object) :
     120     
     121     property_name = None
     122     header = None
     123     foreign_key = None
     124     primary_key = None
     125 
     126     sub_type = None
     127 
     128     def __init__(self,**kw):        
     129         for k,v in kw.iteritems():
     130             setattr(self,k,v)   
     131 
     132 #########################################################################
     133 
     134 class EntityState(Enum):
     135     # 瞬时状态,不受数据库管理的状态
     136     # 或者是事务提交后不需要更新的实体
     137     Transient = 0
     138     # 事务提交后将被新增
     139     New = 1
     140 
     141     # 事务提交后将被修更新
     142     Persist = 2
     143 
     144     # 事务提交后将被删除
     145     Deleted = 3
     146 
     147 class Persistable(object) :
     148     
     149     __entity_status__ = EntityState.New
     150     
     151     def __init__(self,**kw):        
     152         for k,v in kw.iteritems():
     153             setattr(self,k,v)
     154 
     155     def to_new(self) :
     156         self.__entity_status__ = EntityState.New
     157 
     158     def to_persist(self) :
     159         self.__entity_status__ = EntityState.Persist
     160 
     161     def to_delete(self) :
     162         self.__entity_status__ = EntityState.Deleted
     163 
     164     def to_transient(self) :
     165         self.__entity_status__ = EntityState.Transient            
     166 
     167 
     168 class Entity(Persistable) :
     169     
     170     id = IntField(is_primary_key=True,is_auto=True)
     171     creator = StringFiled(size=100)
     172     create_time = DateTimeFiled()
     173     update_time = DateTimeFiled()
     174     updator = StringFiled(size=100)
     175 
     176 class BizEntity(Entity) :
     177     
     178     code = StringFiled()
     179     name = StringFiled(size=100) 
     180     memoto = StringFiled(size=500)  
     181 
     182 #########################################################################
     183 
     184 class NColumn :
     185     
     186     property_name = None
     187     column_name = None
     188     group_name = None
     189     column_type_name = None
     190     header = None
     191     memoto = None
     192     column_type_name = None
     193 
     194     is_primary_key = None
     195     is_auto = False
     196     is_name_equals = None
     197     is_required = None
     198     is_unique = None
     199 
     200     size = None
     201     precision = None
     202 
     203     def __repr__(self):
     204         return self.property_name + "["+self.column_name+"]"
     205 
     206 
     207 class NEntity :
     208     
     209     name = None
     210     table_name = None
     211     entity_id = None
     212     header= None
     213     is_view= None
     214     is_refcheck= None
     215     key_column = None
     216     auto_column = None
     217     order_by= None
     218     type = None
     219 
     220     columns = {}
     221     fields = {}
     222     subs = {}
     223     references = {}
     224 
     225     # def __repr__(self):
     226     #     return self.__name__ + "["+self.table_name+"]"
     227 
     228 class NReference :
     229     header = None
     230     foreign_key = None
     231     primary_key = None
     232 
     233     foreign_key_column = None
     234     primary_key_column = None
     235 
     236     reference_type = None
     237 
     238 class NSubs :
     239     header = None
     240     foreign_key = None
     241     primary_key = None
     242 
     243     foreign_key_column = None
     244     primary_key_column = None
     245 
     246     sub_type = None
     247 
     248 class EntityManager :
     249     
     250     entityMap = {}
     251 
     252     @classmethod
     253     def get_meta(cls,type) :
     254         
     255         
     256         ne = cls.entityMap.get(type.__name__)
     257         if ne == None :
     258             ne = cls.parse_entity(type)
     259             cls.entityMap[type.__name__]= ne
     260 
     261         return ne
     262 
     263     @classmethod
     264     def parse_entity(cls,type) :
     265         
     266         ne = NEntity()
     267         ne.type = type
     268 
     269         ne.table_name= type.__table_name__
     270         ne.name = type.__name__
     271         ne.entity_id=type.__name__
     272         ne.header= type.__doc__
     273         ne.is_view= False
     274         ne.is_refcheck= False
     275 
     276         ne.key_column = None
     277         ne.auto_column = None
     278         ne.order_by= None
     279 
     280         ne.columns = {}
     281         ne.fields = {}
     282         ne.full_columns = {}
     283         ne.full_fields = {}
     284         ne.subs = {}
     285         ne.references = {}
     286 
     287         for k in dir(type) :
     288             
     289             v = getattr(type,k)
     290 
     291             if isinstance(v, Field) :
     292 
     293                 c = cls.parse_field(k,v)
     294 
     295                 ne.full_fields[k]=c
     296                 ne.full_columns[c.column_name] = c
     297 
     298                 if c.is_primary_key :
     299                     ne.key_column = c
     300                 else :
     301                     ne.fields[k]=c
     302                     ne.columns[c.column_name] = c
     303 
     304                 if c.is_auto :
     305                     ne.auto_column = c
     306 
     307             if isinstance(v, Reference) :
     308                 r = cls.parse_reference(k,v)
     309                 ne.references[k]=r
     310 
     311             if isinstance(v, Subs) :
     312                 s = cls.parse_subs(k,v)
     313                 ne.subs[k]=s  
     314         
     315         for r in ne.references.values() :
     316             r.foreign_key_column = ne.full_fields[r.foreign_key]
     317 
     318         for r in ne.subs.values() :
     319             r.primary_key_column = ne.key_column
     320 
     321         return ne
     322 
     323     @classmethod
     324     def parse_field(cls,name,field) :
     325         
     326         c = NColumn()
     327         
     328         c.property_name = name
     329         c.column_name = name
     330 
     331         if field.column_name != None :
     332             c.column_name = name
     333 
     334         c.is_primary_key = False
     335         c.is_auto= False
     336         c.is_name_equals = True
     337         c.is_required = False
     338         c.group_name = ""
     339         c.column_type_name = field.column_type_name
     340         c.size= field.size
     341         c.precision = field.precision
     342         c.group_name = field.group_name
     343         c.is_primary_key = field.is_primary_key
     344         c.is_auto = field.is_auto
     345 
     346         return c
     347 
     348     @classmethod
     349     def parse_reference(cls,name,field) :
     350         r = NReference()
     351 
     352         r.property_name = name
     353         r.foreign_key = field.foreign_key
     354         r.header = field.header
     355         r.primary_key = field.primary_key
     356 
     357         if r.primary_key == None :
     358             r.primary_key = "id"
     359 
     360         r.reference_type = field.reference_type
     361 
     362         r.foreign_key_column = None
     363 
     364         rne = EntityManager.get_meta( field.reference_type )
     365         r.primary_key_column = rne.full_fields[r.primary_key]
     366 
     367         return r
     368 
     369     @classmethod
     370     def parse_subs(cls,name,field) :
     371         
     372         s = NSubs()
     373 
     374         s.property_name = name
     375         s.foreign_key = field.foreign_key
     376         s.header = field.header
     377         s.primary_key = field.primary_key
     378 
     379         if s.primary_key == None :
     380             s.primary_key = "id"
     381 
     382         s.primary_key_column = None
     383 
     384         s.sub_type = field.sub_type  
     385 
     386         rne = EntityManager.get_meta( field.sub_type )
     387         s.foreign_key_column = rne.full_fields[field.foreign_key]          
     388 
     389         return s
     390 
     391 class ORMException(Exception):
     392     def __init__(self,value):
     393             self.value=value
     394 
     395 #########################################################################
     396 
     397 class DataAccess :
     398     
     399     conn = None
     400     cursor = None
     401     isClosed = True
     402 
     403     def open(self,host=configuration.host,port=configuration.port,db=configuration.db,user=configuration.user,pwd=configuration.pwd) :
     404         
     405         self.conn = MySQLdb.connect(host=host,port=port,db=db,user=user,passwd=pwd,charset="utf8")
     406         self.cursor = self.conn.cursor()
     407         self.isClosed = False
     408 
     409     def execute(self,cmd,pars = None) :
     410         if self.isClosed :
     411             raise Exception("db is not opened!")
     412 
     413         logging.info(cmd)
     414 
     415         ret = self.cursor.execute(cmd,pars)
     416         return ret
     417 
     418     def get_last_row_id(self) :
     419         return int(self.cursor.lastrowid)
     420 
     421     def fetchone(self,cmd,pars = None) :
     422         if self.isClosed :
     423             raise Exception("db is not opened!")
     424 
     425         logging.info(cmd)
     426         self.cursor.execute(cmd,pars)
     427         row = self.cursor.fetchone()
     428 
     429         return row
     430 
     431     def fetchall(self,cmd,pars = None) :
     432         if self.isClosed :
     433             raise Exception("db is not opened!")
     434             
     435         logging.info(cmd)
     436         self.cursor.execute(cmd,pars)
     437         rows = self.cursor.fetchall()
     438 
     439         return rows
     440 
     441     def executeScalar(self,cmd,pars = None) :
     442         if self.isClosed :
     443             raise Exception("db is not opened!")
     444 
     445         self.cursor.execute(cmd,pars)
     446         row = self.cursor.fetchone()
     447 
     448         if row == None :
     449             return None
     450         if len(row)==0 :
     451             return None
     452         return row[0]
     453 
     454 
     455     def commit(self):
     456         self.conn.commit()
     457     
     458     def rolback(self):
     459         self.conn.rolback()
     460 
     461     def close(self) :
     462         if self.isClosed :
     463             pass;
     464         
     465         self.conn.close()
     466         self.isClosed=True
     467         self.conn = None
     468         self.cursor = None
     469 
     470 #########################################################################
     471 
     472 class SqlGenerator(object):
     473     
     474     @classmethod
     475     def generate_insert(cls,ne) :
     476         
     477         columns = ne.full_columns
     478         if ne.key_column.is_auto :
     479             columns = ne.columns
     480         
     481         sql = 'insert into %s (%s) values (%s);' % ( ne.table_name, ', '.join(columns.keys()), ', '.join(['%s'] * len(columns)))
     482 
     483         return sql
     484 
     485     @classmethod
     486     def generate_update(cls,ne) :
     487 
     488         sets = []
     489 
     490         for c in ne.columns.keys() :
     491             sets.append("%s = %s" % (c,'%s'))
     492 
     493         sql = 'update %s set %s where %s = %s;' % ( ne.table_name, ",".join(sets),ne.key_column.column_name,"%s")
     494         return sql
     495 
     496     @classmethod
     497     def generate_delete(cls,ne) :
     498         
     499         sql = "delete from %s where %s = %s" % (ne.table_name,ne.key_column.column_name,"%s")
     500 
     501         return sql
     502 
     503     @classmethod
     504     def generate_byid(cls,ne) :
     505         
     506         sql = "select %s from %s where %s = %s" % (",".join(ne.full_columns.keys()),ne.table_name,ne.key_column.column_name,'%s')
     507 
     508         return sql
     509 
     510 class Db(object) :
     511     
     512     dao = DataAccess()
     513 
     514     def open(self) :
     515         self.dao.open()
     516 
     517     def create_db(self,db_name) :
     518         
     519         sql = "CREATE SCHEMA %s DEFAULT CHARACTER SET ucs2 ;" % db_name
     520         self.dao.execute(sql)
     521 
     522     def drop_db(self,db_name) :
     523         sql ="DROP DATABASE %s;" % db_name
     524         self.dao.execute(sql)
     525     
     526     def create_table(self,cls) :
     527         
     528         ne= EntityManager.get_meta(cls)
     529         
     530         columns = []
     531         for column in ne.full_columns.values() :
     532             c = self.generate_column(column)
     533             columns.append(c)
     534             
     535         sql ="create table %s( %s )" % (ne.table_name,",".join(columns))
     536 
     537         self.dao.execute(sql,None)
     538 
     539     def generate_column(self,column) :
     540         sql = "%s %s" % (column.column_name , column.column_type_name)
     541 
     542         if column.size != None :
     543             if column.precision != None :
     544                 sql += "(%d,%d)" % (column.size ,column.precision)
     545             else:
     546                 sql += "(%d)" % column.size 
     547 
     548         if column.is_auto :
     549             sql += " AUTO_INCREMENT"
     550 
     551         if column.is_primary_key :
     552             sql +=" PRIMARY KEY"
     553 
     554         return sql
     555 
     556     def drop_table(self,cls) :
     557         ne= EntityManager.get_meta(cls)
     558         sql ="drop table if exists %s" % ne.table_name
     559 
     560         self.dao.execute(sql,None)
     561 
     562     def commit(self) :
     563         self.dao.commit()
     564 
     565     def close(self) :
     566         self.dao.close()  
     567 
     568 #########################################################################
     569 
     570 class SetQuery(object) :
     571 
     572     dao = None
     573 
     574     def __init__(self,dao) :
     575         self.dao = dao
     576 
     577     def query(self,ne,entity) :
     578         
     579         id = getattr(entity,ne.key_column.property_name)
     580         if id == None :
     581             raise ORMException("%s.id不能为空" % ne.table_name)
     582         
     583         #当前实体
     584         pars = [ id ] 
     585         sql = SqlGenerator.generate_byid(ne)
     586         row = self.dao.fetchone(sql,pars)
     587         if row == None :
     588             entity = None
     589             return None
     590         entity = self.read_row(ne,row,entity)
     591 
     592         self.query_iter(entity,ne)
     593 
     594         return entity
     595     
     596     def query_iter(self,entity,ne) :
     597         
     598         #查询引用实体
     599         for rm in ne.references.values() :
     600             filter = "%s = %s" % (rm.primary_key_column.column_name,"%s")
     601             pars = [getattr(entity,rm.foreign_key)]
     602             re = EntityManager.get_meta(rm.reference_type)
     603 
     604             sql = SqlGenerator.generate_byid(re)
     605             row = self.dao.fetchone(sql,pars) 
     606             ref = self.read_row(re,row)
     607             setattr(entity,rm.property_name,ref)        
     608         
     609         #查询子实体
     610         for sm in ne.subs.values() :
     611             filter = "%s = %s" % (sm.foreign_key_column.column_name,"%s")
     612             pars = [ getattr(entity,ne.key_column.property_name) ] 
     613             se = EntityManager.get_meta(sm.sub_type)
     614             subs = self.do_query(se,filter,pars)
     615 
     616             setattr(entity,sm.property_name,subs)
     617 
     618             for sub in subs :
     619                 self.query_iter(sub,se)
     620 
     621     def do_query(self,ne,filter,pars) :
     622         
     623         sql = "select %s from %s where %s" % (",".join(ne.full_columns.keys()),ne.table_name,filter)
     624         rows = self.dao.fetchall(sql,pars)
     625 
     626         entities = []
     627         for row in rows :
     628             entity = self.read_row(ne,row)
     629             entities.append(entity)
     630 
     631         return entities
     632 
     633     def read_row(self,ne,row,entity = None) :
     634         
     635         if entity == None :
     636             entity = ne.type()
     637 
     638         index = 0;
     639         for c in ne.full_columns.values() :
     640             cell_value = row[index]
     641             setattr(entity,c.property_name,cell_value)
     642             index +=1
     643 
     644         return entity  
     645 #########################################################################
     646           
     647 class Persister(object) :
     648     
     649     isClosed = True
     650     dao = DataAccess()
     651 
     652     def query_first(self,type,filters=None,pars=None):
     653         
     654         lis = self.query_list(type,filters,pars)
     655         if len(lis) == 0 :
     656             return None
     657         
     658         entity = lis[0]
     659         entity = self.byid(entity)
     660 
     661         return entity
     662 
     663     #实体列表查询
     664     #只支持基于主实体的一级查询
     665     def query_list(self,type,filters=None,pars=None) :
     666         
     667         ne = EntityManager.get_meta( type )
     668 
     669         columns = []
     670         for key in ne.full_columns.keys() :
     671             columns.append(ne.name+"."+key)
     672 
     673         sqls =["select %s from %s as %s" % (",".join(columns),ne.table_name,ne.name)]
     674         for r in ne.references.values() :
     675             re = EntityManager.get_meta(r.reference_type)
     676             sqls.append( "left join %s as %s on %s.%s = %s.%s" % (re.table_name, r.property_name,ne.name,r.foreign_key_column.column_name,r.property_name,re.key_column.column_name))
     677 
     678         if filters != None and filters.strip() != "" :
     679             sqls.append("where " + filters)
     680 
     681         sql = "
    ".join(sqls)
     682 
     683         setter = SetQuery(self.dao)
     684 
     685         rows = self.dao.fetchall(sql,pars)
     686         entities = []
     687 
     688         for row in rows :
     689             entity = setter.read_row(ne,row)
     690             entities.append(entity)
     691 
     692         return entities
     693 
     694     def byid(self,entity) :
     695         
     696         ne = EntityManager.get_meta(entity.__class__)
     697         query = SetQuery(self.dao)
     698         entity = query.query(ne,entity)
     699         
     700         return entity        
     701 
     702     def save(self,entity) :
     703         
     704         if entity.__entity_status__ == None :
     705             raise ORMException("调用save方法必须设置实体__entity_state__")
     706         
     707         ne = EntityManager.get_meta(entity.__class__)
     708 
     709         for r in ne.references.values() :
     710             ref = getattr(entity,r.property_name)
     711             fk = getattr(entity,r.foreign_key)
     712             if (fk == None or isinstance(fk,Field)) and ref != None :
     713                 pk = getattr(ref,r.primary_key)
     714                 setattr(entity,r.foreign_key,pk)
     715 
     716         if entity.__entity_status__ == EntityState.New :
     717             self.add(entity)
     718             id = getattr(entity,ne.key_column.property_name)
     719 
     720             for r in ne.subs.values() :
     721                 subs = getattr(entity,r.property_name)
     722                 if isinstance(subs,Subs) :
     723                     break
     724                 if subs is None :
     725                     continue
     726                 for sub in subs :
     727                     sub.to_new()
     728                     setattr(sub,r.foreign_key_column.property_name,id)
     729                     self.save(sub)
     730             entity.to_transient()     
     731 
     732         if entity.__entity_status__ == EntityState.Persist :
     733             self.update(entity)
     734             id = getattr(entity,ne.key_column.property_name)
     735 
     736             for r in ne.subs.values() :
     737                 subs = getattr(entity,r.property_name)
     738                 if isinstance(subs,Subs) :
     739                     break
     740                 if subs is None :
     741                     continue                
     742                 for sub in subs :
     743                     setattr(sub,r.foreign_key_column.property_name,id)
     744                     self.save(sub) 
     745             entity.to_transient()
     746 
     747         if entity.__entity_status__ == EntityState.Deleted :
     748             for r in ne.subs.values() :
     749                 subs = getattr(entity,r.property_name)
     750                 for sub in subs :
     751                     sub.to_delete()
     752                     self.save(sub)
     753             self.delete(entity)                    
     754 
     755         if entity.__entity_status__ == EntityState.Transient :
     756             id = getattr(entity,ne.key_column.property_name)
     757             for r in ne.subs.values() :
     758                 subs = getattr(entity,r.property_name)
     759                 if isinstance(subs,Subs) :
     760                     break
     761                 if subs is None :
     762                     continue                
     763                 for sub in subs :
     764                     setattr(sub,r.foreign_key_column.property_name,id)
     765                     self.save(sub)
     766 
     767     def add(self,entity) :
     768         ne = EntityManager.get_meta(entity.__class__)
     769         sql = SqlGenerator.generate_insert(ne)
     770 
     771         if isinstance(entity,Entity) :
     772             entity.create_time = datetime.now()
     773             entity.creator = configuration.user_name
     774 
     775         pars = []
     776 
     777         fields = ne.full_fields
     778         if ne.key_column.is_auto :
     779             fields = ne.fields    
     780 
     781         for c in fields.keys() :
     782             field_value = getattr(entity,c)
     783             
     784             if isinstance(field_value,Field) :
     785                 field_value = None
     786 
     787             pars.append(field_value)
     788 
     789         self.dao.execute(sql,pars)
     790 
     791         if ne.auto_column != None :
     792             auto_id = self.dao.get_last_row_id()
     793             setattr(entity, ne.auto_column.column_name,auto_id)
     794 
     795     def update(self,entity) :
     796         
     797         ne = EntityManager.get_meta(entity.__class__)
     798         sql = SqlGenerator.generate_update(ne)
     799 
     800         if isinstance(entity,Entity) :
     801             entity.update_time = datetime.now()
     802             entity.updator = configuration.user_name
     803 
     804         pars = []
     805         for c in ne.fields.keys() :
     806             field_value = getattr(entity,c)
     807             # if field_value == None :
     808             #     field_value = "null"
     809             if isinstance(field_value,Field) :
     810                 field_value = None
     811 
     812             pars.append(field_value)
     813         
     814         pars.append( getattr(entity,ne.key_column.property_name) )
     815 
     816         self.dao.execute(sql,pars)
     817         
     818     def delete(self,entity) :
     819         
     820         ne = EntityManager.get_meta(entity.__class__)
     821         sql = SqlGenerator.generate_delete(ne)
     822 
     823         pars = [ getattr(entity,ne.key_column.property_name) ] 
     824 
     825         self.dao.execute(sql,pars)
     826 
     827     def execute(self,cmd,pars = None) :
     828         return self.dao.execute(cmd,pars)
     829 
     830     def fetchone(self,cmd,pars = None) :
     831         return self.dao.fetchone(cmd,pars)
     832 
     833     def fetchall(self,cmd,pars = None) :
     834         return self.dao.fetchall(cmd,pars)
     835 
     836     def executeScalar(self,cmd,pars = None) :
     837         return self.dao.executeScalar(cmd,pars)      
     838     
     839     def commit(self) :
     840         self.dao.commit()
     841     
     842     def open(self,host=configuration.host,port=configuration.port,db=configuration.db,user=configuration.user,pwd=configuration.pwd) :
     843         
     844         self.dao = DataAccess()
     845         self.dao.open(host=host,port=port,db=db,user=user,pwd=pwd)
     846         self.isClosed = False
     847 
     848     def close(self):
     849         self.dao.close()
     850         self.dao.isClosed=True
     851 
     852 #########################################################################     
     853 
     854 class PTable(Persistable):
     855     
     856     __table_name__ = 'tables'
     857 
     858     table_catalog =  StringFiled(size=512)              #
     859     table_schema = StringFiled(size=64)                 #
     860     table_name = StringFiled(size=64, is_primary_key=True)
     861     table_type = StringFiled(size=64)                   #
     862     engine = StringFiled(size=64)                       #
     863     version = LongField()                      #
     864     row_format = StringFiled(size=10)                   #
     865     table_rows = LongField()                   #
     866     avg_row_length = LongField()               #
     867     data_length = LongField()                  #
     868     max_data_length = LongField()              #
     869     index_length = LongField()                 #
     870     data_free = LongField()                    #
     871     auto_increment = LongField()               #
     872     create_time = DateTimeFiled()                        #
     873     update_time = DateTimeFiled()                        #
     874     check_time = DateTimeFiled()                         #
     875     table_collation = StringFiled(size=32)              #
     876     checksum = LongField()                     #
     877     create_options = StringFiled(size=2048)             #
     878     table_comment = StringFiled(size=2048)              #
     879 
     880 class PColumn(Persistable):
     881     
     882     __table_name__ = 'columns'
     883 
     884     table_schema = StringFiled(size=255)
     885     table_name = StringFiled(size=255)
     886     column_name = StringFiled(size=50, is_primary_key=True)
     887     data_type = StringFiled(size=255)
     888     character_maximum_length = StringFiled(size=255) #字符类型时,字段长度
     889     column_key = StringFiled(size=255) #PRI为主键,UNI为unique,MUL是什么意思?
     890     column_comment = StringFiled(size=255) #字段说明
     891     extra = StringFiled(size=255) #'auto_increment'
     892     numeric_precision = IntField()
     893     numeric_scale= IntField()
     894 
     895 class GedColumn(Persistable):
     896     batch = IntField()
     897     dbtype = StringFiled(size=50)    
     898 
     899 class dbtype(Persistable) :
     900     
     901     __table_name__ = 'dbtype'
     902 
     903     id = IntField()                                    #
     904     code = StringFiled(size=50, is_primary_key=True)                               #
     905     name = StringFiled(size=50)                               #
     906     host = StringFiled(size=50)                               #
     907     port = IntField()                                  #
     908     user = StringFiled(size=50)                               #
     909     passwd = StringFiled(size=50)                             #
     910     db = StringFiled(size=50)                                 #
     911     charset = StringFiled(size=50)                            #    
     912 
     913 
     914 class EntityGenerator(object) :
     915 
     916     pm = Persister()
     917     dic = {}
     918 
     919     def open(self) :
     920         
     921         self.pm.open(configuration.host,configuration.port,"information_schema",configuration.user,configuration.pwd)
     922 
     923         self.dic["tinyint"] = "BoolFiled"
     924         self.dic["smallint"] = "ShortField"
     925         self.dic["mediumint"] = "IntField"
     926         self.dic["int"] = "IntField"
     927         self.dic["integer"] = "IntField"
     928         self.dic["bigint"] = "LongField"
     929         self.dic["float"] = "FloatFiled"
     930         self.dic["double"] = "DoubleFiled"
     931         self.dic["decimal"] = "DecimalFiled"
     932         self.dic["date"] = "DateTimeFiled"
     933         self.dic["time"] = "DateTimeFiled"
     934         self.dic["year"] = "IntField"
     935         self.dic["datetime"] = "DateTimeFiled"
     936         self.dic["timestamp"] = "DateTimeFiled"
     937         self.dic["char"] = "StringFiled"
     938         self.dic["varchar"] = "StringFiled"
     939         self.dic["tinyblob"] = "StringFiled" 
     940         self.dic["tinytext"] = "StringFiled"
     941         self.dic["blob"] = "StringFiled"
     942         self.dic["text"] = "StringFiled"
     943         self.dic["mediumblob"] = "BinaryFiled"
     944         self.dic["mediumtext"] = "StringFiled"
     945         self.dic["longblob"] = "BinaryFiled"
     946         self.dic["longtext"] = "StringFiled"       
     947 
     948     def close(self) :
     949         self.pm.close()
     950 
     951     # 根据数据库生成实体
     952     def generate_db(self,db_name) :
     953         ts = self.pm.query_list(PTable,"table_schema = %s",[db_name])
     954 
     955         for t in ts:
     956             self.generate_table(t.table_name,t.table_comment)
     957 
     958     def generate_table(self,table_name,memoto) : 
     959         
     960         cls_name = self.get_class_name(table_name)
     961 
     962         self.out_put( "")
     963         self.out_put(  "#%s" % memoto)
     964         self.out_put(  "class %s(Persistable) : " % cls_name)
     965         self.out_put(  "" )
     966         self.out_put(  "    __table_name__ = '%s'" % table_name )
     967         self.out_put(  "" )
     968         
     969         cs = self.pm.query_list(PColumn,"table_name = %s ",[table_name] )
     970 
     971         for c in cs:
     972             item = self.generate_column(c)
     973             self.out_put(  item )
     974 
     975     def out_put(self,txt) :
     976         print txt
     977     
     978     def get_class_name(self,table_name) :
     979         cls_name = table_name
     980         splits = cls_name.split("_")
     981         if len(splits) > 1 :
     982             items = []
     983             
     984             is_first = True
     985             for item in splits :
     986                 if is_first :
     987                     is_first = False
     988                     continue
     989                     
     990                 items.append(item[0].upper()+item[1:len(item)]) 
     991 
     992             cls_name = "".join(items)
     993 
     994         return cls_name
     995 
     996     def generate_column(self,c) : 
     997         
     998 
     999         properties = []
    1000 
    1001         if c.character_maximum_length != None :
    1002             properties.append("size = %d" % c.character_maximum_length)
    1003         
    1004         if c.data_type == "decimal" :
    1005             properties.append("size = %d" % c.numeric_precision)
    1006             properties.append("precision = %d" % c.numeric_scale)
    1007 
    1008         if c.column_key == "PRI":
    1009             properties.append("is_primary_key=True")
    1010 
    1011         if c.extra == 'auto_increment' :
    1012             properties.append( "is_auto = True" )
    1013 
    1014         item = "    %s = %s( %s )" % (c.column_name.lower(),self.dic[c.data_type],",".join(properties))
    1015         item = item.ljust(60)
    1016 
    1017         if c.column_comment != None :
    1018             item = item +"# "+c.column_comment
    1019         return item
    1020 
    1021     #把数据库表结构生成到ged的columns表中
    1022     def ged_db(self) :
    1023         
    1024         db = Persister()
    1025         db.open(configuration.host,configuration.port,"ged",configuration.user,configuration.pwd)
    1026         ds = db.query_list(dbtype)
    1027 
    1028         columns = []
    1029 
    1030         for d in ds :
    1031             self.ged_fields(d,columns)
    1032 
    1033         for column in columns :
    1034             db.add(column)
    1035 
    1036         db.commit()
    1037         db.close()
    1038 
    1039     def ged_fields(self,d,columns) :
    1040         db = Persister()
    1041         db.open(d.host,d.port,"information_schema",d.user,d.passwd)
    1042         cs = db.query_list(PColumn,"table_schema = %s",[d.db] )
    1043 
    1044         for c in cs :
    1045             gedc = GedColumn()
    1046             gedc.dbtype=d.code
    1047             gedc.column_comment=c.column_comment
    1048             gedc.column_key=c.column_key
    1049             gedc.column_name=c.column_name
    1050             gedc.data_type=c.data_type
    1051             gedc.extra=c.extra
    1052             gedc.table_name=c.table_name
    1053             gedc.table_schema=c.table_schema
    1054             gedc.character_maximum_length=c.character_maximum_length
    1055 
    1056             columns.append(gedc)
    1057 
    1058         db.close()
    1059         
    1060 #########################################################################  
    View Code

    三,test_orm.py

      1 # !/usr/bin/python
      2 # -*- coding: UTF-8 -*-
      3 
      4 import sys
      5 from dao.configuration import *
      6 from dao.orm import *
      7 import startup
      8 import logging
      9 import unittest
     10 
     11 ##############################################################
     12 class Customer(BizEntity) :
     13     '客户'
     14 
     15     __table_name__ = "ns_customer"
     16     
     17 
     18 class Product(BizEntity) :
     19     '产品'
     20 
     21     __table_name__ = "ns_product"
     22     
     23 
     24 class OrderItem(Entity) :
     25     '订单明细'
     26 
     27     __table_name__ = "ns_order_item"
     28 
     29     quantity = DecimalFiled()
     30     price = DecimalFiled()
     31     amount = DecimalFiled()   
     32 
     33     product_id = IntField()
     34     product = Reference(foreign_key = "product_id",header="产品",reference_type = Product)  
     35 
     36     order_id = IntField()    
     37 
     38 class SalesOrder(Entity) :
     39     '销售订单'
     40     
     41     __table_name__ = "ns_order"
     42 
     43     code = StringFiled()
     44     quantity = DecimalFiled()
     45     price = DecimalFiled()
     46     amount = DecimalFiled()
     47 
     48     customer_id = IntField()
     49     customer = Reference(foreign_key = "customer_id",header="客户",reference_type = Customer)
     50 
     51     items = Subs(foreign_key = "order_id",header="订单明细",sub_type = OrderItem)
     52 
     53 class Wdbtype(Persistable) :
     54     
     55     __table_name__ = 'ns_dbtype'
     56 
     57     code = StringFiled( size = 50,is_primary_key=True )     #
     58     name = StringFiled( size = 50 )                         #
     59     host = StringFiled( size = 50 )                         #
     60     port = IntField(  )                                     #
     61     user = StringFiled( size = 50 )                         #
     62     passwd = StringFiled( size = 50 )                       #
     63     db = StringFiled( size = 50 )                           #
     64     charset = StringFiled( size = 50 )                      #    
     65 
     66 ##############################################################
     67 
     68 
     69 __author__ = 'xufangbo'
     70 
     71 class orm_test(unittest.TestCase) :
     72     
     73     item_size = 10
     74     order = None
     75     
     76     def setUp(self):
     77 
     78         self.set_db()
     79         self.test_create()
     80 
     81     def tearDown(self) :
     82         
     83         db = Db()
     84         db.open()
     85 
     86         clss = [SalesOrder,OrderItem,Customer,Product,Wdbtype]
     87 
     88         for cls in clss :
     89             db.drop_table(cls)
     90 
     91         db.commit()
     92         db.close() 
     93 
     94     def test_byid(self) :
     95         
     96         order = self.order
     97         
     98         pm = Persister()
     99         pm.open()        
    100         
    101         pm.byid(order)
    102 
    103         self.assertEqual(order.code,"DD001")
    104         self.assertEqual(len(order.items),self.item_size)
    105         self.assertIsNotNone(order.create_time)
    106         self.assertIsNotNone(order.creator)
    107 
    108         self.assertIsNotNone(order.customer)
    109         self.assertIsNotNone(order.items[0].product)
    110 
    111         pm.commit()
    112         pm.close()        
    113     
    114     def test_persist(self) :
    115 
    116         pm = Persister()
    117         pm.open()
    118 
    119         order = self.order
    120 
    121         order.code="dd002"
    122         order.to_persist()
    123 
    124         order.items[1].to_delete()
    125         order.items[2].price = 18
    126         order.items[2].to_persist()
    127 
    128         pm.save(order)
    129 
    130         self.assertIsNotNone(order.update_time)
    131         self.assertIsNotNone(order.updator)    
    132 
    133         pm.byid(order)
    134 
    135         self.assertEqual(len(order.items),self.item_size-1)
    136         self.assertEqual(order.code,"dd002")
    137         self.assertIsNotNone(order.update_time)
    138         self.assertIsNotNone(order.updator)
    139 
    140         pm.commit()
    141         pm.close()
    142 
    143     def test_delete(self) :
    144     
    145         pm = Persister()
    146         pm.open()
    147 
    148         order = self.order
    149 
    150         order.to_delete()
    151         pm.save(order)
    152 
    153         o = pm.byid(order)
    154         self.assertIsNone(o)
    155 
    156         pm.commit()
    157         pm.close()       
    158 
    159         #=============================
    160         dao = DataAccess()
    161         dao.open()
    162         
    163         sql = 'select count(0) from %s where order_id = %s' % (EntityManager.get_meta(OrderItem).table_name,'%s')
    164         pars = [order.id]
    165         count = int(dao.executeScalar(sql,pars))
    166 
    167         self.assertEqual(count,0)
    168 
    169         dao.close()            
    170 
    171     def test_create(self) :
    172         
    173         pm = Persister()
    174         pm.open()
    175 
    176         customer = Customer()
    177         customer.code = "C001"
    178         customer.name="北京千舟科技发展有限公司"
    179         pm.save(customer)
    180 
    181         product = Product()
    182         product.code = "P001"
    183         product.name="电商ERP标准版4.0"
    184         pm.save(product)
    185 
    186         order = SalesOrder();
    187         order.code="DD001"
    188         order.quantity=2
    189         order.price =2 
    190         order.amount=4.6
    191         order.customer = customer
    192         order.items = []
    193 
    194         for i in range(self.item_size) :
    195             item = OrderItem()
    196             item.quantity=i+1
    197             item.price =i+1
    198             item.amount= (i+1)*(i+1)
    199             item.product = product
    200             order.items.append(item)            
    201 
    202         pm.save(order)
    203 
    204         self.order = order
    205 
    206         self.assertIsNotNone(order.id)
    207         self.assertIsNotNone(order.create_time)
    208         self.assertIsNotNone(order.creator)
    209 
    210         pm.commit()
    211         pm.close()
    212 
    213     def test_query(self) :
    214         
    215         pm = Persister()
    216         pm.open()
    217 
    218         orders = pm.query_list(SalesOrder)
    219         count1 = len(orders)
    220 
    221         customer = Customer()
    222         customer.code = "C003"
    223         customer.name="北京千舟科技发展有限公司"
    224         pm.save(customer)
    225 
    226         product = Product()
    227         product.code = "P003"
    228         product.name="电商ERP标准版4.0"
    229         pm.save(product)
    230 
    231         for i in range(self.item_size) :
    232             order = SalesOrder();
    233             order.code="DD001"
    234             order.quantity=2
    235             order.price =2 
    236             order.amount=4.6
    237             order.customer = customer
    238             order.items = []
    239 
    240             pm.save(order)
    241 
    242         orders = pm.query_list(SalesOrder)
    243         count2 = len(orders)
    244 
    245         self.assertEquals( count1 +self.item_size,count2 ) 
    246 
    247         orders = pm.query_list(SalesOrder,"customer.code = %s",["C003"])
    248         count3 = len(orders)
    249         self.assertEqual(count3,self.item_size)
    250 
    251         pm.commit()
    252         pm.close()
    253     
    254     def test_customer_primary_key(self) :
    255         pm = Persister()
    256         pm.open()
    257 
    258         dt = Wdbtype(code='zl44',name='专利',host='mysql',port=3310,db='patent',user='root',passwd='mysql123',charset='utf8')
    259 
    260         pm.add(dt)
    261         pm.delete(dt)
    262 
    263         pm.commit()
    264         pm.close()        
    265 
    266     def set_db(self) :
    267         
    268         db = Db()
    269         db.open()
    270 
    271         clss = [SalesOrder,OrderItem,Customer,Product,Wdbtype]
    272 
    273         for cls in clss :
    274             db.drop_table(cls)
    275             db.create_table(cls)
    276 
    277         db.commit()
    278         db.close() 
    279 
    280 class EntityGeneratorTest(unittest.TestCase) :
    281 
    282     def test_generator(self) :
    283         
    284         generator = EntityGenerator()
    285         generator.open()
    286         generator.generate_db("wolf")
    287         generator.close()        
    288 
    289 
    290 if __name__ == '__main__':
    291     unittest.main()
    View Code

    作者    :秋时

    本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接。

  • 相关阅读:
    Java中替换字符串中特定字符,replaceAll,replace,replaceFirst的区别
    牛客剑指offer 67题(持续更新~)
    从尾到头打印链表
    字符串变形
    缩写
    删除公共字符
    替换空格
    二维数组的查找
    acm博弈论基础总结
    acm模板总结
  • 原文地址:https://www.cnblogs.com/Netsharp/p/8463517.html
Copyright © 2020-2023  润新知