5. 函数与迭代器

本章介绍 Python 中函数、迭代器、生成器的使用,并介绍一些 Python 中函数:

  • 高级函数:经常用到的 lambda,时而会使用的 filter,以及极少用到的 reducemap

  • 错误控制语句:try .. except

本章引入了下述 Python 库/模块。它们均为内置库/模块,无需额外的安装步骤:

[1]:
import timeit
import cProfile
import tracemalloc, linecache
import functools

5.1. 函数的使用

使用 def 定义函数。下例定义了一个名为 func 的函数。

  • 参数:下面的 func 函数包含两个输入参数 a, b 。其中,参数 b 是一个可选参数,在无输入时会自动赋值为 0。

  • 注释:Python 中建议(但不强制)标出函数的注释,一般用一对三个双引号来标记多行注释,作为函数的注释。

  • 返回值:用 return 语句来返回一个对象。

[2]:
def func(a, b=0):
    """
    This is a function that can meow.
    """
    return " ".join(["meow"] * (a + b))

5.1.1. 函数参数

上例的函数接受两个参数:

[3]:
s = func(a=1, b=2)
print(s)
meow meow meow

对应的参数名在传入时,可以省略。比如上例中的 a=b= 可以省略:

[4]:
func(1, 2)
[4]:
'meow meow meow'

由于参数 b 有一个默认值 0,因此函数 func 也可以只传入一个参数(即参数 a):

[5]:
func(1)
[5]:
'meow'

函数也接受以序列(依次传入列表中的项)或者字典(依字典的键传入对应的值)的方式传入参数,使用带星号 * 前缀的序列,或者带双星号 ** 前缀的字典。

[6]:
key_lst = [2]
func(1, *key_lst)
[6]:
'meow meow meow'
[7]:
key_dict = {"a": 1, "b": 2}
func(**key_dict)
[7]:
'meow meow meow'

这种传参方式有时用在字符串的 format() 函数中:

[8]:
key_dict = {"a": 1, "b": 2}
s = "{a} + {b}".format(**key_dict)
print(s)
1 + 2

5.1.2. 函数帮助

help() 函数来查看函数的帮助,即函数开头的、以三个双引号括起的字符串:

[9]:
help(func)
Help on function func in module __main__:

func(a, b=0)
    This is a function that can meow.

5.2. 处理异常:try与raise

try .. except 语句来控制异常。比如在除法运算时,如果除数为0,那么会弹出异常:

[10]:
def func(m, n):
    r = m / n
    return r

func(1, 0)
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
<ipython-input-10-e51b4c7d78d1> in <module>
      3     return r
      4
----> 5 func(1, 0)

<ipython-input-10-e51b4c7d78d1> in func(m, n)
      1 def func(m, n):
----> 2     r = m / n
      3     return r
      4
      5 func(1, 0)

ZeroDivisionError: division by zero

Python 在除数为0是会弹出 ZeroDivisionError。常常见到的 Error 有:

异常名称

解释

引发示例

– 引用异常 –

AttributeError

引用不存在的属性。

int.split()

IndexError

引用不存在的索引

"abc"[3]

KeyError

引用不存在的字典键

{}["a"]

NameError

使用不存在的变量名

/

– 输入参数异常 –

TypeError

函数被应用在类型错误的对象上

abs("1")

ValueError

函数传入了类型允许但值不适合的参数

int("a")

– 其他 –

AssertionError

断言语句 assert 失败

assert(1 < 0)

StopIteration

迭代器结束迭代

next(iter(''))

SyntaxError

语法错误

a = ]

KeyboardInterrupt

用户从键盘终止了正运行的代码

Ctrl+C 打断运行

要阅读完整的 Error 列表,请参考官方文档的内置异常页面。

添加 try .. except 来处理 ZeroDivisionError:

[11]:
def func(m, n):
    try:
        r = m / n
    except ZeroDivisionError:
        r = float('nan')
    return r

func(1, 0)
[11]:
nan

还可以用 except ERROR as ... 并 print 来输出异常信息:

