什么是AST?
AST 即是抽象语法树,是对源代码的抽象语法结构的树状表示,树上每个节点都代表着源代码中的一种结构,比如表达式、语句、声明等。
我们以 C 语言为例,最简单的一个 main 函数:
int main()
{
return 0;
}
转换为 AST 后,会是以下这种结构:
转换为 AST 后,会是以下这种结构:
FileAST:
FuncDef:
Decl: main, [], [], [], []
FuncDecl:
TypeDecl: main, [], None
IdentifierType: ['int']
Compound:
Return:
Constant: int, 0
这里我们可以很清楚地看到 AST 的树状结构,节点与节点之间有父子关系,比如这里的 main 函数,是一个 FuncDef 类型,即函数定义,它有两个子节点, Decl 声明和 Compound 复合语句,而它们下面又有各自的子节点,从而组成一个层级化的树状结构。
AST 的作用
通常源代码都是人写的,但是最终执行的却是计算机,那么怎么让计算机理解人写的源代码呢?这就需要用到一个叫编译器的东西,编译器会将源代码转换成计算机能够理解的机器码,这是一个非常复杂的过程,而这个过程的基础就是构建一棵抽象语法树。
此外,AST 还有如下作用:
- 代码分析:检查代码中的语法错误、代码规范、内存泄漏等
- 代码优化:对代码进行优化,提高执行效率
- 代码转换:将一种语言转换为另一种语言
AST 构建原理
通常,构建 AST 需要经过词法分析和语法分析两个步骤。
词法分析
词法分析实际是一个扫描的过程,读入源代码的字符流,当遇到空格、操作符、特殊符号的时候,就表示一个词(token)扫描结束了,最终全部扫描完成后会形成一个 token 的序列输出。
比如一行 C 代码:
const int n = 10;
词法分析后就形成了以下的 token 序列:
["const", "int", "n", "=", "10"]
语法分析
经过词法分析后,已经将源代码的字符流转换成了 token 流,这时语法分析就会一个 token 一个 token 地读取,并进行标记。比如遇到 const 和 int 则会标记为关键字、n 标记为标识符、= 标记为运算符、10 标记为常量,最终形成 AST 如下:
FileAST:
Decl: n, ['const'], [], [], []
TypeDecl: n, ['const'], None
IdentifierType: ['int']
Constant: int, 10
AST 应用示例
我这里使用 python 的 pycparser 来对 C 语言进行 AST 分析。
语法检测
error.c
int main()
{
int a = 0
int b = 1;
return 0;
}
check_syntax.py
# coding=utf-8
from pycparser import parse_file
from pycparser.plyparser import ParseError
if __name__ == '__main__':
try:
ast = parse_file('c_files/error.c')
except ParseError as e:
print(e)
检测代码中的语法错误:
> check_syntax.py
c_files/error.c:4:2: before: int
函数定位
foobar.c
void foo()
{
}
void bar()
{
}
int main()
{
foo();
bar();
return 0;
}
locate_func.py
locate_func.py
# coding=utf-8
from pycparser import c_ast, parse_file
class FuncDefVisitor(c_ast.NodeVisitor):
def __init__(self, funcname):
self._funcname = funcname
def visit_FuncDef(self, node):
if node.decl.name == self._funcname:
print('%s defined at %s' % (node.decl.name, node.decl.coord))
class FuncCallVisitor(c_ast.NodeVisitor):
def __init__(self, funcname):
self._funcname = funcname
def visit_FuncCall(self, node):
if node.name.name == self._funcname:
print('%s called at %s' % (self._funcname, node.name.coord))
if node.args:
self.visit(node.args)
def locate_func(filename, funcname):
ast = parse_file(filename)
def_v = FuncDefVisitor(funcname)
def_v.visit(ast)
call_v = FuncCallVisitor(funcname)
call_v.visit(ast)
if __name__ == '__main__':
locate_func('c_files/foobar.c', 'foo')
locate_func('c_files/foobar.c', 'bar')
查找代码中函数的定义和调用的位置:
> python locate_func.py
foo defined at c_files/foobar.c:1:6
foo called at c_files/foobar.c:13:2
bar defined at c_files/foobar.c:6:6
bar called at c_files/foobar.c:14:2
函数改写
rewrite_func.py
# coding=utf-8
from pycparser import c_parser, c_ast, c_generator
class ParamAdder(c_ast.NodeVisitor):
def visit_FuncDecl(self, node):
typ = c_ast.TypeDecl(
declname='_new', quals=[], align=[],
type=c_ast.IdentifierType(['int'])
)
new_decl = c_ast.Decl(
name='_new', quals=[], align=[], storage=[],
funcspec=[], type=typ, init=None, bitsize=None,
coord=node.coord
)
if node.args:
node.args.params.append(new_decl)
else:
node.args = c_ast.ParamList(params=[new_decl])
text = r"""
void foo(int a, int b) {
}
void bar() {
}
"""
if __name__ == '__main__':
parser = c_parser.CParser()
ast = parser.parse(text)
v = ParamAdder()
v.visit(ast)
generator = c_generator.CGenerator()
c_code = generator.visit(ast)
print("
[C Code]
%s" % c_code)
为每个函数增加一个新的入参:
> python rewrite_func.py
[C Code]
void foo(int a, int b, int _new)
{
}
void bar(int _new)
{
}
代码格式化
chaos.c
int sum(int n) {
if ( n <=0) return 0;
int a=0;
for(int i=0;i < n; i++)
{
a += i;
} return a;
}
int main(){
int a= sum();
return 0;
}
format_code.py
format_code.py
# code=utf-8
from pycparser import parse_file, c_generator
def translate_to_c(filename):
ast = parse_file(filename)
generator = c_generator.CGenerator()
return generator.visit(ast)
if __name__ == '__main__':
c_code = translate_to_c('c_files/chaos.c')
print("
[C Code]
%s" % c_code)
将格式混乱的代码格式化为标准格式:
> python format_code.py
[C Code]
int sum(int n)
{
if (n <= 0)
return 0;
int a = 0;
for (int i = 0; i < n; i++)
{
a += i;
}
return a;
}
int main()
{
int a = sum();
return 0;
}
代码生成
generate_code.py
# coding=utf-8
from pycparser import c_ast, c_generator
def generate_ast():
constant_zero = c_ast.Constant(type='int', value='0')
return_node = c_ast.Return(expr=constant_zero)
compound_node = c_ast.Compound(block_items=[return_node])
type_decl_node = c_ast.TypeDecl(
declname='main', quals=[],
type=c_ast.IdentifierType(names=['int']),
align=[]
)
func_decl_node = c_ast.FuncDecl(
args=c_ast.ParamList([]),
type=type_decl_node
)
func_def_node = c_ast.Decl(
name='main', quals=[], storage=[], funcspec=[],
type=func_decl_node, init=None,
bitsize=None, align=[]
)
return c_ast.FuncDef(
decl=func_def_node, param_decls=None,
body=compound_node
)
def generate_c_code(ast):
generator = c_generator.CGenerator()
return generator.visit(ast)
if __name__ == '__main__':
ast = generate_ast()
c_code = generate_c_code(ast)
print('
[C Code]
%s' % c_code)
通过 AST 生成 C 语言代码:
> python generate_code.py
[C Code]
int main()
{
return 0;
}
总结
本文介绍了 AST 的基础概念和构建原理,并展示了一些 AST 的应用示例,我们可以发现不仅仅底层的编译需要用到 AST,在应用方面 AST 也能做许多事情,因此不管是做底层还是应用层开发,都可以了解一下 AST,它能为编程带来许多可能性。