学完了 Python 的装饰器,打算深入一下 LRU_cache 这个内置装饰器的实现与原理

顺便回顾一下双链表这个数据结构

斐波那契数列:

def fibonacci(n):
	if n < 2:
		return n
	return fibonacci(n-2) + fibonacci(n-1)

时间复杂度大概要 \(2^n\),递归方式非常耗时

现在要求优化它的速度

我们可以想到用字典缓存的装饰器

def cache(func):
    data = {}

    def wrapper(n):
        if n in data:
            return data[n]
        else:
            res = func(n)
            data[n] = res
            return res
    return wrapper


@cache
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n - 1) + fibonacci(n - 2)

如果内存空间有限怎么办?

我们需要策略去处理缓存满的情况:Least Recently Used(LRU)、Least Frequently Used(LFU)

这里用的是 LRU 策略,一段时间不用的缓存会被剔除掉

实现一个 LRU,把最远的元素给剔除

循环双链表

class Node(object):
    """结点"""

    def __init__(self, prev=None, next=None, key=None, value=None):
        self.prev, self.next, self.key, self.value = prev, next, key, value


class CircularDoubleLinkedList(object):
    """循环双端队列"""

    def __init__(self):
        node = Node()
        node.prev, node.next = node, node
        self.rootnode = node

    def headnode(self):
        """访问首结点(根节点下一个)"""
        return self.rootnode.next

    def tailnode(self):
        """访问尾结点"""
        return self.rootnode.prev

    def remove(self, node):
        if node is self.rootnode:
            return
        else:
            node.prev.next = node.next
            node.next.prev = node.prev

    def append(self, node):
        tailnode = self.tailnode()
        tailnode.next = node
        node.next = self.rootnode
        node.prev = tailnode
        self.rootnode.prev = node


class LRUCache(object):
    def __init__(self, maxsize=16):
        self.maxsize = maxsize
        self.cache = {}
        self.access = CircularDoubleLinkedList()
        self.isfull = len(self.cache) >= self.maxsize

    def __call__(self, func):
        def wrapper(n):
            cachenode = self.cache.get(n)
            if cachenode is not None:  # 命中缓存
                self.access.remove(cachenode)
                self.access.append(cachenode)
                res = cachenode.value
                return res
            else:  # 没有命中
                value = func(n)
                if not self.isfull:  # 缓存未满
                    tailnode = self.access.tailnode()
                    newnode = Node(tailnode, self.access.rootnode, n, value)
                    self.access.append(newnode)
                    self.cache[n] = newnode
                    self.isfull = len(self.cache) >= self.maxsize
                else:  # 缓存满了
                    lru_node = self.access.headnode()
                    del self.cache[lru_node.key]
                    self.access.remove(lru_node)
                    tailnode = self.access.tailnode()
                    newnode = Node(tailnode, self.access.rootnode, n, value)
                    self.access.append(newnode)
                    self.cache[n] = newnode
                return value
        return wrapper