CPython 标准库源码分析 collections.Counter

886 阅读5分钟

Counter 是一个专门计数可哈希对象的 dict 子类,元素会被当做 dict 的 key,计数的数量被当做 value 存储。

这是 Counter 的 doc string,直接明确的指出了元素会被存储在 dict 的 key,换句话说只有可哈希的元素才能使用 Counter 来计数。

>>> c = Counter()
>>> c = Counter('gallahad') 
>>> c = Counter({'a': 4, 'b': 2})
>>> c = Counter(a=4, b=2)

Counter 有 4 中初始化方式:空、字符串、字典、关键字参数

class Counter(dict):
	def __init__(*args, **kwds):
        if not args:
            raise TypeError("...")
        self, *args = args
        if len(args) > 1:
            raise TypeError("...")

        super(Counter, self).__init__()
        self.update(*args, **kwds)

从初始化代码中可以看出,Counter 继承自 dict,最后通过 self.update() 方法把参数更新到 Counter 中。

def update(*args, **kwds):
    ... # 参数检查
    iterable = args[0] if args else None
        if iterable is not None:
            if isinstance(iterable, _collections_abc.Mapping):
                if self:
                    self_get = self.get
                    for elem, count in iterable.items():
                        self[elem] = count + self_get(elem, 0)
                else:
                    super(Counter, self).update(iterable)               	
        else:
    	    _count_elements(self, iterable)
    
        if kwds:
            self.update(kwds)

update 函数是需要一个 iteranble 对象,也就是说要是一个可以被 for in 的数据类型。 随后判断了这个 iteranble 是不是一个 Mapping,如果是就使用 .items() 遍历 key 和 value。 如果传入的参数是一个 dict 或者 Counter 的实例就会走到这个条件判断中,通过「mapping」的行为更新计数。

Mapping 是在 collections 中的一个抽象数据类型,这个抽象数据类型并不是用来继承的,是用来判断类型的抽象数据类型。 就像这里 if isinstance(iterable, _collections_abc.Mapping) ,本质是一个 duck typing 的应用。

如果要实现自定义的 dict 类型,一般会继承 collections.abc.User.Dict 来实现。

def _count_elements(mapping, iterable):
    'Tally elements from the iterable.'
    mapping_get = mapping.get
    for elem in iterable:
        mapping[elem] = mapping_get(elem, 0) + 1

非 Mapping 类型使用 _count_elements 函数完成计数跟新。_count_elements 函数使用 iterable 都会实现的迭代器遍历完成。

如果参数是关键字参数会直接调用当前的 update 方法更新,同样走的是 Mapping 类型那条路。

def subtract(*args, **kwds):
    ... # 参数检查
        
    iterable = args[0] if args else None
    if iterable is not None:
        self_get = self.get
        if isinstance(iterable, _collections_abc.Mapping):
            for elem, count in iterable.items():
                self[elem] = self_get(elem, 0) - count
        else:
            for elem in iterable:
                self[elem] = self_get(elem, 0) - 1
    if kwds:
        self.subtract(kwds)

subtract 函数和 update 函数功能相反,但是实现很类似,仅仅是把加换成了减,同时还有还有可能出现 0 值和负值。

>>> c = Counter("abcd")
>>> c.subtract(d=10)
>>> c
Counter({'a': 1, 'b': 1, 'c': 1, 'd': -9})
>>>
def elements(self):
    return _chain.from_iterable(_starmap(_repeat, self.items()))

elements 方法可以把 Counter 转换成迭代器,同时忽略掉了 0 值和负值的计数。

>>> for e in c.elements():
...     print(e)
...
...
a
b
c
>>>

_chain.from_iterable(_starmap(_repeat, self.items())) 用了三个 itertool 里面的三个方法来生成迭代器。

  • _repeat: itertools.repeat,创建一个重复的对象的迭代器,repeat('A', 2) => ['A', 'A']
  • _starmap: itertools._starmap,创建一个迭代器使用可迭代对象中获取的参数,starmap(lambda x: x+x, ['A', 'B']) => ['AA', 'BB']
  • _chain.from_iterable: itertools.chain.from_iterable, 从可迭代对象创建一个迭代器

Counter 中剩下就是一些运算来简化过程,实现了 "+", "-", "&", "|" 和对应原地修改 "+=", "-=", "&=", "|="。

def __add__(self, other):
    if not isinstance(other, Counter):
        return NotImplemented
    result = Counter()
    for elem, count in self.items():
        newcount = count + other[elem]
        if newcount > 0:
            result[elem] = newcount
    for elem, count in other.items():
        if elem not in self and count > 0:
            result[elem] = count
    return result

所有的非原地修改都会生成一个新的 Conter 实例,在加法中,现实相加了 other 中有的元素,然后再把只在 other 中同时大于 0 的也放入新的 Counter 中。

def __sub__(self, other):
    if not isinstance(other, Counter):
        return NotImplemented
    result = Counter()
    for elem, count in self.items():
        newcount = count - other[elem]
        if newcount > 0:
            result[elem] = newcount
    for elem, count in other.items():
        if elem not in self and count < 0:
            result[elem] = 0 - count
    return result

非原地的减法是从被减数中减去计数同时这个计数还要大于 0 才会被放入结果中,如果减数中有负值会反转成正值放入新 Counter 中。

def __or__(self, other):
    if not isinstance(other, Counter):
        return NotImplemented
    result = Counter()
    for elem, count in self.items():
        other_count = other[elem]
        newcount = other_count if count < other_count else count
        if newcount > 0:
            result[elem] = newcount
    for elem, count in other.items():
        if elem not in self and count > 0:
            result[elem] = count
    return result

并集运算的过程是假如没有就放入新的 Counter 中,如果有就对比,哪个计数大,哪个就放入新的 Counter 中,同时也要保证每个计数不能小于 0.

def __and__(self, other):
    if not isinstance(other, Counter):
        return NotImplemented
    result = Counter()
    for elem, count in self.items():
        other_count = other[elem]
        newcount = count if count < other_count else other_count
        if newcount > 0:
            result[elem] = newcount
    return result

差集运算找出同时存两个 Counter 中,计数较小的那个放入新的 Counter 中,同时保证不大于 0。

剩下的就是与之对应的原地方法,并不是创建新的 Counter 而是直接使用老的 Counter,实现过程上比较类似,但是最后是使用 self._keep_positive() 方法来保证返回的计数中不会有负值。

def _keep_positive(self):
    nonpositive = [elem for elem, count in self.items() if not count > 0]
    for elem in nonpositive:
        del self[elem]
    return self

def __iadd__(self, other):
    for elem, count in other.items():
        self[elem] += count
    return self._keep_positive()

最后剩下的一个函数是用的最多的 most_common(), 返回最多的 n 个计数

def most_common(self, n=None):
    if n is None:
        return sorted(self.items(), key=_itemgetter(1), reverse=True)
    return _heapq.nlargest(n, self.items(), key=_itemgetter(1))

实现过程简单暴力,直接根据计数做了个排序,然后使用了最大堆,获取前 N 的元素和计算值。

总结一下,Counter 是基于 dict 的子类使用 key 存储每个元素,所以可计数的元素肯定是可哈希的元素,核心方法是 update() 使用了 duck typing 方式更新不同合法类型的参数。 在重载的运算过程中,总是要保证不会有负计数的出现,唯一可能出现负计数的时候就是调用 subtract。所以在遍历不要直接使用 c.items() 方法,必须使用 c.elements()