修饰器的作用
修饰器(Decorator)
的作用是在不修改目标函数的前提下,在函数执行前后执行额外的指令。
例如,要实现对目标函数运行时间的计时,最简单的方法是在原函数的前后加上计时:
# 计算斐波那契数列
def fib(cntr):
n, a, b = 0, 0, 1
while n < cntr:
print(b)
a, b = b, a + b
n += 1
# 计算斐波那契数列,计时
def fib(cntr):
start = time.time()
n, a, b = 0, 0, 1
while n < cntr:
print(b)
a, b = b, a + b
n += 1
end = time.time()
print("function finish in %s ms" % (end - start))
而使用修饰器,则可以很优雅地不用几乎不用修改原函数就能实现该功能:
import time
def calc_time(func):
def wrapper(*args, **kwargs):
start = time.time()
res = func(*args, **kwargs)
end = time.time()
print("function finish in %s ms" % (end - start))
return res
return wrapper
# 计算斐波那契数列,计时
@calc_time
def fib(cntr):
n, a, b = 0, 0, 1
while n < cntr:
print(b)
a, b = b, a + b
n += 1
或许你会说,这样一来,代码不是比原来还要复杂了吗?确实,目前看来是如此,但是如果你还有很多个函数,每个都需要计时,那么这种方法很明显是要比第一种方法要简单易用的——你只需在需要计时的函数上加上@calc_time
就完事了。
修饰器的作用原理
在函数fib()
定义处加上@calc_time
,相当于执行了fib = calc_time(fib)
,执行过后的fib已经指向了一个新的函数calc_time(fib)
而不是原来的fib()
,如此一来,我们执行的新函数类似如下:
def calc_time():
def wrapper(*args, **kwargs):
start = time.time()
def fib(cntr):
n, a, b = 0, 0, 1
while n < cntr:
print(b)
a, b = b, a + b
n += 1
res = fib(*args, **kwargs)
end = time.time()
print("function finish in %s ms" % (end - start))
return res
return wrapper
fib = calc_time
fib()(100)
此时,我们调用的时用的是fib()(100)
而不是fib(100)
,这显然与原调用方法不同,我们打印fib
和fib()
的类型以一探究竟:
print(fib)
print(fib())
<function calc_time at 0x000002B9CA4DAB80>
<function calc_time..wrapper at 0x000002B9CA4DAF70>
可以看到,fib的类型是calc_time
,而这个函数的返回值是函数wrapper
(而不是wrapper()的结果),而fib()(100)
返回的才是wrapper(100)
的结果,而由于wrapper
的形参是*args, **kwargs
,内部调用fib(*args, **kwargs)
时会传入所有参数,所以实际上执行的就是fib(100)
了。
我们进一步化简,去掉wrapper
层,这样就可以很明了地看懂修饰器的作用原理了:
def calc_time(*args, **kwargs):
start = time.time()
def fib(cntr):
n, a, b = 0, 0, 1
while n < cntr:
print(b)
a, b = b, a + b
n += 1
res = fib(*args, **kwargs)
end = time.time()
print("function finish in %s ms" % (end - start))
return res
fib = calc_time
fib(100)
修正的修饰器写法
上文使用的修饰器是可以正常使用的,但是如果你打印原函数的类型就会发现有缺陷:
print(fib)
<function calc_time.
.wrapper at 0x000001591E4AA700>
这种缺陷是致命的,比方说,你在捕获错误的时候如果捕获的全是calc_time
类型的错误而不是fib
类型的错误,你根本就无法根据错误堆栈确定产生错误的函数,为了解决这个问题,你需要将函数改成如下(在修饰函数上加上@functools.wraps(func)
):
import time
import functools
def calc_time(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
res = func(*args, **kwargs)
end = time.time()
print("function finish in %s ms" % (end - start))
return res
return wrapper
# 计算斐波那契数列,计时
@calc_time
def fib(cntr):
n, a, b = 0, 0, 1
while n < cntr:
print(b)
a, b = b, a + b
n += 1
再打印函数的类型,可以发现已经修正了:
print(fib)
<function fib at 0x0000022F4A0AF040>
带参数的修饰器
修饰器可以带参数,我们以web路由为例说明带参数的修饰器的作用
import functools
route_rules = {}
def route(url):
def decorator(func):
route_rules[url] = func
print("已添加路由:%s -> %s" % (url, func))
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorator
@route("/hello")
def hello():
return "hello, world!"
@route("/hello/godfish")
def hello_godfish():
return "hello, godfish"
print("打印路由表:")
for k,v in route_rules.items():
print(k, "->", v)
已添加路由:/hello -> <function hello at 0x000001AB9509E0D0>
已添加路由:/hello/godfish -> <function hello_godfish at 0x000001AB9509E310>
打印路由表:
/hello -> <function hello at 0x000001AB9509E0D0>
/hello/godfish -> <function hello_godfish at 0x000001AB9509E310>