• python unittest 源码学习


    主要学习的模块:

    ├── suite.py
    ├── case.py
    ├── main.py
    ├── runner.py
    ├── warnings.py
    ├── signals.py
    ├── result.py
    └── loader.py
    
    • suite.py TestSuite是TestCase的集合

    • case.py 就是我们平时继承的 unittest.TestCase

    • main.py TestProgram所在文件执行 parseArgs,runTests

    • runner.py 实际跑单测的时候直接加载的 TextTestResultTextTestRunner 所在地

    • warnings.py处理相关警告信息

    • signals.py 处理相关信号

    • result.py 保存结果的基类

    • loader.py 加载测试用例


    学习开始代码:

    class BaiduTestClass1(unittest.TestCase):
        def setUp(self):
            pass
    
        def test_baidu1_func1(self):
            print('test_baidu1_func1')
            self.assertEqual(u"He", u"He")
    
        def test_baidu1_func2(self):
            print('test_baidu1_func2')
            self.assertEqual(u"He", u"He")
    
        def tearDown(self):
            pass
    
    class BaiduTestClass2(unittest.TestCase):
        def setUp(self):
            pass
    
        def test_baidu2_func1(self):
            print('test_baidu2_func1')
            self.assertEqual(u"He2", u"He2")
    
        def test_baidu2_func2(self):
            print('test_baidu2_func2')
            self.assertEqual(u"He2", u"He2")
    
        def tearDown(self):
            pass
    
    
    class GetAttrTest(unittest.TestCase):
        def getattrtest(self):
            pass
    
    class GetAttrTest(object):
        def getattrtest(self):
            pass
    
    if __name__ == "__main__":
        unittest.main()
    

    1. 程序开始后进入到main.py
    '''
    main.py
    '''
    main = TestProgram
    ...
    class TestProgram(object):
    ...
    
        def __init__(self, module='__main__', defaultTest=None, argv=None,...
            self.testRunner = testRunner
            self.testLoader = testLoader  
            self.progName = os.path.basename(argv[0])
            self.parseArgs(argv)
            self.runTests()
    

    上面main.pyself.testLoader = testLoader会触发loader.py的代码

    '''
    main.py
    '''
    testLoader=loader.defaultTestLoader # 调用loader.py中的
    
    '''
    loader.py
    '''
    defaultTestLoader = TestLoader()
    testMethodPrefix = 'test'  #确定case的前缀
    
    
    1. 上面1步骤的main.pyself.parseArgs(argv)会设置self.testNames = None并且执行createTests()
    '''
    main.py
    '''
    elif self.defaultTest is None:
         # createTests will load tests from self.module
         self.testNames = None
    ...
    self.createTests()
    
    

    createTests()会触发 self.testLoader调用,self.test开始被赋值

    '''
    main.py
    '''
    elif self.testNames is None:
    	self.test = self.testLoader.loadTestsFromModule(self.module)
    

    上面代码会触发loader中的如下调用,返回tests(suiteClass类)

    '''
    loader.py
    '''
    ...
    tests = []
    	for name in dir(module):
    	obj = getattr(module, name)
                #tests = [] 增加case 必须 attrname.startswith(self.testMethodPrefix):
                if isinstance(obj, type) and issubclass(obj, case.TestCase):
                    tests.append(self.loadTestsFromTestCase(obj))
            load_tests = getattr(module, 'load_tests', None)      
             # test 列表转换成suite
            tests = self.suiteClass(tests)
    		... 
            return tests
    
    1. 上面1步骤runTests()会构建testRunner(testRunner类)并且执行testRunner.run(self.test) ,test是上面返回的tests (suiteClass类)
    '''
    main.py
    '''
    try:
          try:
            testRunner = self.testRunner(verbosity=self.verbosity,
                                        failfast=self.failfast,
                                        buffer=self.buffer,
                                        warnings=self.warnings,
                                        tb_locals=self.tb_locals)
    ...
    self.result = testRunner.run(self.test)
    
    1. testRunner.run(self.test)触发如下代码中的 try:test(result)
    '''
    runner.py
    '''
        def run(self, test):
            "Run the given test case or test suite."
            result = self._makeResult()
            registerResult(result)
            result.failfast = self.failfast
            result.buffer = self.buffer
            result.tb_locals = self.tb_locals
            with warnings.catch_warnings():
                if self.warnings:
                    # if self.warnings is set, use it to filter all the warnings
                    warnings.simplefilter(self.warnings)
                    # if the filter is 'default' or 'always', special-case the
                    # warnings from the deprecated unittest methods to show them
                    # no more than once per module, because they can be fairly
                    # noisy.  The -Wd and -Wa flags can be used to bypass this
                    # only when self.warnings is None.
                    if self.warnings in ['default', 'always']:
                        warnings.filterwarnings('module',
                                category=DeprecationWarning,
                                message=r'Please use assertw+ instead.')
                startTime = time.perf_counter()
                startTestRun = getattr(result, 'startTestRun', None)
                if startTestRun is not None:
                    startTestRun()
                try:
                    test(result)
                finally:
                    stopTestRun = getattr(result, 'stopTestRun', None)
                    if stopTestRun is not None:
                        stopTestRun()
                stopTime = time.perf_counter()
            timeTaken = stopTime - startTime
            result.printErrors()
            if hasattr(result, 'separator2'):
                self.stream.writeln(result.separator2)
            run = result.testsRun
            self.stream.writeln("Ran %d test%s in %.3fs" %
                                (run, run != 1 and "s" or "", timeTaken))
            self.stream.writeln()
    
            expectedFails = unexpectedSuccesses = skipped = 0
            try:
                results = map(len, (result.expectedFailures,
                                    result.unexpectedSuccesses,
                                    result.skipped))
            except AttributeError:
                pass
            else:
                expectedFails, unexpectedSuccesses, skipped = results
    
            infos = []
            if not result.wasSuccessful():
                self.stream.write("FAILED")
                failed, errored = len(result.failures), len(result.errors)
                if failed:
                    infos.append("failures=%d" % failed)
                if errored:
                    infos.append("errors=%d" % errored)
            else:
                self.stream.write("OK")
            if skipped:
                infos.append("skipped=%d" % skipped)
            if expectedFails:
                infos.append("expected failures=%d" % expectedFails)
            if unexpectedSuccesses:
                infos.append("unexpected successes=%d" % unexpectedSuccesses)
            if infos:
                self.stream.writeln(" (%s)" % (", ".join(infos),))
            else:
                self.stream.write("
    ")
            return result
    
    

    从而触发class TestSuite(BaseTestSuite)中的BaseTestSuite的__call__方法

    def __call__(self, *args, **kwds):        
        return self.run(*args, **kwds)
    
    1. 上面的调用会执行suite(TestSuite类)的run()调用,这里如果test是TestSuite类就会递归执行test(result)
    '''
    suite.py
    '''
    ...
    # BaseTestSuite实现了__iter__方法,这里就可以递归执行,test(result) 因为调用了enumerate(self)
    for index, test in enumerate(self):             
                if result.shouldStop:
                    break
                if _isnotsuite(test):
                    self._tearDownPreviousClass(test, result)
                    self._handleModuleFixture(test, result)
                    self._handleClassSetUp(test, result)
                    result._previousTestClass = test.__class__
                    if (getattr(test.__class__, '_classSetupFailed', False) or
                        getattr(result, '_moduleSetUpFailed', False)):
                        continue
    
                if not debug:              
                    test(result)
    

    如果上面的是test是TestCase类就会调用TestCase 的__call__方法触发run方法

    '''
    case.py
    '''
    def __call__(self, *args, **kwds):       
            return self.run(*args, **kwds)
           
    def run(self, result=None):
    		...
            # 获取测试方法
            testMethod = getattr(self, self._testMethodName)
            
            # 是否跳过
            if (getattr(self.__class__, "__unittest_skip__", False) or
                getattr(testMethod, "__unittest_skip__", False)):
                # If the class or method was skipped.
                try:
                    skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
                                or getattr(testMethod, '__unittest_skip_why__', ''))
                    self._addSkip(result, self, skip_why)
                finally:
                    result.stopTest(self)
                return
          	....
            try:
                self._outcome = outcome
    
                #执行setup
                with outcome.testPartExecutor(self):
                    self._callSetUp()
                if outcome.success:
    			    outcome.expecting_failure = expecting_failure
                    with outcome.testPartExecutor(self, isTest=True):             
                        # amize 这里真正执行测试用例
                        self._callTestMethod(testMethod)
     
    

    一段便于理解的unittest的代码

    
    import importlib
    import logging
    
    
    class TestCase(object):
        def __init__(self, name):
            self.name = name
    
        def setup(self):
            pass
    
        def teardown(self):
            pass
    
    
    class Loader(object):
        def __init__(self):
            self.cases = {}
    
        def load(self, path):
            module = importlib.import_module(path)
            for test_class_name in dir(module):
                test_class = getattr(module, test_class_name)
                if (
                        isinstance(test_class, type) and
                        issubclass(test_class, TestCase)
                ):
                    self.cases.update({
                        test_class: self.find_test_method(test_class) or []
                    })
    
        def find_test_method(self, test_class):
            test_methods = []
    
            for method in dir(test_class):
                if method.startswith("test_"):
                    test_methods.append(
                        getattr(test_class, method)
                    )
    
            return test_methods
    
        def __iter__(self):
            for test_class, test_cases in self.cases.items():
                yield test_class, test_cases
    
    
    class Runner(object):
        def __init__(self, path):
            self.path = path
    
        def run(self):
            loader = Loader()
            loader.load(self.path)
    
            for test_class, test_cases in loader:
                test_instance = test_class(test_class.__name__)
                test_instance.setup()
    
                try:
                    for test_case in test_cases:
                        test_case(test_instance)
                except:
                    logging.exception("error occured, skip this method")
    
                test_instance.teardown()
    
    from myunittest import TestCase
    
    
    class DemoTestCase(TestCase):
        def setup(self):
            print("setup")
    
        def teardown(self):
            print("teardown")
    
        def test_normal(self):
            print("test normal function")
    
        def test_exception(self):
            raise Exception("haha, exception here!")
    
    
    from myunittest import Runner
    
    
    if __name__ == "__main__":
        runner = Runner("test_demo")
        runner.run()
    

    补充:

    1. case的排序问题(list自带的sort加上cmp_to_key函数)

    2. failfast装饰器

    def failfast(method):
        @wraps(method)
        def inner(self, *args, **kw):
            if getattr(self, 'failfast', False):
                self.stop()
            return method(self, *args, **kw)
        return inner
    
    1. HTMLTestRunner里面执行
        def run(self, test):
            "Run the given test case or test suite."
            result = _TestResult(self.verbosity)
            test(result)
            self.stopTime = datetime.datetime.now()
            self.generateReport(test, result)
            return result
    
    
  • 相关阅读:
    Android开源库
    银行卡的数字检測
    hdu4941 Magical Forest
    android之检測是否有网络
    在Oracle数据库中使用NFS,怎样调优?
    centos+nginx+php-fpm+php include fastcgi_params php页面能訪问但空白,被fastcgi_params与fastcgi.conf害慘了
    漫谈反射
    Android 四大组件学习之BroadcastReceiver二
    【LeetCode】two num 利用comparable接口 对对象进行排序
    扩展功能==继承?
  • 原文地址:https://www.cnblogs.com/amize/p/13247060.html
Copyright © 2020-2023  润新知