[12]:
def func(m, n):
    try:
        r = m / n
    except ZeroDivisionError as e:
        r = float('nan')
        print(f"Error: {e}")
    return r

func(1, 0)
Error: division by zero
[12]:
nan

其中的 except 语句可以处理多种 error 类型。可以:

  • 把多个 error 组成一个元组,或者

  • 使用多个 except 语句。

[13]:
def func(m, n):
    try:
        r = m / n
    except (ZeroDivisionError, TypeError) as e:
        r = float('nan')
        print(f"Error: {e}")
    return r

func(1, 'x'), func(1, 0), func(1, 2)
Error: unsupported operand type(s) for /: 'int' and 'str'
Error: division by zero
[13]:
(nan, nan, 0.5)
[14]:
def func(m, n):
    try:
        r = m / n
    except ZeroDivisionError:
        r = float('nan')
    except TypeError:
        r = None
    return r

func(1, 'x'), func(1, 0), func(1, 2)
[14]:
(None, nan, 0.5)

异常处理的最后一个 except 语句可以:

  • 不接任何 error 类型(即 except: ),以表示处理所有未被之前 except 所处理的异常;或者

  • 接受基类型 Exception(即 except Exception as e:),并输出异常信息 e

[15]:
def func(m, n):
    try:
        r = m / n
    except ZeroDivisionError:
        r = float('nan')
    except Exception as e:
        r = None
        print(f'Unhandled error in func({m}, {n}): {e}')
    return r

x = [1, 2, 3]
func(1, 'x'), func(1, 0), func(1, 2)
Unhandled error in func(1, x): unsupported operand type(s) for /: 'int' and 'str'
[15]:
(None, nan, 0.5)

最后,用户也可以用 raise 语句强制抛出异常:

[16]:
raise(ValueError("This is a forced error."))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-16-eb20ffe9fe10> in <module>
----> 1 raise(ValueError("This is a forced error."))

ValueError: This is a forced error.

5.3. 迭代器

迭代器(iterator)是可用 next() 函数依次访问数据的一种数据对象。从用法上讲,它与序列的应用场合非常近似,但有性能优化方面的潜力。

迭代器可以用以下方式创建:

  • 使用 iter() 函数强制转换一个序列对象

  • 使用迭代器解析(或称生成器解析)

  • 使用生成器,参考下方生成器一节的内容

  • 定义一个含有 __next__() 方法的类

[17]:
s = "abc"
x = iter(s)

for _ in range(len(s)):
    print(next(x))
a
b
c

在被遍历一次后,迭代器的“指针”会抵达其末尾。如果继续访问,将会抛出 StopIteration 异常:

[18]:
next(x)
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-18-92de4e9f6b1e> in <module>
----> 1 next(x)

StopIteration:

顺带一提,在迭代器上应用 next() 实质上是调用了迭代器内部的 __next__() 方法:

[19]:
x = iter("abc")
x.__next__()
[19]:
'a'

通常,我们使用循环语句(而不是 next() 函数)来遍历迭代器:

[20]:
x = iter("abc")
for c in x:
    print(c)
a
b
c

除了用 iter() 强制转换,迭代器也可以使用类似列表解析的“迭代器解析”,不过要将外侧的方括号换成圆括号:

[21]:
x = (2 ** k for k in range(5))

for k, num in enumerate(x):
    print(f"2^{k}:\t{num}")
2^0:    1
2^1:    2
2^2:    4
2^3:    8
2^4:    16

5.3.1. 迭代器的意义:性能优化

迭代器最显著的意义就在于性能优化。由于迭代器解析会依次以(内部的)__next__() 方法调用数据,因此在调用时才计算当前项的值,而不是在创建之时一次性地计算出每一项的值。这对于内存占用与优化控制流都较有意义。

下面是一个较极端的例子,模拟在重负担任务时使用列表与迭代器之间的差异。

[22]:
# 2**(2**10) = 1.79769... x10^308
def func_lst():
    x = [2**(n) for n in range(2**15)]
    for i, k in enumerate(x):
        if k > 1.7e308:
            break
    print(i)

