python下的orm使用SQLAlchemy比较多,用了一段时间感觉不顺手,主要问题是SQLAlchemy太重,所以自己写了一个orm,实现方式和netsharp类似,oql部分因为代码比较多,没有完全实现
下面是源代码
一,系统配置
configuration.py
# !/usr/bin/python # -*- coding: UTF-8 -*- host="192.168.4.1" port=3306 db="wolf" user="root" pwd="xxx" user_name="xxx"
二,orm.py
1 # !/usr/bin/python 2 # -*- coding: UTF-8 -*- 3 4 import sys 5 from datetime import * 6 from enum import Enum 7 import logging 8 9 import MySQLdb 10 import MySQLdb.cursors 11 12 import configuration 13 14 ######################################################################### 15 16 class Field(object): 17 18 property_name = None 19 column_name = None 20 group_name = None 21 column_type_name = None 22 header = None 23 memoto = None 24 25 is_primary_key = None 26 is_auto = False 27 is_name_equals = None 28 is_required = None 29 is_unique = None 30 31 size = None 32 precision = None 33 34 def __init__(self,**kw): 35 for k,v in kw.iteritems(): 36 setattr(self,k,v) 37 38 class ShortField(Field) : 39 40 def __init__(self,**kw): 41 self.column_type_name="smallint" 42 super(ShortField,self).__init__(**kw) 43 44 45 class IntField(Field) : 46 47 def __init__(self,**kw): 48 self.column_type_name="int" 49 super(IntField,self).__init__(**kw) 50 51 52 class LongField(Field) : 53 54 def __init__(self,**kw): 55 self.column_type_name="bigint" 56 super(LongField,self).__init__(**kw) 57 58 59 class StringFiled(Field) : 60 61 def __init__(self,**kw): 62 self.column_type_name="nvarchar" 63 self.size=50 64 super(StringFiled,self).__init__(**kw) 65 66 67 class BoolFiled(Field) : 68 def __init__(self,**kw): 69 self.column_type_name="bool" 70 super(BoolFiled,self).__init__(**kw) 71 72 73 class FloatFiled(Field) : 74 def __init__(self,**kw): 75 self.column_type_name="float" 76 self.size=8 77 self.precision=4 78 super(FloatFiled,self).__init__(**kw) 79 80 81 class DoubleFiled(Field) : 82 def __init__(self,**kw): 83 self.column_type_name="double" 84 self.size=8 85 self.precision=4 86 super(DoubleFiled,self).__init__(**kw) 87 88 89 class DecimalFiled(Field) : 90 def __init__(self,**kw): 91 self.column_type_name="decimal" 92 self.size=14 93 self.precision=8 94 super(DecimalFiled,self).__init__(**kw) 95 96 class DateTimeFiled(Field) : 97 def __init__(self,**kw): 98 self.column_type_name="datetime" 99 super(DateTimeFiled,self).__init__(**kw) 100 101 class BinaryFiled(Field) : 102 def __init__(self,**kw): 103 self.column_type_name="longblob" 104 super(BinaryFiled,self).__init__(**kw) 105 106 class Reference(object) : 107 108 property_name = None 109 header = None 110 foreign_key = None 111 primary_key = None 112 113 reference_type = None 114 115 def __init__(self,**kw): 116 for k,v in kw.iteritems(): 117 setattr(self,k,v) 118 119 class Subs(object) : 120 121 property_name = None 122 header = None 123 foreign_key = None 124 primary_key = None 125 126 sub_type = None 127 128 def __init__(self,**kw): 129 for k,v in kw.iteritems(): 130 setattr(self,k,v) 131 132 ######################################################################### 133 134 class EntityState(Enum): 135 # 瞬时状态,不受数据库管理的状态 136 # 或者是事务提交后不需要更新的实体 137 Transient = 0 138 # 事务提交后将被新增 139 New = 1 140 141 # 事务提交后将被修更新 142 Persist = 2 143 144 # 事务提交后将被删除 145 Deleted = 3 146 147 class Persistable(object) : 148 149 __entity_status__ = EntityState.New 150 151 def __init__(self,**kw): 152 for k,v in kw.iteritems(): 153 setattr(self,k,v) 154 155 def to_new(self) : 156 self.__entity_status__ = EntityState.New 157 158 def to_persist(self) : 159 self.__entity_status__ = EntityState.Persist 160 161 def to_delete(self) : 162 self.__entity_status__ = EntityState.Deleted 163 164 def to_transient(self) : 165 self.__entity_status__ = EntityState.Transient 166 167 168 class Entity(Persistable) : 169 170 id = IntField(is_primary_key=True,is_auto=True) 171 creator = StringFiled(size=100) 172 create_time = DateTimeFiled() 173 update_time = DateTimeFiled() 174 updator = StringFiled(size=100) 175 176 class BizEntity(Entity) : 177 178 code = StringFiled() 179 name = StringFiled(size=100) 180 memoto = StringFiled(size=500) 181 182 ######################################################################### 183 184 class NColumn : 185 186 property_name = None 187 column_name = None 188 group_name = None 189 column_type_name = None 190 header = None 191 memoto = None 192 column_type_name = None 193 194 is_primary_key = None 195 is_auto = False 196 is_name_equals = None 197 is_required = None 198 is_unique = None 199 200 size = None 201 precision = None 202 203 def __repr__(self): 204 return self.property_name + "["+self.column_name+"]" 205 206 207 class NEntity : 208 209 name = None 210 table_name = None 211 entity_id = None 212 header= None 213 is_view= None 214 is_refcheck= None 215 key_column = None 216 auto_column = None 217 order_by= None 218 type = None 219 220 columns = {} 221 fields = {} 222 subs = {} 223 references = {} 224 225 # def __repr__(self): 226 # return self.__name__ + "["+self.table_name+"]" 227 228 class NReference : 229 header = None 230 foreign_key = None 231 primary_key = None 232 233 foreign_key_column = None 234 primary_key_column = None 235 236 reference_type = None 237 238 class NSubs : 239 header = None 240 foreign_key = None 241 primary_key = None 242 243 foreign_key_column = None 244 primary_key_column = None 245 246 sub_type = None 247 248 class EntityManager : 249 250 entityMap = {} 251 252 @classmethod 253 def get_meta(cls,type) : 254 255 256 ne = cls.entityMap.get(type.__name__) 257 if ne == None : 258 ne = cls.parse_entity(type) 259 cls.entityMap[type.__name__]= ne 260 261 return ne 262 263 @classmethod 264 def parse_entity(cls,type) : 265 266 ne = NEntity() 267 ne.type = type 268 269 ne.table_name= type.__table_name__ 270 ne.name = type.__name__ 271 ne.entity_id=type.__name__ 272 ne.header= type.__doc__ 273 ne.is_view= False 274 ne.is_refcheck= False 275 276 ne.key_column = None 277 ne.auto_column = None 278 ne.order_by= None 279 280 ne.columns = {} 281 ne.fields = {} 282 ne.full_columns = {} 283 ne.full_fields = {} 284 ne.subs = {} 285 ne.references = {} 286 287 for k in dir(type) : 288 289 v = getattr(type,k) 290 291 if isinstance(v, Field) : 292 293 c = cls.parse_field(k,v) 294 295 ne.full_fields[k]=c 296 ne.full_columns[c.column_name] = c 297 298 if c.is_primary_key : 299 ne.key_column = c 300 else : 301 ne.fields[k]=c 302 ne.columns[c.column_name] = c 303 304 if c.is_auto : 305 ne.auto_column = c 306 307 if isinstance(v, Reference) : 308 r = cls.parse_reference(k,v) 309 ne.references[k]=r 310 311 if isinstance(v, Subs) : 312 s = cls.parse_subs(k,v) 313 ne.subs[k]=s 314 315 for r in ne.references.values() : 316 r.foreign_key_column = ne.full_fields[r.foreign_key] 317 318 for r in ne.subs.values() : 319 r.primary_key_column = ne.key_column 320 321 return ne 322 323 @classmethod 324 def parse_field(cls,name,field) : 325 326 c = NColumn() 327 328 c.property_name = name 329 c.column_name = name 330 331 if field.column_name != None : 332 c.column_name = name 333 334 c.is_primary_key = False 335 c.is_auto= False 336 c.is_name_equals = True 337 c.is_required = False 338 c.group_name = "" 339 c.column_type_name = field.column_type_name 340 c.size= field.size 341 c.precision = field.precision 342 c.group_name = field.group_name 343 c.is_primary_key = field.is_primary_key 344 c.is_auto = field.is_auto 345 346 return c 347 348 @classmethod 349 def parse_reference(cls,name,field) : 350 r = NReference() 351 352 r.property_name = name 353 r.foreign_key = field.foreign_key 354 r.header = field.header 355 r.primary_key = field.primary_key 356 357 if r.primary_key == None : 358 r.primary_key = "id" 359 360 r.reference_type = field.reference_type 361 362 r.foreign_key_column = None 363 364 rne = EntityManager.get_meta( field.reference_type ) 365 r.primary_key_column = rne.full_fields[r.primary_key] 366 367 return r 368 369 @classmethod 370 def parse_subs(cls,name,field) : 371 372 s = NSubs() 373 374 s.property_name = name 375 s.foreign_key = field.foreign_key 376 s.header = field.header 377 s.primary_key = field.primary_key 378 379 if s.primary_key == None : 380 s.primary_key = "id" 381 382 s.primary_key_column = None 383 384 s.sub_type = field.sub_type 385 386 rne = EntityManager.get_meta( field.sub_type ) 387 s.foreign_key_column = rne.full_fields[field.foreign_key] 388 389 return s 390 391 class ORMException(Exception): 392 def __init__(self,value): 393 self.value=value 394 395 ######################################################################### 396 397 class DataAccess : 398 399 conn = None 400 cursor = None 401 isClosed = True 402 403 def open(self,host=configuration.host,port=configuration.port,db=configuration.db,user=configuration.user,pwd=configuration.pwd) : 404 405 self.conn = MySQLdb.connect(host=host,port=port,db=db,user=user,passwd=pwd,charset="utf8") 406 self.cursor = self.conn.cursor() 407 self.isClosed = False 408 409 def execute(self,cmd,pars = None) : 410 if self.isClosed : 411 raise Exception("db is not opened!") 412 413 logging.info(cmd) 414 415 ret = self.cursor.execute(cmd,pars) 416 return ret 417 418 def get_last_row_id(self) : 419 return int(self.cursor.lastrowid) 420 421 def fetchone(self,cmd,pars = None) : 422 if self.isClosed : 423 raise Exception("db is not opened!") 424 425 logging.info(cmd) 426 self.cursor.execute(cmd,pars) 427 row = self.cursor.fetchone() 428 429 return row 430 431 def fetchall(self,cmd,pars = None) : 432 if self.isClosed : 433 raise Exception("db is not opened!") 434 435 logging.info(cmd) 436 self.cursor.execute(cmd,pars) 437 rows = self.cursor.fetchall() 438 439 return rows 440 441 def executeScalar(self,cmd,pars = None) : 442 if self.isClosed : 443 raise Exception("db is not opened!") 444 445 self.cursor.execute(cmd,pars) 446 row = self.cursor.fetchone() 447 448 if row == None : 449 return None 450 if len(row)==0 : 451 return None 452 return row[0] 453 454 455 def commit(self): 456 self.conn.commit() 457 458 def rolback(self): 459 self.conn.rolback() 460 461 def close(self) : 462 if self.isClosed : 463 pass; 464 465 self.conn.close() 466 self.isClosed=True 467 self.conn = None 468 self.cursor = None 469 470 ######################################################################### 471 472 class SqlGenerator(object): 473 474 @classmethod 475 def generate_insert(cls,ne) : 476 477 columns = ne.full_columns 478 if ne.key_column.is_auto : 479 columns = ne.columns 480 481 sql = 'insert into %s (%s) values (%s);' % ( ne.table_name, ', '.join(columns.keys()), ', '.join(['%s'] * len(columns))) 482 483 return sql 484 485 @classmethod 486 def generate_update(cls,ne) : 487 488 sets = [] 489 490 for c in ne.columns.keys() : 491 sets.append("%s = %s" % (c,'%s')) 492 493 sql = 'update %s set %s where %s = %s;' % ( ne.table_name, ",".join(sets),ne.key_column.column_name,"%s") 494 return sql 495 496 @classmethod 497 def generate_delete(cls,ne) : 498 499 sql = "delete from %s where %s = %s" % (ne.table_name,ne.key_column.column_name,"%s") 500 501 return sql 502 503 @classmethod 504 def generate_byid(cls,ne) : 505 506 sql = "select %s from %s where %s = %s" % (",".join(ne.full_columns.keys()),ne.table_name,ne.key_column.column_name,'%s') 507 508 return sql 509 510 class Db(object) : 511 512 dao = DataAccess() 513 514 def open(self) : 515 self.dao.open() 516 517 def create_db(self,db_name) : 518 519 sql = "CREATE SCHEMA %s DEFAULT CHARACTER SET ucs2 ;" % db_name 520 self.dao.execute(sql) 521 522 def drop_db(self,db_name) : 523 sql ="DROP DATABASE %s;" % db_name 524 self.dao.execute(sql) 525 526 def create_table(self,cls) : 527 528 ne= EntityManager.get_meta(cls) 529 530 columns = [] 531 for column in ne.full_columns.values() : 532 c = self.generate_column(column) 533 columns.append(c) 534 535 sql ="create table %s( %s )" % (ne.table_name,",".join(columns)) 536 537 self.dao.execute(sql,None) 538 539 def generate_column(self,column) : 540 sql = "%s %s" % (column.column_name , column.column_type_name) 541 542 if column.size != None : 543 if column.precision != None : 544 sql += "(%d,%d)" % (column.size ,column.precision) 545 else: 546 sql += "(%d)" % column.size 547 548 if column.is_auto : 549 sql += " AUTO_INCREMENT" 550 551 if column.is_primary_key : 552 sql +=" PRIMARY KEY" 553 554 return sql 555 556 def drop_table(self,cls) : 557 ne= EntityManager.get_meta(cls) 558 sql ="drop table if exists %s" % ne.table_name 559 560 self.dao.execute(sql,None) 561 562 def commit(self) : 563 self.dao.commit() 564 565 def close(self) : 566 self.dao.close() 567 568 ######################################################################### 569 570 class SetQuery(object) : 571 572 dao = None 573 574 def __init__(self,dao) : 575 self.dao = dao 576 577 def query(self,ne,entity) : 578 579 id = getattr(entity,ne.key_column.property_name) 580 if id == None : 581 raise ORMException("%s.id不能为空" % ne.table_name) 582 583 #当前实体 584 pars = [ id ] 585 sql = SqlGenerator.generate_byid(ne) 586 row = self.dao.fetchone(sql,pars) 587 if row == None : 588 entity = None 589 return None 590 entity = self.read_row(ne,row,entity) 591 592 self.query_iter(entity,ne) 593 594 return entity 595 596 def query_iter(self,entity,ne) : 597 598 #查询引用实体 599 for rm in ne.references.values() : 600 filter = "%s = %s" % (rm.primary_key_column.column_name,"%s") 601 pars = [getattr(entity,rm.foreign_key)] 602 re = EntityManager.get_meta(rm.reference_type) 603 604 sql = SqlGenerator.generate_byid(re) 605 row = self.dao.fetchone(sql,pars) 606 ref = self.read_row(re,row) 607 setattr(entity,rm.property_name,ref) 608 609 #查询子实体 610 for sm in ne.subs.values() : 611 filter = "%s = %s" % (sm.foreign_key_column.column_name,"%s") 612 pars = [ getattr(entity,ne.key_column.property_name) ] 613 se = EntityManager.get_meta(sm.sub_type) 614 subs = self.do_query(se,filter,pars) 615 616 setattr(entity,sm.property_name,subs) 617 618 for sub in subs : 619 self.query_iter(sub,se) 620 621 def do_query(self,ne,filter,pars) : 622 623 sql = "select %s from %s where %s" % (",".join(ne.full_columns.keys()),ne.table_name,filter) 624 rows = self.dao.fetchall(sql,pars) 625 626 entities = [] 627 for row in rows : 628 entity = self.read_row(ne,row) 629 entities.append(entity) 630 631 return entities 632 633 def read_row(self,ne,row,entity = None) : 634 635 if entity == None : 636 entity = ne.type() 637 638 index = 0; 639 for c in ne.full_columns.values() : 640 cell_value = row[index] 641 setattr(entity,c.property_name,cell_value) 642 index +=1 643 644 return entity 645 ######################################################################### 646 647 class Persister(object) : 648 649 isClosed = True 650 dao = DataAccess() 651 652 def query_first(self,type,filters=None,pars=None): 653 654 lis = self.query_list(type,filters,pars) 655 if len(lis) == 0 : 656 return None 657 658 entity = lis[0] 659 entity = self.byid(entity) 660 661 return entity 662 663 #实体列表查询 664 #只支持基于主实体的一级查询 665 def query_list(self,type,filters=None,pars=None) : 666 667 ne = EntityManager.get_meta( type ) 668 669 columns = [] 670 for key in ne.full_columns.keys() : 671 columns.append(ne.name+"."+key) 672 673 sqls =["select %s from %s as %s" % (",".join(columns),ne.table_name,ne.name)] 674 for r in ne.references.values() : 675 re = EntityManager.get_meta(r.reference_type) 676 sqls.append( "left join %s as %s on %s.%s = %s.%s" % (re.table_name, r.property_name,ne.name,r.foreign_key_column.column_name,r.property_name,re.key_column.column_name)) 677 678 if filters != None and filters.strip() != "" : 679 sqls.append("where " + filters) 680 681 sql = " ".join(sqls) 682 683 setter = SetQuery(self.dao) 684 685 rows = self.dao.fetchall(sql,pars) 686 entities = [] 687 688 for row in rows : 689 entity = setter.read_row(ne,row) 690 entities.append(entity) 691 692 return entities 693 694 def byid(self,entity) : 695 696 ne = EntityManager.get_meta(entity.__class__) 697 query = SetQuery(self.dao) 698 entity = query.query(ne,entity) 699 700 return entity 701 702 def save(self,entity) : 703 704 if entity.__entity_status__ == None : 705 raise ORMException("调用save方法必须设置实体__entity_state__") 706 707 ne = EntityManager.get_meta(entity.__class__) 708 709 for r in ne.references.values() : 710 ref = getattr(entity,r.property_name) 711 fk = getattr(entity,r.foreign_key) 712 if (fk == None or isinstance(fk,Field)) and ref != None : 713 pk = getattr(ref,r.primary_key) 714 setattr(entity,r.foreign_key,pk) 715 716 if entity.__entity_status__ == EntityState.New : 717 self.add(entity) 718 id = getattr(entity,ne.key_column.property_name) 719 720 for r in ne.subs.values() : 721 subs = getattr(entity,r.property_name) 722 if isinstance(subs,Subs) : 723 break 724 if subs is None : 725 continue 726 for sub in subs : 727 sub.to_new() 728 setattr(sub,r.foreign_key_column.property_name,id) 729 self.save(sub) 730 entity.to_transient() 731 732 if entity.__entity_status__ == EntityState.Persist : 733 self.update(entity) 734 id = getattr(entity,ne.key_column.property_name) 735 736 for r in ne.subs.values() : 737 subs = getattr(entity,r.property_name) 738 if isinstance(subs,Subs) : 739 break 740 if subs is None : 741 continue 742 for sub in subs : 743 setattr(sub,r.foreign_key_column.property_name,id) 744 self.save(sub) 745 entity.to_transient() 746 747 if entity.__entity_status__ == EntityState.Deleted : 748 for r in ne.subs.values() : 749 subs = getattr(entity,r.property_name) 750 for sub in subs : 751 sub.to_delete() 752 self.save(sub) 753 self.delete(entity) 754 755 if entity.__entity_status__ == EntityState.Transient : 756 id = getattr(entity,ne.key_column.property_name) 757 for r in ne.subs.values() : 758 subs = getattr(entity,r.property_name) 759 if isinstance(subs,Subs) : 760 break 761 if subs is None : 762 continue 763 for sub in subs : 764 setattr(sub,r.foreign_key_column.property_name,id) 765 self.save(sub) 766 767 def add(self,entity) : 768 ne = EntityManager.get_meta(entity.__class__) 769 sql = SqlGenerator.generate_insert(ne) 770 771 if isinstance(entity,Entity) : 772 entity.create_time = datetime.now() 773 entity.creator = configuration.user_name 774 775 pars = [] 776 777 fields = ne.full_fields 778 if ne.key_column.is_auto : 779 fields = ne.fields 780 781 for c in fields.keys() : 782 field_value = getattr(entity,c) 783 784 if isinstance(field_value,Field) : 785 field_value = None 786 787 pars.append(field_value) 788 789 self.dao.execute(sql,pars) 790 791 if ne.auto_column != None : 792 auto_id = self.dao.get_last_row_id() 793 setattr(entity, ne.auto_column.column_name,auto_id) 794 795 def update(self,entity) : 796 797 ne = EntityManager.get_meta(entity.__class__) 798 sql = SqlGenerator.generate_update(ne) 799 800 if isinstance(entity,Entity) : 801 entity.update_time = datetime.now() 802 entity.updator = configuration.user_name 803 804 pars = [] 805 for c in ne.fields.keys() : 806 field_value = getattr(entity,c) 807 # if field_value == None : 808 # field_value = "null" 809 if isinstance(field_value,Field) : 810 field_value = None 811 812 pars.append(field_value) 813 814 pars.append( getattr(entity,ne.key_column.property_name) ) 815 816 self.dao.execute(sql,pars) 817 818 def delete(self,entity) : 819 820 ne = EntityManager.get_meta(entity.__class__) 821 sql = SqlGenerator.generate_delete(ne) 822 823 pars = [ getattr(entity,ne.key_column.property_name) ] 824 825 self.dao.execute(sql,pars) 826 827 def execute(self,cmd,pars = None) : 828 return self.dao.execute(cmd,pars) 829 830 def fetchone(self,cmd,pars = None) : 831 return self.dao.fetchone(cmd,pars) 832 833 def fetchall(self,cmd,pars = None) : 834 return self.dao.fetchall(cmd,pars) 835 836 def executeScalar(self,cmd,pars = None) : 837 return self.dao.executeScalar(cmd,pars) 838 839 def commit(self) : 840 self.dao.commit() 841 842 def open(self,host=configuration.host,port=configuration.port,db=configuration.db,user=configuration.user,pwd=configuration.pwd) : 843 844 self.dao = DataAccess() 845 self.dao.open(host=host,port=port,db=db,user=user,pwd=pwd) 846 self.isClosed = False 847 848 def close(self): 849 self.dao.close() 850 self.dao.isClosed=True 851 852 ######################################################################### 853 854 class PTable(Persistable): 855 856 __table_name__ = 'tables' 857 858 table_catalog = StringFiled(size=512) # 859 table_schema = StringFiled(size=64) # 860 table_name = StringFiled(size=64, is_primary_key=True) 861 table_type = StringFiled(size=64) # 862 engine = StringFiled(size=64) # 863 version = LongField() # 864 row_format = StringFiled(size=10) # 865 table_rows = LongField() # 866 avg_row_length = LongField() # 867 data_length = LongField() # 868 max_data_length = LongField() # 869 index_length = LongField() # 870 data_free = LongField() # 871 auto_increment = LongField() # 872 create_time = DateTimeFiled() # 873 update_time = DateTimeFiled() # 874 check_time = DateTimeFiled() # 875 table_collation = StringFiled(size=32) # 876 checksum = LongField() # 877 create_options = StringFiled(size=2048) # 878 table_comment = StringFiled(size=2048) # 879 880 class PColumn(Persistable): 881 882 __table_name__ = 'columns' 883 884 table_schema = StringFiled(size=255) 885 table_name = StringFiled(size=255) 886 column_name = StringFiled(size=50, is_primary_key=True) 887 data_type = StringFiled(size=255) 888 character_maximum_length = StringFiled(size=255) #字符类型时,字段长度 889 column_key = StringFiled(size=255) #PRI为主键,UNI为unique,MUL是什么意思? 890 column_comment = StringFiled(size=255) #字段说明 891 extra = StringFiled(size=255) #'auto_increment' 892 numeric_precision = IntField() 893 numeric_scale= IntField() 894 895 class GedColumn(Persistable): 896 batch = IntField() 897 dbtype = StringFiled(size=50) 898 899 class dbtype(Persistable) : 900 901 __table_name__ = 'dbtype' 902 903 id = IntField() # 904 code = StringFiled(size=50, is_primary_key=True) # 905 name = StringFiled(size=50) # 906 host = StringFiled(size=50) # 907 port = IntField() # 908 user = StringFiled(size=50) # 909 passwd = StringFiled(size=50) # 910 db = StringFiled(size=50) # 911 charset = StringFiled(size=50) # 912 913 914 class EntityGenerator(object) : 915 916 pm = Persister() 917 dic = {} 918 919 def open(self) : 920 921 self.pm.open(configuration.host,configuration.port,"information_schema",configuration.user,configuration.pwd) 922 923 self.dic["tinyint"] = "BoolFiled" 924 self.dic["smallint"] = "ShortField" 925 self.dic["mediumint"] = "IntField" 926 self.dic["int"] = "IntField" 927 self.dic["integer"] = "IntField" 928 self.dic["bigint"] = "LongField" 929 self.dic["float"] = "FloatFiled" 930 self.dic["double"] = "DoubleFiled" 931 self.dic["decimal"] = "DecimalFiled" 932 self.dic["date"] = "DateTimeFiled" 933 self.dic["time"] = "DateTimeFiled" 934 self.dic["year"] = "IntField" 935 self.dic["datetime"] = "DateTimeFiled" 936 self.dic["timestamp"] = "DateTimeFiled" 937 self.dic["char"] = "StringFiled" 938 self.dic["varchar"] = "StringFiled" 939 self.dic["tinyblob"] = "StringFiled" 940 self.dic["tinytext"] = "StringFiled" 941 self.dic["blob"] = "StringFiled" 942 self.dic["text"] = "StringFiled" 943 self.dic["mediumblob"] = "BinaryFiled" 944 self.dic["mediumtext"] = "StringFiled" 945 self.dic["longblob"] = "BinaryFiled" 946 self.dic["longtext"] = "StringFiled" 947 948 def close(self) : 949 self.pm.close() 950 951 # 根据数据库生成实体 952 def generate_db(self,db_name) : 953 ts = self.pm.query_list(PTable,"table_schema = %s",[db_name]) 954 955 for t in ts: 956 self.generate_table(t.table_name,t.table_comment) 957 958 def generate_table(self,table_name,memoto) : 959 960 cls_name = self.get_class_name(table_name) 961 962 self.out_put( "") 963 self.out_put( "#%s" % memoto) 964 self.out_put( "class %s(Persistable) : " % cls_name) 965 self.out_put( "" ) 966 self.out_put( " __table_name__ = '%s'" % table_name ) 967 self.out_put( "" ) 968 969 cs = self.pm.query_list(PColumn,"table_name = %s ",[table_name] ) 970 971 for c in cs: 972 item = self.generate_column(c) 973 self.out_put( item ) 974 975 def out_put(self,txt) : 976 print txt 977 978 def get_class_name(self,table_name) : 979 cls_name = table_name 980 splits = cls_name.split("_") 981 if len(splits) > 1 : 982 items = [] 983 984 is_first = True 985 for item in splits : 986 if is_first : 987 is_first = False 988 continue 989 990 items.append(item[0].upper()+item[1:len(item)]) 991 992 cls_name = "".join(items) 993 994 return cls_name 995 996 def generate_column(self,c) : 997 998 999 properties = [] 1000 1001 if c.character_maximum_length != None : 1002 properties.append("size = %d" % c.character_maximum_length) 1003 1004 if c.data_type == "decimal" : 1005 properties.append("size = %d" % c.numeric_precision) 1006 properties.append("precision = %d" % c.numeric_scale) 1007 1008 if c.column_key == "PRI": 1009 properties.append("is_primary_key=True") 1010 1011 if c.extra == 'auto_increment' : 1012 properties.append( "is_auto = True" ) 1013 1014 item = " %s = %s( %s )" % (c.column_name.lower(),self.dic[c.data_type],",".join(properties)) 1015 item = item.ljust(60) 1016 1017 if c.column_comment != None : 1018 item = item +"# "+c.column_comment 1019 return item 1020 1021 #把数据库表结构生成到ged的columns表中 1022 def ged_db(self) : 1023 1024 db = Persister() 1025 db.open(configuration.host,configuration.port,"ged",configuration.user,configuration.pwd) 1026 ds = db.query_list(dbtype) 1027 1028 columns = [] 1029 1030 for d in ds : 1031 self.ged_fields(d,columns) 1032 1033 for column in columns : 1034 db.add(column) 1035 1036 db.commit() 1037 db.close() 1038 1039 def ged_fields(self,d,columns) : 1040 db = Persister() 1041 db.open(d.host,d.port,"information_schema",d.user,d.passwd) 1042 cs = db.query_list(PColumn,"table_schema = %s",[d.db] ) 1043 1044 for c in cs : 1045 gedc = GedColumn() 1046 gedc.dbtype=d.code 1047 gedc.column_comment=c.column_comment 1048 gedc.column_key=c.column_key 1049 gedc.column_name=c.column_name 1050 gedc.data_type=c.data_type 1051 gedc.extra=c.extra 1052 gedc.table_name=c.table_name 1053 gedc.table_schema=c.table_schema 1054 gedc.character_maximum_length=c.character_maximum_length 1055 1056 columns.append(gedc) 1057 1058 db.close() 1059 1060 #########################################################################
三,test_orm.py
1 # !/usr/bin/python 2 # -*- coding: UTF-8 -*- 3 4 import sys 5 from dao.configuration import * 6 from dao.orm import * 7 import startup 8 import logging 9 import unittest 10 11 ############################################################## 12 class Customer(BizEntity) : 13 '客户' 14 15 __table_name__ = "ns_customer" 16 17 18 class Product(BizEntity) : 19 '产品' 20 21 __table_name__ = "ns_product" 22 23 24 class OrderItem(Entity) : 25 '订单明细' 26 27 __table_name__ = "ns_order_item" 28 29 quantity = DecimalFiled() 30 price = DecimalFiled() 31 amount = DecimalFiled() 32 33 product_id = IntField() 34 product = Reference(foreign_key = "product_id",header="产品",reference_type = Product) 35 36 order_id = IntField() 37 38 class SalesOrder(Entity) : 39 '销售订单' 40 41 __table_name__ = "ns_order" 42 43 code = StringFiled() 44 quantity = DecimalFiled() 45 price = DecimalFiled() 46 amount = DecimalFiled() 47 48 customer_id = IntField() 49 customer = Reference(foreign_key = "customer_id",header="客户",reference_type = Customer) 50 51 items = Subs(foreign_key = "order_id",header="订单明细",sub_type = OrderItem) 52 53 class Wdbtype(Persistable) : 54 55 __table_name__ = 'ns_dbtype' 56 57 code = StringFiled( size = 50,is_primary_key=True ) # 58 name = StringFiled( size = 50 ) # 59 host = StringFiled( size = 50 ) # 60 port = IntField( ) # 61 user = StringFiled( size = 50 ) # 62 passwd = StringFiled( size = 50 ) # 63 db = StringFiled( size = 50 ) # 64 charset = StringFiled( size = 50 ) # 65 66 ############################################################## 67 68 69 __author__ = 'xufangbo' 70 71 class orm_test(unittest.TestCase) : 72 73 item_size = 10 74 order = None 75 76 def setUp(self): 77 78 self.set_db() 79 self.test_create() 80 81 def tearDown(self) : 82 83 db = Db() 84 db.open() 85 86 clss = [SalesOrder,OrderItem,Customer,Product,Wdbtype] 87 88 for cls in clss : 89 db.drop_table(cls) 90 91 db.commit() 92 db.close() 93 94 def test_byid(self) : 95 96 order = self.order 97 98 pm = Persister() 99 pm.open() 100 101 pm.byid(order) 102 103 self.assertEqual(order.code,"DD001") 104 self.assertEqual(len(order.items),self.item_size) 105 self.assertIsNotNone(order.create_time) 106 self.assertIsNotNone(order.creator) 107 108 self.assertIsNotNone(order.customer) 109 self.assertIsNotNone(order.items[0].product) 110 111 pm.commit() 112 pm.close() 113 114 def test_persist(self) : 115 116 pm = Persister() 117 pm.open() 118 119 order = self.order 120 121 order.code="dd002" 122 order.to_persist() 123 124 order.items[1].to_delete() 125 order.items[2].price = 18 126 order.items[2].to_persist() 127 128 pm.save(order) 129 130 self.assertIsNotNone(order.update_time) 131 self.assertIsNotNone(order.updator) 132 133 pm.byid(order) 134 135 self.assertEqual(len(order.items),self.item_size-1) 136 self.assertEqual(order.code,"dd002") 137 self.assertIsNotNone(order.update_time) 138 self.assertIsNotNone(order.updator) 139 140 pm.commit() 141 pm.close() 142 143 def test_delete(self) : 144 145 pm = Persister() 146 pm.open() 147 148 order = self.order 149 150 order.to_delete() 151 pm.save(order) 152 153 o = pm.byid(order) 154 self.assertIsNone(o) 155 156 pm.commit() 157 pm.close() 158 159 #============================= 160 dao = DataAccess() 161 dao.open() 162 163 sql = 'select count(0) from %s where order_id = %s' % (EntityManager.get_meta(OrderItem).table_name,'%s') 164 pars = [order.id] 165 count = int(dao.executeScalar(sql,pars)) 166 167 self.assertEqual(count,0) 168 169 dao.close() 170 171 def test_create(self) : 172 173 pm = Persister() 174 pm.open() 175 176 customer = Customer() 177 customer.code = "C001" 178 customer.name="北京千舟科技发展有限公司" 179 pm.save(customer) 180 181 product = Product() 182 product.code = "P001" 183 product.name="电商ERP标准版4.0" 184 pm.save(product) 185 186 order = SalesOrder(); 187 order.code="DD001" 188 order.quantity=2 189 order.price =2 190 order.amount=4.6 191 order.customer = customer 192 order.items = [] 193 194 for i in range(self.item_size) : 195 item = OrderItem() 196 item.quantity=i+1 197 item.price =i+1 198 item.amount= (i+1)*(i+1) 199 item.product = product 200 order.items.append(item) 201 202 pm.save(order) 203 204 self.order = order 205 206 self.assertIsNotNone(order.id) 207 self.assertIsNotNone(order.create_time) 208 self.assertIsNotNone(order.creator) 209 210 pm.commit() 211 pm.close() 212 213 def test_query(self) : 214 215 pm = Persister() 216 pm.open() 217 218 orders = pm.query_list(SalesOrder) 219 count1 = len(orders) 220 221 customer = Customer() 222 customer.code = "C003" 223 customer.name="北京千舟科技发展有限公司" 224 pm.save(customer) 225 226 product = Product() 227 product.code = "P003" 228 product.name="电商ERP标准版4.0" 229 pm.save(product) 230 231 for i in range(self.item_size) : 232 order = SalesOrder(); 233 order.code="DD001" 234 order.quantity=2 235 order.price =2 236 order.amount=4.6 237 order.customer = customer 238 order.items = [] 239 240 pm.save(order) 241 242 orders = pm.query_list(SalesOrder) 243 count2 = len(orders) 244 245 self.assertEquals( count1 +self.item_size,count2 ) 246 247 orders = pm.query_list(SalesOrder,"customer.code = %s",["C003"]) 248 count3 = len(orders) 249 self.assertEqual(count3,self.item_size) 250 251 pm.commit() 252 pm.close() 253 254 def test_customer_primary_key(self) : 255 pm = Persister() 256 pm.open() 257 258 dt = Wdbtype(code='zl44',name='专利',host='mysql',port=3310,db='patent',user='root',passwd='mysql123',charset='utf8') 259 260 pm.add(dt) 261 pm.delete(dt) 262 263 pm.commit() 264 pm.close() 265 266 def set_db(self) : 267 268 db = Db() 269 db.open() 270 271 clss = [SalesOrder,OrderItem,Customer,Product,Wdbtype] 272 273 for cls in clss : 274 db.drop_table(cls) 275 db.create_table(cls) 276 277 db.commit() 278 db.close() 279 280 class EntityGeneratorTest(unittest.TestCase) : 281 282 def test_generator(self) : 283 284 generator = EntityGenerator() 285 generator.open() 286 generator.generate_db("wolf") 287 generator.close() 288 289 290 if __name__ == '__main__': 291 unittest.main()