本文是记录专业课“程序语言理论与编译技术”的部分笔记。
LECTURE 22(实现一个从AST到RISCV的编译器)
一、问题分析
1、完整的编译器(如LLVM)需先完成AST到IR的转换,并进行代码优化,再到汇编,如下图:
本次实验实现从AST到RISC-V(不考虑优化)的翻译。对于汇编,我们可以访问Compiler Explorer (godbolt.org)来熟悉一下,或者参看类似下图的手册(包括汇编指令和寄存器类型等内容):
对于本次实验,我们用 x8 作为帧指针(fp) 来固定指向当前函数的栈桢基址,使用 a0 寄存器作为表达式计算结果的默认存放寄存器(函数返回值也约定放在 a0)。为了方便,我们避免寄存器分配,转而使用栈来保存局部变量和临时数据。如下图:
2、首先考虑Binop的翻译,我们参考下图:
对于Int 常数类型,我们直接将常数值加载到寄存器,对应RISC-V 提供的 li(load immediate)伪指令,将一个立即数载入寄存器。而对于Binop,我们考虑处理三种情况:
• Num Binop Num E.g. 2 + 7
• Num Binop Num Binop Num E.g. 2 + 6 * 9
• (Num Binop Num) Binop (Num Binop Num) E.g. (2 + 6) * (8 + 7)
基本流程:计算左操作数 -> 保存左值 -> 计算右操作数 -> 运算合成结果。如此之后,我们可以使用两种方案来保存:
• 将左值保存在临时寄存器中(例如t0 ),再计算右值到另一个寄存器(例如t1 )
• 采用栈,在计算左值后将其压栈保存,计算右值后再弹栈取出左值进行运算
3、然后考虑Let的翻译:
这里我们需要增加对变量和作用域的支持,使编译器可以处理let绑定和变量引用。我们需要考虑栈帧与变量地址,当进入一个新的let块,需在栈上分配新的空间保存局部变量的值(调整栈指针fp的offset)。然后在编译阶段也需要一个映射(符号表)跟踪当前作用域中每个变量名的栈偏移或寄存器位置(env)。
据此,我们简化逻辑,使fp固定不动,变量x偏移可以按照进入let的次序决定。而固定帧具体而言,使在程序一开始一次性分配最大所需栈,然后fp = sp固定。
我们可以考虑一个流程:先编译e1得到其值在a0,然后在栈上分配空间保存其值 (addi sp, sp, -8)。接着将变量名x映射到新空间的位置,例如offset += 8表示又用8字节,映射env新增 (“x”, of fset)。随后编译内部表达式e2,Var x按env找到偏移,完成后回退栈指针释放x所占的空间 (addi sp, sp, 8 )。最后恢复符号表,弹出env中x的绑定。
4、然后考虑If的翻译,我们参考下图:
我们需要扩展编译器支持Bool常量和If条件表达式,正确处理程序的控制流,为每个if表达式生成唯一的标签并正确安插跳转指令,且保证程序无论哪一种路径都会结束。
具体的,思考同前面let/变量的交互:使用标签而非固定地址。条件的各分支内部可以定义自己的变量且作用域仅限于分支内,实现需以递归的方式编译分支表达式并各自管理环境。在分支中寄存器值通过a0返回结果,无残留的临时变量,无需专门清理分支的栈。
5、最后考虑Func和App的翻译。扩展编译器支持函数定义(Func)和函数调用(App),函数的闭包包含函数代码和绑定环境的组合体。而编译Func(x, body)会产生两部分输出:全局函数代码(储存待后续统一输出),以及当前表达式计算时创建闭包的代码。
回顾Riscv寄存器,caller负责传递参数(a0-a7寄存器)和保存临时寄存器;callee可以自由使用临时寄存器但需保存和恢复caller保存寄存器(如s0等)。我们在实现时仅考虑函数调用一个显式参数,并需要注意传递参数值和环境指针。
闭包的环境结构(仅提供一种思路):分配一块内存,大小足够存放一个指针加上所有自由变量的值;将函数的代码地址写入这块内存的开头;将该函数的自由变量当前环境中的值写入内存的后续字段;将指向这块内存的指针作为闭包值放入某寄存器。
注意,Func需要一个唯一的新函数标签(类似if的标签)。对函数体进行的编译,与main形式上相同(env -> local_env + closure_env)。在main函数体换Func编译结果为上一步的闭包分配代码。
而App调用有多种方式:jalr、call等。以jalr为例:将a0(闭包指针)保存到t0,a0放参数;使用ld t1 , 0(t 0)加载闭包偏移为0处内容(代码地址)到t1;jalr ra, t 1进行跳转调用。
其余包括free、优化等问题不在本节内容之内。
二、前置代码
1、lib/ast.ml
type binop =
| Add
| Sub
| Mul
| Div
| Leq
type expr =
| Int of int
| Var of string
| Bool of bool
| Binop of binop * expr * expr
| Let of string * expr * expr
| If of expr * expr * expr
| Func of string * expr
| App of expr * expr
2、lib/lexer.mll
{
open Parser
}
rule read = parse
| [' ' '\t' '\n'] { read lexbuf }
| '+' { PLUS }
| '-' { MINUS }
| '*' { TIMES }
| '/' { DIV }
| '(' { LPAREN }
| ')' { RPAREN }
| "<=" { LEQ }
| "true" { TRUE }
| "false" { FALSE }
| "let" { LET }
| "=" { EQUALS }
| "in" { IN }
| "if" { IF }
| "then" { THEN }
| "else" { ELSE }
| "->" { ARROW }
| "fun" { FUNC }
| ['0'-'9']+ as num { INT (int_of_string num) }
| ['a'-'z' 'A'-'Z']+ as id { ID id }
| eof { EOF }
| _ { failwith "Invalid character" }
3、lib/parser.mly
%{
open Ast
(** [make_apply e [e1; e2; ...]] makes the application
[e e1 e2 ...]). Requires: the list argument is non-empty. *)
let rec make_apply e = function
| [] -> failwith "precondition violated"
| [e'] -> App (e, e')
| h :: ((_ :: _) as t) -> make_apply (App (e, h)) t
%}
%token <int> INT
%token <string> ID
%token PLUS MINUS TIMES DIV EOF
%token LPAREN RPAREN
%token LEQ
%token TRUE FALSE
%token LET EQUALS IN
%token IF THEN ELSE
%token FUNC ARROW
%nonassoc IN
%nonassoc ELSE
%left LEQ
%left PLUS MINUS
%left TIMES DIV
%start main
%type <Ast.expr> main
%%
main:
expr EOF { $1 }
;
expr:
| simpl_expr { $1 }
| simpl_expr simpl_expr+ { make_apply $1 $2 }
;
simpl_expr:
| INT { Int $1 }
| ID { Var $1 }
| TRUE { Bool true }
| FALSE { Bool false}
| simpl_expr LEQ simpl_expr { Binop (Leq, $1, $3) }
| simpl_expr TIMES simpl_expr { Binop (Mul, $1, $3) }
| simpl_expr DIV simpl_expr { Binop (Div, $1, $3) }
| simpl_expr PLUS simpl_expr { Binop (Add, $1, $3) }
| simpl_expr MINUS simpl_expr { Binop (Sub, $1, $3) }
| LET ID EQUALS simpl_expr IN simpl_expr { Let ($2, $4, $6) }
| FUNC ID ARROW expr { Func ($2, $4) }
| IF simpl_expr THEN simpl_expr ELSE simpl_expr { If ($2, $4, $6) }
| LPAREN expr RPAREN { $2 }
;
4、lib/dune
(library
(name Simpl_riscv)
(modules parser lexer ast))
(ocamllex lexer)
(menhir
(modules parser))
5、bin/dune
(executable
(public_name Simpl_riscv)
(name main)
(modules main)
(libraries Simpl_riscv)
(flags (:standard -w -32-27-26-39-8-37)))
6、bin/main.ml的部分代码
open Simpl_riscv
open Ast
let rec string_of_expr (e : expr) : string =
match e with
| Int n -> Printf.sprintf "Int %d" n
| Var id -> Printf.sprintf "Var %s" id
| Bool b ->
let b_str =
match b with
| true -> "true"
| false -> "false"
in
Printf.sprintf "Bool %s" b_str
| Binop (binop, e1, e2) ->
let binop_str =
match binop with
| Add -> "Add"
| Mul -> "Mul"
| Sub -> "Sub"
| Div -> "Div"
| Leq -> "Leq"
in
Printf.sprintf "Binop (%s, %s, %s)" binop_str (string_of_expr e1) (string_of_expr e2)
| Let (var, e1, e2) -> Printf.sprintf "Let (%s, %s, %s)" var (string_of_expr e1) (string_of_expr e2)
| If (e1, e2, e3) -> Printf.sprintf "If (%s, %s, %s)" (string_of_expr e1) (string_of_expr e2) (string_of_expr e3)
| Func (var, e) -> Printf.sprintf "Func (%s, %s)" var (string_of_expr e)
| App (e1, e2) -> Printf.sprintf "App (%s, %s)" (string_of_expr e1) (string_of_expr e2)
let parse s : expr =
let lexbuf = Lexing.from_string s in
let ast = Parser.main Lexer.read lexbuf in
ast
(* 全局标签计数器,用于生成唯一标签 *)
let label_count = ref 0
let fresh_label prefix =
incr label_count;
Printf.sprintf "%s_%d" prefix !label_count
(* 全局列表:保存所有生成的函数代码,最终附加在程序末尾 *)
let functions : string list ref = ref []
(* 简单的自由变量分析(不去重,仅适用于教学示例) *)
let rec free_vars expr bound =
match expr with
| Int _ | Bool _ -> []
| Var x -> if List.mem x bound then [] else [x]
| Binop (_, e1, e2) -> free_vars e1 bound @ free_vars e2 bound
| Let (x, e1, e2) -> free_vars e1 bound @ free_vars e2 (x :: bound)
| If (cond, e_then, e_else) ->
free_vars cond bound @ free_vars e_then bound @ free_vars e_else bound
| Func (x, body) -> free_vars body (x :: bound)
| App (e1, e2) -> free_vars e1 bound @ free_vars e2 bound
(*
compile_expr env cur_offset expr
env: (variable, offset) 的关联列表,其中 offset 是相对于 fp 的偏移(单位:字节)
cur_offset: 当前已经分配的 let 变量字节数(每个变量占 8 字节)
返回的汇编代码保证计算结果存放在寄存器 a0 中
*)
let rec compiler_expr (env : (string * int) list) (cur_offset : int) (expr : expr) : string =
match expr with
| Int n ->
Printf.sprintf "\tli a0, %d\n" n
| Bool b ->
if b then "\tli a0, 1\n" else "\tli a0, 0\n"
| Var x ->
(try
let offset = List.assoc x env in
Printf.sprintf "\tld a0, -%d(fp)\n" offset
with Not_found ->
failwith ("Unbound variable: " ^ x))
(*——————————————*)
and compile_expr_func (local_env : (string * int) list) (closure_env : (string * int) list) (cur_offset : int) (expr : expr) : string =
match expr with
(*——————————————*)
let compiler_program (e : expr) : string =
let body_code = compiler_expr [] 0 e in
let prologue =
".text\n\
.global main\n\
main:\n\
\taddi sp, sp, -64\n\
\tmv fp, sp\n"
in
let epilogue =
"\
\tmv sp, fp\n\
\taddi sp, sp, 64\n\
\tret\n"
in
let func_code = String.concat "\n" !functions in
prologue ^ body_code ^ epilogue ^ "\n" ^ func_code
let () =
let filename = "test/simpl_test4.in" in
(* let filename = "test/simpl_test2.in" in *)
let in_channel = open_in filename in
let file_content = really_input_string in_channel (in_channel_length in_channel) in
close_in in_channel;
(* let res = interp file_content in
Printf.printf "Result of interpreting %s:\n%s\n\n" filename res;
let res = interp_big file_content in
Printf.printf "Result of interpreting %s with big-step model:\n%s\n\n" filename res; *)
let ast = parse file_content in
Printf.printf "AST: %s\n" (string_of_expr ast);
let output_file = Sys.argv.(1) in
let oc = open_out output_file in
let asm_code = compiler_program ast in
output_string oc asm_code;
close_out oc;
Printf.printf "Generated RISC-V code saved to: %s\n" output_file
三、具体实现
1、bin/main.ml的compiler_expr函数
let rec compiler_expr (env : (string * int) list) (cur_offset : int) (expr : expr) : string =
match expr with
| Int n ->
Printf.sprintf "\tli a0, %d\n" n
| Bool b ->
if b then "\tli a0, 1\n" else "\tli a0, 0\n"
| Var x ->
(try
let offset = List.assoc x env in
Printf.sprintf "\tld a0, -%d(fp)\n" offset
with Not_found ->
failwith ("Unbound variable: " ^ x))
| Binop (op, e1, e2) ->
let code1 = compiler_expr env cur_offset e1 in
let push_left = "\taddi sp, sp, -8\n\tsd a0, 0(sp)\n" in
let code2 = compiler_expr env cur_offset e2 in
let pop_left = "\tld t0, 0(sp)\n\taddi sp, sp, 8\n" in
let op_code = match op with
| Add -> "\tadd a0, t0, a0\n"
| Sub -> "\tsub a0, t0, a0\n"
| Mul -> "\tmul a0, t0, a0\n"
| Div -> "\tdiv a0, t0, a0\n"
| Leq -> "Not implemented"
in
code1 ^ push_left ^ code2 ^ pop_left ^ op_code
| Let (x, e1, e2) ->
let code1 = compiler_expr env cur_offset e1 in
let new_offset = cur_offset + 8 in
let alloc = Printf.sprintf "\taddi sp, sp, -8\n\tsd a0, -%d(fp)\n" new_offset in
let env' = (x, new_offset) :: env in
let code2 = compiler_expr env' new_offset e2 in
let free = "\taddi sp, sp, 8\n" in
code1 ^ alloc ^ code2 ^ free
| If (cond, e_then, e_else) ->
let label_else = fresh_label "Lelse" in
let label_end = fresh_label "Lend" in
let code_cond = compiler_expr env cur_offset cond in
let code_then = compiler_expr env cur_offset e_then in
let code_else = compiler_expr env cur_offset e_else in
code_cond ^
Printf.sprintf "\tbeq a0, x0, %s\n" label_else ^
code_then ^
Printf.sprintf "\tj %s\n" label_end ^
Printf.sprintf "%s:\n" label_else ^
code_else ^
Printf.sprintf "%s:\n" label_end
| Func (x, body) ->
let fvs = free_vars body [x] in
let num_free = List.length fvs in
let func_id = fresh_label "func" in
let local_env = [(x, 8)] in
let closure_env = List.mapi (fun i v -> (v, 8 * i)) fvs in
let func_body_code = compile_expr_func local_env closure_env 0 body in
let func_prologue =
Printf.sprintf "%s:\n\taddi sp, sp, -16\n\tsd ra, 8(sp)\n\tsd fp, 0(sp)\n\tmv fp, sp\n" func_id
in
let func_epilogue =
"\tld ra, 8(sp)\n\tld fp, 0(sp)\n\taddi sp, sp, 16\n\tret\n"
in
let func_code = func_prologue ^ func_body_code ^ func_epilogue in
functions := !functions @ [func_code];
let closure_size = 8 * (1 + num_free) in
let alloc_code = Printf.sprintf "\tli a0, %d\n\tjal ra, malloc\n" closure_size in
let move_closure = "\tmv t0, a0\n" in
let store_code_ptr = Printf.sprintf "\tla t1, %s\n\tsd t1, 0(t0)\n" func_id in
let store_free_vars =
List.mapi (fun i v ->
let outer_offset =
try List.assoc v env with Not_found -> failwith ("Unbound free var: " ^ v)
in
Printf.sprintf "\tld t1, -%d(fp)\n\tsd t1, %d(t0)\n" outer_offset (8 * (i + 1))
) fvs |> String.concat ""
in
let ret_code = "\tmv a0, t0\n" in
alloc_code ^ move_closure ^ store_code_ptr ^ store_free_vars ^ ret_code
| App (e1, e2) ->
let code_f = compiler_expr env cur_offset e1 in
let save_closure = "\tmv t0, a0\n" in
let code_arg = compiler_expr env cur_offset e2 in
let load_env = "\taddi a1, t0, 8\n" in
let load_code_ptr = "\tld t1, 0(t0)\n" in
let call = "\tjalr ra, 0(t1)\n" in
code_f ^ save_closure ^ code_arg ^ load_env ^ load_code_ptr ^ call
2、bin/main.ml的compile_expr_func函数
and compile_expr_func (local_env : (string * int) list) (closure_env : (string * int) list) (cur_offset : int) (expr : expr) : string =
match expr with
| Int n ->
Printf.sprintf "\tli a0, %d\n" n
| Bool b ->
if b then "\tli a0, 1\n" else "\tli a0, 0\n"
| Var x ->
if List.mem_assoc x local_env then
Printf.sprintf "\tld a0, -%d(fp)\n" (List.assoc x local_env)
else if List.mem_assoc x closure_env then
Printf.sprintf "\tld a0, %d(a1)\n" (List.assoc x closure_env)
else
failwith ("Unbound variable in function: " ^ x)
| Binop (op, e1, e2) ->
let code1 = compile_expr_func local_env closure_env cur_offset e1 in
let push_left = "\taddi sp, sp, -8\n\tsd a0, 0(sp)\n" in
let code2 = compile_expr_func local_env closure_env cur_offset e2 in
let pop_left = "\tld t0, 0, 0(sp)\n\taddi sp, sp, 8\n" in
let op_code = match op with
| Add -> "\tadd a0, t0, a0\n"
| Sub -> "\tsub a0, t0, a0\n"
| Mul -> "\tmul a0, t0, a0\n"
| Div -> "\tdiv a0, t0, a0\n"
| Leq -> "Not implemented"
in
code1 ^ push_left ^ code2 ^ pop_left ^ op_code
| If _ -> failwith "Not implemented"
| Func _ | App _ -> failwith "Nested functions not supported in function bodies"