def func_iter():
    x = (2**(n) for n in range(2**15))
    for i, k in enumerate(x):
        if k > 1.7e308:
            break
    print(i)

本节中所用到的时间、内存性能测试的模块,都会在之后的标准库章节进行介绍。

5.3.2. 时间开销对比

下面是用 timeit 模块对上述两个函数的运行时间测试结果。由于 func_lst() 函数创建了一个长为 \(2\times 10^{15}\) 的列表,因此速度比 func_iter() 慢上不少。

[23]:
# 运算一次所花费的时间
# -- 运行在我的古董级 CPU i7-6700HQ @2.60GHz 上。

time_lst = timeit.timeit(func_lst, number=1)
time_iter = timeit.timeit(func_iter, number=1)
print(f"List:\t{time_lst:.4f} sec\nIter:\t{time_iter:.4f} sec")
1024
1024
List:   1.7223 sec
Iter:   0.0024 sec

用 cProfile 模块查看 func_lst() 运算开销,可以看到用时最高的一项是 <listcomp>,即 list comprehesion 列表解析。解析这个长度巨大的列表占用了整个函数运行的绝大部分时间。函数 func_iter() 的总开销极短,故在这里省略。

顺便一提, cProfile 并不是一个基准测试工具(请使用 timeit),而主要用来查看各部分运行的耗时占比;它的时间度量可能并不准确。

