AST抽象语法树的原理与应用

2024-08-23 16:27 杨晓琪 367

什么是AST?

AST 即是抽象语法树,是对源代码的抽象语法结构的树状表示,树上每个节点都代表着源代码中的一种结构,比如表达式、语句、声明等。

我们以 C 语言为例,最简单的一个 main 函数:

int main()
{
  return 0;
}
转换为 AST 后,会是以下这种结构:

转换为 AST 后,会是以下这种结构:

FileAST: 
  FuncDef: 
    Decl: main, [], [], [], []
      FuncDecl: 
        TypeDecl: main, [], None
          IdentifierType: ['int']
    Compound: 
      Return: 
        Constant: int, 0

这里我们可以很清楚地看到 AST 的树状结构,节点与节点之间有父子关系,比如这里的 main 函数,是一个 FuncDef 类型,即函数定义,它有两个子节点, Decl 声明和 Compound 复合语句,而它们下面又有各自的子节点,从而组成一个层级化的树状结构。

AST 的作用

通常源代码都是人写的,但是最终执行的却是计算机,那么怎么让计算机理解人写的源代码呢?这就需要用到一个叫编译器的东西,编译器会将源代码转换成计算机能够理解的机器码,这是一个非常复杂的过程,而这个过程的基础就是构建一棵抽象语法树。

此外,AST 还有如下作用:

  • 代码分析:检查代码中的语法错误、代码规范、内存泄漏等
  • 代码优化:对代码进行优化,提高执行效率
  • 代码转换:将一种语言转换为另一种语言

AST 构建原理

通常,构建 AST 需要经过词法分析和语法分析两个步骤。

词法分析

词法分析实际是一个扫描的过程,读入源代码的字符流,当遇到空格、操作符、特殊符号的时候,就表示一个词(token)扫描结束了,最终全部扫描完成后会形成一个 token 的序列输出。

比如一行 C 代码:

const int n = 10;

词法分析后就形成了以下的 token 序列:

["const", "int", "n", "=", "10"]

语法分析

经过词法分析后,已经将源代码的字符流转换成了 token 流,这时语法分析就会一个 token 一个 token 地读取,并进行标记。比如遇到 const 和 int 则会标记为关键字、n 标记为标识符、= 标记为运算符、10 标记为常量,最终形成 AST 如下:

FileAST: 
  Decl: n, ['const'], [], [], []
    TypeDecl: n, ['const'], None
      IdentifierType: ['int']
    Constant: int, 10
   

AST 应用示例

我这里使用 python 的 pycparser 来对 C 语言进行 AST 分析。

语法检测

error.c

int main()
{ 
    int a = 0
    int b = 1;

    return 0;
}

check_syntax.py

# coding=utf-8
from pycparser import parse_file
from pycparser.plyparser import ParseError


if __name__ == '__main__':
    try:
        ast = parse_file('c_files/error.c')
    except ParseError as e:
        print(e)

     检测代码中的语法错误:

> check_syntax.py
c_files/error.c:4:2: before: int

函数定位

foobar.c

void foo()
{ 

}

void bar()
{

}

int main()
{
    foo();
    bar();

    return 0;
}
locate_func.py

locate_func.py

# coding=utf-8
from pycparser import c_ast, parse_file


class FuncDefVisitor(c_ast.NodeVisitor):
    def __init__(self, funcname):
        self._funcname = funcname

    def visit_FuncDef(self, node):
        if node.decl.name == self._funcname:
            print('%s defined at %s' % (node.decl.name, node.decl.coord))

class FuncCallVisitor(c_ast.NodeVisitor):
    def __init__(self, funcname):
        self._funcname = funcname

    def visit_FuncCall(self, node):
        if node.name.name == self._funcname:
            print('%s called at %s' % (self._funcname, node.name.coord))
        if node.args:
            self.visit(node.args)

def locate_func(filename, funcname):
    ast = parse_file(filename)

    def_v = FuncDefVisitor(funcname)
    def_v.visit(ast)

    call_v = FuncCallVisitor(funcname)
    call_v.visit(ast)


