给 CPython 擦屁股


6 月份我个人生活上发生了一些不幸的事情,加上期末毕业季各种事务堆得我完全透不过气,我度过了充斥着绝望、内耗、颓废的十几天,宛如十年。

纵使这期间我极其低产,我还是艰难地推进 hmr 修复了几个硬骨头 bugs。编程的时候可以短暂地让脑子转起来而且一定程度上忽视这个现实世界,这让我很舒服。

前几天参加了 AdventureX 的一个预热活动,一次在深圳开展的 BuilderUp 活动,其实就是一个无主题的创业者分享会。我本来想推销一下 hmr,正好 hmr 也算 stable 了,于是就开了有史以来最大的 PR 把 reactivity 分支(历经半年,154 次提交)合并进来了:Pull Request #319 · promplate/pyth-on-line,但是后来其实也没用上。当然活动是超级棒的,玩的很开心。久旱逢甘霖。

今天这篇随笔 就是介绍下我最新处理的的一个 bug,它源于一个 CPython 搁置的 Issue:

Issue #121306 · python/cpython 是我一年前给 CPython 提的一个 Issue,大意是,在这种情况下:

class _:
    print(a)

访问 a 并不会遵循 locals -> globals -> builtins 这样的顺序:如果 globals 不是一个纯 dict 而是一个 dict 的子类,即使 a 存在,也会访问不到(报 NameError 错误)。

虽然 core devs 的回复说,就是不让 globals 是 dict 的子类的。但事实上除了 ClassDef 中访问全局变量这一个 case 之外,其它都运行得好好的。

说起来,后来我发现其中回复我的 @gaogaotiantian 就是编程区 up 主 码农高天

我自己个人也是在很多情况下使用了这个 feature,比如我最常用的 case 就是用一个 ChainMap 作为 globals,来实现上下文管理。用于 模板渲染、monkey patching 等场景。但这些场景一般不会遇到 ClassDef 语句,所以这个 Issue 对我的影响一直蛮有限的。

直到我要在 hmr 中实现函数级热重载。

通常我们说热重载,hmr 中的 m 指的是 module,也就是说重新运行的最小单位是一个模块。但写 Python 的人比较懒,而我又总是惯着别人,on the other hand 我也喜欢炫技,所以也实现了一个比较 hacky 的函数级热重载:

  1. 它基本上表现就像用 functools.cache 装饰无参的函数
  2. 当其中访问到的全局变量 / 别的模块的变量 invalidate 时,标记为脏
  3. 当这个函数源代码改变时,标记为脏

问题就在,它需要追踪访问全局变量。这本并不难,因为我已经实现了一个 ReactiveProxy 类,会通过 __getitem__ 记录所有变量的取值,但由于 CPython 的这个 bug 存在,我没法在被我标记的函数里创建 class(更准确说,是创建 class 并在 class 的 frame 中访问全局变量)。由于我 hmr 这个库是希望做到尽可能对用户无感,所以肯定是要支持所有用法的。在函数中创建 class 显然是常见的用法。

经过漫长的思索,我发现唯一的 userland solution 就是修改 AST,比如把上面的例子编译成:

class _:
    try:
        a = locals()["a"]
    except KeyError:
        try:
            a = globals()["a"]
        except KeyError:
            raise NameError("a") from None
    print(a)

换句话说,就是手动从 gloabls()locals() 中拿这个值。这样确保了我们自定义的 __getitem__ 一定会被触发。

这其实是一个过度简化的例子(因为其实 print 也应该进行这样的取值。而且也要除了 locals()globals() 外还应该尝试 __builtins__

实现

LLM 比我更熟悉 Python 的 AST 用法。而且我需求也挺清晰,于是这部分元编程就先交给各类 vibe coding 工具来做,我提思路和小的修改,迭代一二十次差不多就成了。我又清理和重构了下,目前的核心实现大概是这样:

import ast
from inspect import cleandoc
from typing import override


class ClassTransformer(ast.NodeTransformer):
    @override
    def visit_ClassDef(self, node: ast.ClassDef):
        node.body = [
            name_lookup_function,
            *map(ClassBodyTransformer().visit, node.body),
            ast.Delete(targets=[ast.Name(id="__name_lookup", ctx=ast.Del())]),
        ]
        return node


class ClassBodyTransformer(ast.NodeTransformer):
    @override
    def visit_Name(self, node: ast.Name):
        if isinstance(node.ctx, ast.Load) and node.id != "__name_lookup":
            return build_name_lookup(node.id)
        return node

    @override
    def visit_FunctionDef(self, node: ast.FunctionDef):
        node.decorator_list = [self.visit(d) for d in node.decorator_list]
        self.visit(node.args)
        if node.returns:
            node.returns = self.visit(node.returns)
        return node

    visit_AsyncFunctionDef = visit_FunctionDef  # type: ignore

    @override
    def visit_Lambda(self, node: ast.Lambda):
        self.visit(node.args)
        return node


def build_name_lookup(name: str) -> ast.Call:
    return ast.Call(func=ast.Name(id="__name_lookup"), args=[ast.Constant(value=name)])


name_lookup_function = ast.FunctionDef(
    name="__name_lookup",
    args=ast.arguments(args=[ast.arg(arg="name")]),
    body=ast.parse(
        cleandoc("""

            from inspect import currentframe
            f = currentframe().f_back
            l, g, b = f.f_locals, f.f_globals, f.f_builtins
            m = object()
            if (v := l.get(name, m)) is not m or (v := g.get(name, m)) is not m or (v := b.get(name, m)) is not m:
                return v
            raise NameError(name)

        """)
    ).body,
    decorator_list=[ast.Name(id="staticmethod")],
)

简单来说,就是首先注入一个固定的 __name_lookup 函数,首先获得 class 所在的 frame,然后依次尝试 f_localsf_globalsf_builtins,并在 class 的结尾删掉这个 __name_lookup 函数。

接下来递归地把所有 ast.Name 都变成 __name_lookup(...) 这样的调用,就完成啦。编译上面的例子的结果就是:

class _:
    @staticmethod
    def __name_lookup(name):
        from inspect import currentframe
        f = currentframe().f_back
        l, g, b = (f.f_locals, f.f_globals, f.f_builtins)
        m = object()
        if (v := l.get(name, m)) is not m or (v := g.get(name, m)) is not m or (v := b.get(name, m)) is not m:
            return v
        raise NameError(name)

    __name_lookup('print')(__name_lookup('a'))

    del __name_lookup

完成啦 🥂 你学会了吗 :)