[24]:
cProfile.run('func_lst()')
1024
         40 function calls in 1.701 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    1.685    1.685 <ipython-input-22-eae10954e071>:2(func_lst)
        1    1.685    1.685    1.685    1.685 <ipython-input-22-eae10954e071>:3(<listcomp>)
        1    0.016    0.016    1.701    1.701 <string>:1(<module>)
        3    0.000    0.000    0.000    0.000 iostream.py:197(schedule)
        2    0.000    0.000    0.000    0.000 iostream.py:310(_is_master_process)
        2    0.000    0.000    0.000    0.000 iostream.py:323(_schedule_flush)
        2    0.000    0.000    0.000    0.000 iostream.py:386(write)
        3    0.000    0.000    0.000    0.000 iostream.py:93(_event_pipe)
        3    0.000    0.000    0.000    0.000 socket.py:342(send)
        3    0.000    0.000    0.000    0.000 threading.py:1017(_wait_for_tstate_lock)
        3    0.000    0.000    0.000    0.000 threading.py:1071(is_alive)
        3    0.000    0.000    0.000    0.000 threading.py:513(is_set)
        1    0.000    0.000    1.701    1.701 {built-in method builtins.exec}
        2    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.print}
        2    0.000    0.000    0.000    0.000 {built-in method nt.getpid}
        3    0.000    0.000    0.000    0.000 {method 'acquire' of '_thread.lock' objects}
        3    0.000    0.000    0.000    0.000 {method 'append' of 'collections.deque' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


5.3.3. 内存开销对比

最后,我们用 tracemalloc 模块查看一下双方的内存占用,并打印占用内存最大的3条指令。

下例中的 display_top 函数引用自 Python 官方文档 tracemalloc: Pretty Top,仅改动了 limit 参数的默认值。

[25]:
def display_top(snapshot, key_type='lineno', limit=3):
    snapshot = snapshot.filter_traces((
        tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
        tracemalloc.Filter(False, "<unknown>"),
    ))
    top_stats = snapshot.statistics(key_type)

    print("Top %s lines" % limit)
    for index, stat in enumerate(top_stats[:limit], 1):
        frame = stat.traceback[0]
        print("#%s: %s:%s: %.1f KiB"
              % (index, frame.filename, frame.lineno, stat.size / 1024))
        line = linecache.getline(frame.filename, frame.lineno).strip()
        if line:
            print('    %s' % line)

    other = top_stats[limit:]
    if other:
        size = sum(stat.size for stat in other)
        print("%s other: %.1f KiB" % (len(other), size / 1024))
    total = sum(stat.size for stat in top_stats)
    print("Total allocated size: %.1f KiB" % (total / 1024))

使用列表解析的 func_lst() 占用高达约 70 MB,而占用第二位的只有十几 KB。这说明绝大部分的内存使用都是创建这个巨大的列表的开销。

[26]:
tracemalloc.start()

# func_lst()
x = [2**(n) for n in range(2**15)]
for i, k in enumerate(x):
    if k > 1.7e308:
        break

snapshot = tracemalloc.take_snapshot()
display_top(snapshot)
Top 3 lines
#1: <ipython-input-26-0957e906499f>:4: 71111.4 KiB
    x = [2**(n) for n in range(2**15)]
#2: d:\programming\python38\lib\site-packages\IPython\core\history.py:763: 0.4 KiB
    conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
#3: d:\programming\python38\lib\codeop.py:136: 0.2 KiB
    codeob = compile(source, filename, symbol, self.flags, 1)
9 other: 0.7 KiB
Total allocated size: 71112.7 KiB

而使用迭代器解析的 func_iter() 的内存占用非常小,仅 200 余 KB。

[27]:
tracemalloc.start()

# func_iter()
x = (2**(n) for n in range(2**15))
for i, k in enumerate(x):
    if k > 1.7e308:
        break

snapshot = tracemalloc.take_snapshot()
display_top(snapshot)
Top 3 lines
#1: d:\programming\python38\lib\linecache.py:137: 92.0 KiB
    lines = fp.readlines()
#2: d:\programming\python38\lib\site-packages\IPython\core\compilerop.py:101: 5.3 KiB
    return compile(source, filename, symbol, self.flags | PyCF_ONLY_AST, 1)
#3: d:\programming\python38\lib\json\decoder.py:353: 1.5 KiB
    obj, end = self.scan_once(s, idx)
108 other: 25.1 KiB
Total allocated size: 123.8 KiB

更多迭代器的内容,可以参考 itertools 标准库相关的章节。

5.4. 生成器

生成器(generator)是一种创建迭代器的简便工具。它用 yield 关键字来代替 return,使生成器在每次被访问时都返回一个值,以此实现“依次序取出数据”的迭代器效果。

[28]:
def fprinting(seq):
    for i, val in enumerate(seq):
        yield f"{i}-th:\t{val}"

for k in fprinting("hello"):
    print(k)
0-th:   h
1-th:   e
2-th:   l
3-th:   l
4-th:   o

生成器解析(迭代器解析)在上文已经介绍过,可以减少内存使用。它可以直接作为序列参数传入函数(但不要忘记迭代器只能被遍历一次);在下例中,生成器解析的结果被作为参数直接传给了 sum()str.join() 函数。

[29]:
n = sum(k**2 for k in range(5))
s = ' + '.join(f"{k**2}" for k in range(5))
print(s, "=", n)
0 + 1 + 4 + 9 + 16 = 30

5.5. 装饰器

装饰器(decorator)是 Python 中的一种简便语法,本质上也是一种函数。例如,我们定义了两个函数 fg ,那么把函数 f 作为 g 的装饰器,相当于 F = f(g) 并返回 F(x)

从上面的叙述中你可能已经注意到,装饰器以函数作为输入值;因此,它本质上是函数的函数

5.5.1. 装饰器与外部函数

装饰器的功能大多可以通过在已有函数的外部定义一个新函数来实现。但装饰器相比外部函数拥有以下优势:

  • 分离外部功能与核心函数,减少代码修改的工作量

  • 更好的代码复用性

  • 更好的可读性

下面我们用一个例子来展示装饰器的这一优势。例如,我们现在定义了一个函数,将输入的字符串的空格去掉:

[30]:
def remove_space(s):
    return ''.join(s.split())

s = "This is\tPython."
remove_space(s)
[30]:
'ThisisPython.'

现在我们想在该函数之外实现一些功能,例如:测试函数的运行时间。

一种朴素的思维是,我们可以定义一个外部函数 print_time,用来接受与上例的 remove_space 函数相同的参数;在新函数内部实现 remove_space 函数,并在其之后打印运行时间。

[31]:
# 例:未使用装饰器
def print_time(s):
    def remove_space(s):
        return ''.join(s.split())
    # Timing
    start = timeit.default_timer()
    r = remove_space(s)
    end = timeit.default_timer()
    print(f"Run in {end-start:.0g} sec.")
    return r

s = "This is\tPython."
print_time(s)
Run in 6e-06 sec.
[31]:
'ThisisPython.'

这样写非常的直观。但是,一旦 remove_space 函数需要更改,问题就变得复杂了。例如,现在我们要添加一个参数,让用户决定用什么字符代替空格来连接文本,这时候就需要对上述代码的好几处额外进行更改:

[32]:
# 例:未使用装饰器,需要多处修改
def print_time(s, repl=''):          # 添加 repl 参数输入
    def remove_space(s, repl):
        return repl.join(s.split())
    # Timing
    start = timeit.default_timer()
    r = remove_space(s, repl)        # 添加 repl 参数
    end = timeit.default_timer()
    print(f"Run in {end-start:.0g} sec.")
    return r

s = "This is\tPython."
print_time(s, '.')
Run in 6e-06 sec.
[32]:
'This.is.Python.'

5.5.2. 装饰器的创建

使用装饰器可以简化以上流程,避免这些额外的代码更改。装饰器可以接受函数作为参数输入,因此装饰器本身可以与函数分离设计,尽量降低代码的更改量:

[33]:
# 定义装饰器
def print_deco(func):
    def modified_func(*args, **kwarg):
        start = timeit.default_timer()
        r = func(*args, **kwarg)
        end = timeit.default_timer()
        print(f"Run in {end-start:.0g} sec.")
        return r
    return modified_func

由于装饰是函数的函数:传入一个函数 func,返回一个在 func 基础上重新定义的函数 modified_func 。在 Python 中,装饰器以一种简便的 @ 符号,声明在函数定义之前:

[34]:
@print_deco    # 实质上重新定义了被装饰的函数
def remove_space(s):
    return ''.join(s.split())

s = "This is\tPython."
remove_space(s)
Run in 6e-06 sec.
[34]:
'ThisisPython.'

因此,对函数 func 使用装饰器 @decorator,实质上等同于以下调用:

decorator(func)(args)  # 注意它与函数嵌套调用 func2(func(s)) 之间的写法区别!
[35]:
def remove_space(s):    # 此处的函数未被重新定义
    return ''.join(s.split())

s = "This is\tPython."
print_deco(remove_space)(s)
Run in 5e-06 sec.
[35]:
'ThisisPython.'

如果要给函数添加参数,只需要在核心代码上修改。相比未使用装饰器的方案,这省去了逐层传参的步骤。

[36]:
@print_deco    # 实质上重新定义了被装饰的函数
def remove_space(s, repl=''):
    return repl.join(s.split())

s = "This is\tPython."
remove_space(s, '.')
Run in 5e-06 sec.
[36]:
'This.is.Python.'

5.5.3. 装饰器的参数

装饰器还可以接受参数。例如,在上例的基础上,我们添加执行次数作为一个可选参数(缺省时执行1次)。

[37]:
# 朴素的装饰器
def print_deco(n=1):
    def deco(func):
        def modified_func(*args, **kwarg):
            start = timeit.default_timer()
            for _ in range(n):
                r = func(*args, **kwarg)
            end = timeit.default_timer()
            print(f"{n} Run in {end-start:.0g} sec.")
            return r
        return modified_func
    return deco

remove_space 函数上应用上例中的装饰器,等价于:

def remove_space:
    ...
remove_space = print_deco(n)(remove_space)

因此例中的 deco 必须接受一个函数,并返回一个重定义后的函数。这样定义的装饰器在使用时,必须写出括号。

[38]:
@print_deco()  # 无参数时也必须有括号
def remove_space(s, repl=''):
    return repl.join(s.split())

s = "This is\tPython."
remove_space(s)
1 Run in 2e-05 sec.
[38]:
'ThisisPython.'
[39]:
@print_deco(10)
def remove_space(s, repl=''):
    return repl.join(s.split())

s = "This is\tPython."
remove_space(s)
10 Run in 4e-05 sec.
[39]:
'ThisisPython.'

注意到在空参数时,装饰器也添加了括号后缀 @print_deco() 以避免参数的传递错误。


另一种带可选参数的装饰器的写法,是让不加括号时能正确使用装饰器,可以利用标准库 functools 中的 partial 命令;该命令的更多信息,请参考 functools 一节的内容。

使用 partial 命令改写后的装饰器如下:

[40]:
# 装饰器的第一参数必须是函数
def print_deco(func=None, n=1):
    if func is None:
        return functools.partial(print_deco, n=n)
    def modified_func(*args, **kwarg):
        start = timeit.default_timer()
        for _ in range(n):
            r = func(*args, **kwarg)
        end = timeit.default_timer()
        print(f"{n} Run in {end-start:.0g} sec.")
        return r
    return modified_func

注意如上定义的装饰器在使用时,如果使用 @print_deco 而不是 @print_deco()

[41]:
@print_deco
def remove_space(s, repl=''):
    return repl.join(s.split())

s = "This is\tPython."
remove_space(s)
1 Run in 1e-05 sec.
[41]:
'ThisisPython.'
[42]:
remove_space(s, '.')
1 Run in 8e-06 sec.
[42]:
'This.is.Python.'

传入参数的装饰器,注意该参数只能以 key=value 的形式传入:

[43]:
@print_deco(n=10)
def remove_space(s, repl=''):
    return repl.join(s.split())

s = "This is\tPython."
remove_space(s)
10 Run in 3e-05 sec.
[43]:
'ThisisPython.'

5.6. lambda 函数

lambda 称为匿名函数,它用来简洁地声明一个函数。比如下面的匿名函数 f1 与用 def 定义的函数 f2 作用相同:

[44]:
f1 = lambda x,y: x+y
def f2(x, y):
    return x+y

f1(2,3) == f2(2,3)
[44]:
True

5.7. filter 函数

filter 称为过滤函数,将一个返回逻辑类型(True/False)的函数作为过滤器,用来过滤一个序列。序列中的经过函数能返回 True 的项会被保留下来,其余的项会被舍弃。

最后,被保留的项会以一个迭代器的形式返回。使用时请谨记迭代器的特性:迭代器在创建后只能被遍历一次。

[45]:
f = lambda x: x > 0
lst = range(-3, 3)

filter(f, lst)
[45]:
<filter at 0x18b1f9f1670>

你可以将迭代器转为列表,或者直接在循环中处理:

[46]:
x = filter(f, lst)  # 新建迭代器
print(list(x))
[1, 2]
[47]:
x = filter(f, lst)  # 新建迭代器
for k in x:
    print(k)
1
2

5.8. reduce 函数*

reduce 缩减函数将一个接受两个参数、返回一个值的函数应用在一个序列上,并依次以类似“累加”计算的方式来遍历整个序列。我认为这并不是一个常用的函数。

该函数在 Python 2 时代是一个可以直接调用的高级函数,但在 Python 3 版本需要从 functools 模块导入:

[48]:
from functools import reduce

下面是一个 reduce 函数的例子:

[49]:
# (((1+2)+3)+4) = 10.

d = reduce(lambda x,y: x+y, range(5))
print(d)
10

5.9. map 函数*

map 映射函数将一个函数应用在一个序列的每一项上,并返回一个迭代器。由于这个功能可以用迭代器解析的方法实现,因此它并不是一个常用的高级函数。

[50]:
s = "123"
x = map(int, s)

for k in x:
    print(k)
1
2
3

它与迭代器解析的写法并没有显著的不同:

[51]:
s = "123"
x = (int(k) for k in s)

for k in x:
    print(k)
1
2
3