if __name__ == '__main__':
    locate_func('c_files/foobar.c', 'foo')
    locate_func('c_files/foobar.c', 'bar')

 查找代码中函数的定义和调用的位置:

> python locate_func.py
foo defined at c_files/foobar.c:1:6
foo called at c_files/foobar.c:13:2
bar defined at c_files/foobar.c:6:6
bar called at c_files/foobar.c:14:2

函数改写

rewrite_func.py

# coding=utf-8
from pycparser import c_parser, c_ast, c_generator


class ParamAdder(c_ast.NodeVisitor):
    def visit_FuncDecl(self, node):
        typ = c_ast.TypeDecl(
            declname='_new', quals=[], align=[],
            type=c_ast.IdentifierType(['int'])
        )
        new_decl = c_ast.Decl(
            name='_new', quals=[], align=[], storage=[],
            funcspec=[], type=typ, init=None, bitsize=None,
            coord=node.coord
        )

        if node.args:
            node.args.params.append(new_decl)
        else:
            node.args = c_ast.ParamList(params=[new_decl])


text = r"""
void foo(int a, int b) {
}

void bar() {
}
"""


if __name__ == '__main__':
    parser = c_parser.CParser()
    ast = parser.parse(text)

    v = ParamAdder()
    v.visit(ast)


    generator = c_generator.CGenerator()
    c_code = generator.visit(ast)
    print("
[C Code]
%s" % c_code)

为每个函数增加一个新的入参:

> python rewrite_func.py
[C Code]
void foo(int a, int b, int _new)
{
}

void bar(int _new)
{
}

代码格式化

chaos.c

int  sum(int  n) {
  if ( n <=0) return 0;
  int a=0;
    for(int i=0;i < n; i++)
    {
    a += i;
} return a;
}

int main(){
    int a= sum();
  return 0;
}
format_code.py

format_code.py

# code=utf-8
from pycparser import parse_file, c_generator


def translate_to_c(filename):
    ast = parse_file(filename)
    generator = c_generator.CGenerator()
    return generator.visit(ast)


if __name__ == '__main__':
    c_code = translate_to_c('c_files/chaos.c')
    print("
[C Code]
%s" % c_code)

将格式混乱的代码格式化为标准格式:

> python format_code.py
[C Code]
int sum(int n)
{
  if (n <= 0)
    return 0;
  int a = 0;
  for (int i = 0; i < n; i++)
  {
    a += i;
  }

  return a;
}

int main()
{
  int a = sum();
  return 0;
}

代码生成

generate_code.py

# coding=utf-8
from pycparser import c_ast, c_generator


def generate_ast():
    constant_zero = c_ast.Constant(type='int', value='0')
    return_node = c_ast.Return(expr=constant_zero)
    compound_node = c_ast.Compound(block_items=[return_node])

    type_decl_node = c_ast.TypeDecl(
        declname='main', quals=[],
        type=c_ast.IdentifierType(names=['int']),
        align=[]
    )
    func_decl_node = c_ast.FuncDecl(
        args=c_ast.ParamList([]),
        type=type_decl_node
    )
    func_def_node = c_ast.Decl(
        name='main', quals=[], storage=[], funcspec=[],
        type=func_decl_node, init=None,
        bitsize=None, align=[]
    )

    return c_ast.FuncDef(
        decl=func_def_node, param_decls=None,
        body=compound_node
    )

def generate_c_code(ast):
    generator = c_generator.CGenerator()
    return generator.visit(ast)


if __name__ == '__main__':
    ast = generate_ast()
    c_code = generate_c_code(ast)
    print('
[C Code]
%s' % c_code)

通过 AST 生成 C 语言代码:

> python generate_code.py
[C Code]
int main()
{
  return 0;
}

总结

本文介绍了 AST 的基础概念和构建原理,并展示了一些 AST 的应用示例,我们可以发现不仅仅底层的编译需要用到 AST,在应用方面 AST 也能做许多事情,因此不管是做底层还是应用层开发,都可以了解一下 AST,它能为编程带来许多可能性。