This commit is contained in:
Lorenzo Torres 2026-01-14 18:36:27 +01:00
parent 09d6cf4b46
commit ed0ad1d095
14 changed files with 846 additions and 897 deletions

View file

@ -3,8 +3,8 @@
include config.mk
SRC = lc.c utils.c lexer.c parser.c sema.c ir.c
HDR = config.def.h utils.h lexer.h parser.h sema.h ir.h
SRC = lc.c utils.c lexer.c parser.c sema.c codegen.c
HDR = config.def.h utils.h lexer.h parser.h sema.h codegen.h
OBJ = ${SRC:.c=.o}
all: options lc

736
codegen.c Normal file
View file

@ -0,0 +1,736 @@
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "codegen.h"
#include "sema.h"
#include "stb_ds.h"
typedef struct {
char *key;
int value;
} var_map;
static var_map *variables = NULL;
static int stack_offset = 0;
static int label_counter = 0;
static int *break_stack = NULL;
void gen_expr(FILE *fp, ast_node *expr);
void gen_unary(FILE *fp, ast_node *expr);
void gen_statement_list(FILE *fp, ast_node *stmt);
int get_var_offset(char *name, usize name_len);
int get_var_offset_sized(char *name, usize name_len, usize size);
void gen_binary(FILE *fp, ast_node *expr)
{
switch (expr->expr.binary.operator) {
case OP_PLUS:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "add %%rcx, %%rax\n");
break;
case OP_MINUS:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "sub %%rax, %%rcx\n");
fprintf(fp, "mov %%rcx, %%rax\n");
break;
case OP_MUL:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "imul %%rcx, %%rax\n");
break;
case OP_DIV:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "mov %%rax, %%rbx\n");
fprintf(fp, "mov %%rcx, %%rax\n");
fprintf(fp, "cqo\n");
fprintf(fp, "idiv %%rbx\n");
break;
case OP_MOD:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "mov %%rax, %%rbx\n");
fprintf(fp, "mov %%rcx, %%rax\n");
fprintf(fp, "cqo\n");
fprintf(fp, "idiv %%rbx\n");
fprintf(fp, "mov %%rdx, %%rax\n");
break;
case OP_BOR:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "or %%rcx, %%rax\n");
break;
case OP_BAND:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "and %%rcx, %%rax\n");
break;
case OP_BXOR:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "xor %%rcx, %%rax\n");
break;
case OP_EQ:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "cmp %%rax, %%rcx\n");
fprintf(fp, "sete %%al\n");
fprintf(fp, "movzx %%al, %%rax\n");
break;
case OP_NEQ:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "cmp %%rax, %%rcx\n");
fprintf(fp, "setne %%al\n");
fprintf(fp, "movzx %%al, %%rax\n");
break;
case OP_LT:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "cmp %%rax, %%rcx\n");
fprintf(fp, "setl %%al\n");
fprintf(fp, "movzx %%al, %%rax\n");
break;
case OP_GT:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "cmp %%rax, %%rcx\n");
fprintf(fp, "setg %%al\n");
fprintf(fp, "movzx %%al, %%rax\n");
break;
case OP_LE:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "cmp %%rax, %%rcx\n");
fprintf(fp, "setle %%al\n");
fprintf(fp, "movzx %%al, %%rax\n");
break;
case OP_GE:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "mov %%rax, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "cmp %%rax, %%rcx\n");
fprintf(fp, "setge %%al\n");
fprintf(fp, "movzx %%al, %%rax\n");
break;
case OP_AND:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "test %%rax, %%rax\n");
fprintf(fp, "setne %%al\n");
fprintf(fp, "movzx %%al, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "test %%rax, %%rax\n");
fprintf(fp, "setne %%al\n");
fprintf(fp, "movzx %%al, %%rax\n");
fprintf(fp, "and %%rcx, %%rax\n");
break;
case OP_OR:
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "test %%rax, %%rax\n");
fprintf(fp, "setne %%al\n");
fprintf(fp, "movzx %%al, %%rcx\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "test %%rax, %%rax\n");
fprintf(fp, "setne %%al\n");
fprintf(fp, "movzx %%al, %%rax\n");
fprintf(fp, "or %%rcx, %%rax\n");
break;
case OP_ASSIGN: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: left side of assignment must be identifier\n");
break;
}
gen_expr(fp, expr->expr.binary.right);
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
break;
}
case OP_ASSIGN_PTR: {
gen_expr(fp, expr->expr.binary.left);
fprintf(fp, "push %%rax\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "pop %%rcx\n");
fprintf(fp, "mov %%rax, (%%rcx)\n");
break;
}
case OP_PLUS_EQ: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: left side of assignment must be identifier\n");
break;
}
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset);
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "add %%rcx, %%rax\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
break;
}
case OP_MINUS_EQ: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: left side of assignment must be identifier\n");
break;
}
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset);
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "sub %%rax, %%rcx\n");
fprintf(fp, "mov %%rcx, %%rax\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
break;
}
case OP_MUL_EQ: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: left side of assignment must be identifier\n");
break;
}
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset);
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "imul %%rcx, %%rax\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
break;
}
case OP_DIV_EQ: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: left side of assignment must be identifier\n");
break;
}
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset);
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "mov %%rax, %%rbx\n");
fprintf(fp, "mov %%rcx, %%rax\n");
fprintf(fp, "cqo\n");
fprintf(fp, "idiv %%rbx\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
break;
}
case OP_MOD_EQ: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: left side of assignment must be identifier\n");
break;
}
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset);
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "mov %%rax, %%rbx\n");
fprintf(fp, "mov %%rcx, %%rax\n");
fprintf(fp, "cqo\n");
fprintf(fp, "idiv %%rbx\n");
fprintf(fp, "mov %%rdx, -%d(%%rbp)\n", offset);
break;
}
case OP_BOR_EQ: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: left side of assignment must be identifier\n");
break;
}
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset);
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "or %%rcx, %%rax\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
break;
}
case OP_BAND_EQ: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: left side of assignment must be identifier\n");
break;
}
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset);
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "and %%rcx, %%rax\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
break;
}
case OP_BXOR_EQ: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: left side of assignment must be identifier\n");
break;
}
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset);
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "xor %%rcx, %%rax\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
break;
}
case OP_RSHIFT_EQ: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: left side of assignment must be identifier\n");
break;
}
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rax\n", offset);
fprintf(fp, "push %%rax\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "mov %%rax, %%rcx\n");
fprintf(fp, "pop %%rax\n");
fprintf(fp, "sar %%cl, %%rax\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
break;
}
case OP_LSHIFT_EQ: {
if (expr->expr.binary.left->type != NODE_IDENTIFIER) {
break;
}
int offset = get_var_offset(expr->expr.binary.left->expr.string.start,
expr->expr.binary.left->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rax\n", offset);
fprintf(fp, "push %%rax\n");
gen_expr(fp, expr->expr.binary.right);
fprintf(fp, "mov %%rax, %%rcx\n");
fprintf(fp, "pop %%rax\n");
fprintf(fp, "shl %%cl, %%rax\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
break;
}
}
}
int get_var_offset(char *name, usize name_len)
{
char *var_name = strndup(name, name_len);
ptrdiff_t idx = shgeti(variables, var_name);
if (idx >= 0) {
free(var_name);
return variables[idx].value;
}
stack_offset += 8;
shput(variables, var_name, stack_offset);
return stack_offset;
}
int get_var_offset_sized(char *name, usize name_len, usize size)
{
char *var_name = strndup(name, name_len);
ptrdiff_t idx = shgeti(variables, var_name);
if (idx >= 0) {
free(var_name);
return variables[idx].value;
}
usize aligned_size = (size + 7) & ~7;
stack_offset += aligned_size;
shput(variables, var_name, stack_offset);
return stack_offset;
}
void gen_statement_list(FILE *fp, ast_node *stmt)
{
if (!stmt) return;
if (stmt->type == NODE_UNIT) {
ast_node *current = stmt;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
gen_expr(fp, current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
} else {
gen_expr(fp, stmt);
}
}
void gen_unary(FILE *fp, ast_node *expr)
{
switch (expr->expr.unary.operator) {
case UOP_MINUS:
gen_expr(fp, expr->expr.unary.right);
fprintf(fp, "neg %%rax\n");
break;
case UOP_NOT:
gen_expr(fp, expr->expr.unary.right);
fprintf(fp, "test %%rax, %%rax\n");
fprintf(fp, "sete %%al\n");
fprintf(fp, "movzx %%al, %%rax\n");
break;
case UOP_INCR:
if (expr->expr.unary.right->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: increment requires identifier\n");
break;
}
int offset_incr = get_var_offset(expr->expr.unary.right->expr.string.start,
expr->expr.unary.right->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rax\n", offset_incr);
fprintf(fp, "inc %%rax\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset_incr);
break;
case UOP_DECR:
if (expr->expr.unary.right->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: decrement requires identifier\n");
break;
}
int offset_decr = get_var_offset(expr->expr.unary.right->expr.string.start,
expr->expr.unary.right->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rax\n", offset_decr);
fprintf(fp, "dec %%rax\n");
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset_decr);
break;
case UOP_REF:
if (expr->expr.unary.right->type != NODE_IDENTIFIER) {
fprintf(fp, "# ERROR: address-of requires identifier\n");
break;
}
int offset_ref = get_var_offset(expr->expr.unary.right->expr.string.start,
expr->expr.unary.right->expr.string.len);
fprintf(fp, "lea -%d(%%rbp), %%rax\n", offset_ref);
break;
case UOP_DEREF:
gen_expr(fp, expr->expr.unary.right);
fprintf(fp, "mov (%%rax), %%rax\n");
break;
}
}
void gen_expr(FILE *fp, ast_node *expr)
{
switch (expr->type) {
case NODE_INTEGER:
fprintf(fp, "mov $%lu, %%rax\n", expr->expr.integer);
break;
case NODE_FLOAT: {
// TODO: do not truncate
i64 int_val = (i64)expr->expr.flt;
fprintf(fp, "mov $%ld, %%rax\n", int_val);
break;
}
case NODE_CHAR:
fprintf(fp, "mov $%d, %%rax\n", (int)(unsigned char)expr->expr.ch);
break;
case NODE_BOOL:
fprintf(fp, "mov $%d, %%rax\n", expr->expr.boolean ? 1 : 0);
break;
case NODE_IDENTIFIER: {
int offset = get_var_offset(expr->expr.string.start, expr->expr.string.len);
fprintf(fp, "mov -%d(%%rbp), %%rax\n", offset);
break;
}
case NODE_BINARY:
gen_binary(fp, expr);
break;
case NODE_UNARY:
gen_unary(fp, expr);
break;
case NODE_CAST:
gen_expr(fp, expr->expr.cast.value);
break;
case NODE_VAR_DECL: {
int offset;
if (expr->expr.var_decl.type && expr->expr.var_decl.type->expr_type) {
usize var_size = expr->expr.var_decl.type->expr_type->size;
offset = get_var_offset_sized(expr->expr.var_decl.name,
expr->expr.var_decl.name_len, var_size);
} else {
offset = get_var_offset(expr->expr.var_decl.name, expr->expr.var_decl.name_len);
}
if (expr->expr.var_decl.value) {
if (expr->expr.var_decl.value->type == NODE_STRUCT_INIT) {
ast_node *member_list = expr->expr.var_decl.value->expr.struct_init.members;
ast_node *current = member_list;
type *struct_type = expr->expr_type;
if (!struct_type && expr->expr.var_decl.type) {
struct_type = expr->expr.var_decl.type->expr_type;
}
while (current && current->type == NODE_UNIT) {
ast_node *assignment = current->expr.unit_node.expr;
if (assignment && assignment->type == NODE_BINARY &&
assignment->expr.binary.operator == OP_ASSIGN) {
ast_node *field = assignment->expr.binary.left;
ast_node *value = assignment->expr.binary.right;
if (field->type == NODE_IDENTIFIER && struct_type &&
struct_type->tag == TYPE_STRUCT) {
char *field_name = strndup(field->expr.string.start,
field->expr.string.len);
member *m = struct_type->data.structure.members;
int field_offset = -1;
while (m) {
if (m->name_len == field->expr.string.len &&
strncmp(m->name, field->expr.string.start, m->name_len) == 0) {
field_offset = m->offset;
break;
}
m = m->next;
}
if (field_offset >= 0) {
gen_expr(fp, value);
type *field_type = shget(struct_type->data.structure.member_types, field_name);
if (field_type && field_type->size == 4) {
fprintf(fp, "mov %%eax, -%d(%%rbp)\n", offset + field_offset);
} else if (field_type && field_type->size == 2) {
fprintf(fp, "mov %%ax, -%d(%%rbp)\n", offset + field_offset);
} else if (field_type && field_type->size == 1) {
fprintf(fp, "mov %%al, -%d(%%rbp)\n", offset + field_offset);
} else {
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset + field_offset);
}
}
free(field_name);
}
}
current = current->expr.unit_node.next;
}
} else {
gen_expr(fp, expr->expr.var_decl.value);
fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset);
}
}
break;
}
case NODE_RETURN: {
if (expr->expr.ret.value) {
gen_expr(fp, expr->expr.ret.value);
}
fprintf(fp, "mov %%rbp, %%rsp\n");
fprintf(fp, "pop %%rbp\n");
fprintf(fp, "ret\n");
break;
}
case NODE_IF: {
int label_else = label_counter++;
int label_end = label_counter++;
gen_expr(fp, expr->expr.if_stmt.condition);
fprintf(fp, "test %%rax, %%rax\n");
if (expr->expr.if_stmt.otherwise) {
fprintf(fp, "jz .L%d\n", label_else);
} else {
fprintf(fp, "jz .L%d\n", label_end);
}
gen_statement_list(fp, expr->expr.if_stmt.body);
if (expr->expr.if_stmt.otherwise) {
fprintf(fp, "jmp .L%d\n", label_end);
fprintf(fp, ".L%d:\n", label_else);
gen_statement_list(fp, expr->expr.if_stmt.otherwise);
}
fprintf(fp, ".L%d:\n", label_end);
break;
}
case NODE_WHILE: {
int label_start = label_counter++;
int label_end = label_counter++;
fprintf(fp, ".L%d:\n", label_start);
gen_expr(fp, expr->expr.whle.condition);
fprintf(fp, "test %%rax, %%rax\n");
fprintf(fp, "jz .L%d\n", label_end);
arrput(break_stack, label_end);
gen_statement_list(fp, expr->expr.whle.body);
arrpop(break_stack);
fprintf(fp, "jmp .L%d\n", label_start);
fprintf(fp, ".L%d:\n", label_end);
break;
}
case NODE_LABEL: {
char *label_name = strndup(expr->expr.label.name, expr->expr.label.name_len);
fprintf(fp, ".L_%s:\n", label_name);
free(label_name);
break;
}
case NODE_GOTO: {
char *label_name = strndup(expr->expr.label.name, expr->expr.label.name_len);
fprintf(fp, "jmp .L_%s\n", label_name);
free(label_name);
break;
}
case NODE_BREAK: {
if (arrlen(break_stack) > 0) {
int loop_end = break_stack[arrlen(break_stack) - 1];
fprintf(fp, "jmp .L%d\n", loop_end);
} else {
fprintf(fp, "# ERROR: break outside of loop\n");
}
break;
}
case NODE_ACCESS: {
ast_node *base = expr->expr.access.expr;
ast_node *member_node = expr->expr.access.member;
if (base->type == NODE_IDENTIFIER) {
int base_offset = get_var_offset(base->expr.string.start, base->expr.string.len);
type *struct_type = base->expr_type;
if (member_node->type == NODE_IDENTIFIER && struct_type &&
struct_type->tag == TYPE_STRUCT) {
member *m = struct_type->data.structure.members;
int field_offset = -1;
while (m) {
if (m->name_len == member_node->expr.string.len &&
strncmp(m->name, member_node->expr.string.start, m->name_len) == 0) {
field_offset = m->offset;
break;
}
m = m->next;
}
if (field_offset >= 0) {
fprintf(fp, "mov -%d(%%rbp), %%rax\n", base_offset + field_offset);
} else {
fprintf(fp, "# ERROR: field not found\n");
}
} else {
fprintf(fp, "# ERROR: not a struct type\n");
}
} else {
fprintf(fp, "# ERROR: complex struct access not implemented\n");
}
break;
}
case NODE_STRUCT_INIT: {
fprintf(fp, "# ERROR: struct init outside of variable declaration\n");
break;
}
case NODE_CALL: {
const char *arg_regs[] = {"%rdi", "%rsi", "%rdx", "%rcx", "%r8", "%r9"};
int param_count = 0;
ast_node *param = expr->expr.call.parameters;
while (param && param->type == NODE_UNIT) {
param_count++;
param = param->expr.unit_node.next;
}
param = expr->expr.call.parameters;
int i = 0;
while (param && param->type == NODE_UNIT) {
if (param->expr.unit_node.expr) {
gen_expr(fp, param->expr.unit_node.expr);
fprintf(fp, "push %%rax\n");
}
param = param->expr.unit_node.next;
i++;
}
for (int j = param_count - 1; j >= 0; j--) {
if (j < 6) {
fprintf(fp, "pop %s\n", arg_regs[j]);
} else {
// TODO: handle more than 6 arguments properly
fprintf(fp, "pop %%rax\n");
fprintf(fp, "push %%rax\n");
}
}
fprintf(fp, "call %.*s\n", (int)expr->expr.call.name_len, expr->expr.call.name);
if (param_count > 6) {
int stack_args = param_count - 6;
fprintf(fp, "add $%d, %%rsp\n", stack_args * 8);
}
break;
}
}
}
void gen_function(FILE *fp, ast_node *fn)
{
ast_node *current = fn->expr.function.body;
stack_offset = 0;
label_counter = 0;
shfree(variables);
variables = NULL;
arrfree(break_stack);
break_stack = NULL;
fprintf(fp, ".global %s\n%s:\n", fn->expr.function.name, fn->expr.function.name);
fprintf(fp, "push %%rbp\n");
fprintf(fp, "mov %%rsp, %%rbp\n");
fprintf(fp, "sub $256, %%rsp\n");
const char *param_regs[] = {"%rdi", "%rsi", "%rdx", "%rcx", "%r8", "%r9"};
member *param = fn->expr.function.parameters;
int param_idx = 0;
while (param && param_idx < 6) {
int offset = get_var_offset(param->name, param->name_len);
fprintf(fp, "mov %s, -%d(%%rbp)\n", param_regs[param_idx], offset);
param = param->next;
param_idx++;
}
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
gen_expr(fp, current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
fprintf(fp, "mov %%rbp, %%rsp\n");
fprintf(fp, "pop %%rbp\n");
fprintf(fp, "ret\n");
}
void generate(ast_node *node)
{
FILE *fp = fopen("test.s", "w");
ast_node *current = node;
fprintf(fp, ".section .text\n");
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) {
gen_function(fp, current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
fclose(fp);
shfree(variables);
variables = NULL;
arrfree(break_stack);
break_stack = NULL;
}

8
codegen.h Normal file
View file

@ -0,0 +1,8 @@
#ifndef CODEGEN_H
#define CODEGEN_H
#include "parser.h"
void generate(ast_node *node);
#endif

812
ir.c
View file

@ -1,812 +0,0 @@
#include "ir.h"
#include <stdlib.h>
#include <stdio.h>
#include "stb_ds.h"
#include "sema.h"
struct { ir_node key; ir_node *value; } *global_hash = NULL;
static ir_node *graph;
static ir_node *current_memory;
static ir_node *current_control;
static usize current_stack = 0;
static ir_node *current_scope = NULL;
static ir_node *build_expression(ast_node *node);
static struct {
ir_node **return_controls;
ir_node **return_memories;
ir_node **return_values;
} current_func = {0};
static void node_name(ir_node *node)
{
if (!node) {
printf("null [label=\"NULL\", style=filled, fillcolor=red]\n");
return;
}
printf("%ld ", node->id);
switch (node->code) {
case OC_START:
printf("[label=\"%s\", style=filled, color=orange]\n", node->data.start_name);
break;
case OC_RETURN:
printf("[label=\"return\", style=filled, color=orange]\n");
break;
case OC_ADD:
printf("[label=\"+\"]\n");
break;
case OC_NEG:
case OC_SUB:
printf("[label=\"-\"]\n");
break;
case OC_DIV:
printf("[label=\"/\"]\n");
break;
case OC_MUL:
printf("[label=\"*\"]\n");
break;
case OC_MOD:
printf("[label=\"%%\"]\n");
break;
case OC_BAND:
printf("[label=\"&\"]\n");
break;
case OC_BOR:
printf("[label=\"|\"]\n");
break;
case OC_BXOR:
printf("[label=\"^\"]\n");
break;
case OC_EQ:
printf("[label=\"==\"]\n");
break;
case OC_CONST_INT:
printf("[label=\"%ld\"]\n", node->data.const_int);
break;
case OC_CONST_FLOAT:
printf("[label=\"%f\"]\n", node->data.const_float);
break;
case OC_FRAME_PTR:
printf("[label=\"frame_ptr\"]\n");
break;
case OC_STORE:
printf("[label=\"store\", shape=box]\n");
break;
case OC_LOAD:
printf("[label=\"load\", shape=box]\n");
break;
case OC_ADDR:
printf("[label=\"addr\"]\n");
break;
case OC_REGION:
printf("[label=\"region\", shape=diamond, style=filled, color=green]\n");
break;
case OC_PHI:
printf("[label=\"phi\", shape=triangle]\n");
break;
case OC_IF:
printf("[label=\"if\", shape=diamond, style=filled, color=lightblue]\n");
break;
case OC_PROJ:
printf("[label=\"proj\", shape=diamond, style=filled, color=cyan]\n");
break;
default:
printf("[label=\"%d\"]\n", node->code);
break;
}
}
static void print_graph(ir_node *node)
{
for (int i = 0; i < hmlen(global_hash); i++) {
ir_node *node = global_hash[i].value;
node_name(node);
for (int j = 0; j < arrlen(node->out); j++) {
if (node->out[j]) {
node_name(node->out[j]);
printf("%ld->%ld\n", node->out[j]->id, node->id);
}
}
}
}
static void push_scope(void)
{
arrput(current_scope->data.symbol_tables, NULL);
}
static struct symbol_def *get_def(char *name)
{
for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) {
struct symbol_def *def = shget(current_scope->data.symbol_tables[i], name);
if (def) return def;
}
return NULL;
}
static void set_def(char *name, ir_node *node, bool lvalue)
{
for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) {
if (shget(current_scope->data.symbol_tables[i], name)) {
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->is_lvalue = lvalue;
def->node = node;
shput(current_scope->data.symbol_tables[i], name, def);
return;
}
}
int index = arrlen(current_scope->data.symbol_tables) - 1;
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->is_lvalue = lvalue;
def->node = node;
shput(current_scope->data.symbol_tables[index], name, def);
}
static ir_node *copy_scope(ir_node *src)
{
ir_node *dst = calloc(1, sizeof(ir_node));
dst->code = OC_SCOPE;
for (int i=0; i < arrlen(src->data.symbol_tables); i++) {
arrput(dst->data.symbol_tables, NULL);
symbol_table *src_table = src->data.symbol_tables[i];
for (int j=0; j < shlen(src_table); j++) {
shput(dst->data.symbol_tables[i], src_table[j].key, src_table[j].value);
}
}
return dst;
}
static void const_fold(ir_node *binary)
{
ir_node *left = binary->out[0];
ir_node *right = binary->out[1];
if (left->code == OC_CONST_INT && right->code == OC_CONST_INT) {
switch (binary->code) {
case OC_ADD:
binary->data.const_int = left->data.const_int + right->data.const_int;
break;
case OC_SUB:
binary->data.const_int = left->data.const_int - right->data.const_int;
break;
case OC_MUL:
binary->data.const_int = left->data.const_int * right->data.const_int;
break;
case OC_DIV:
if (right->data.const_int != 0)
binary->data.const_int = left->data.const_int / right->data.const_int;
break;
case OC_MOD:
if (right->data.const_int != 0)
binary->data.const_int = left->data.const_int % right->data.const_int;
break;
case OC_BOR:
binary->data.const_int = left->data.const_int | right->data.const_int;
break;
case OC_BAND:
binary->data.const_int = left->data.const_int & right->data.const_int;
break;
case OC_BXOR:
binary->data.const_int = left->data.const_int ^ right->data.const_int;
break;
case OC_EQ:
binary->data.const_int = left->data.const_int == right->data.const_int;
break;
default:
return;
}
binary->code = OC_CONST_INT;
arrfree(binary->out); binary->out = NULL;
arrfree(binary->in); binary->in = NULL;
binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe);
}
if (left->code == OC_CONST_FLOAT && right->code == OC_CONST_FLOAT) {
switch (binary->code) {
case OC_ADD:
binary->data.const_float = left->data.const_float + right->data.const_float;
break;
case OC_SUB:
binary->data.const_float = left->data.const_float - right->data.const_float;
break;
case OC_MUL:
binary->data.const_float = left->data.const_float * right->data.const_float;
break;
case OC_DIV:
if (right->data.const_float != 0.0f)
binary->data.const_float = left->data.const_float / right->data.const_float;
break;
default:
return;
}
binary->code = OC_CONST_FLOAT;
arrfree(binary->out); binary->out = NULL;
arrfree(binary->in); binary->in = NULL;
binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe);
}
}
static ir_node *build_address(usize base, usize offset) {
ir_node *addr = calloc(1, sizeof(ir_node));
addr->code = OC_ADDR;
ir_node *base_node = calloc(1, sizeof(ir_node));
if (base == -1) {
base_node->code = OC_FRAME_PTR;
base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe);
} else {
base_node->code = OC_CONST_INT;
base_node->data.const_int = base;
base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe);
}
ir_node *offset_node = calloc(1, sizeof(ir_node));
offset_node->code = OC_CONST_INT;
offset_node->data.const_int = offset;
offset_node->id = stbds_hash_bytes(offset_node, sizeof(ir_node), 0xcafebabe);
arrput(addr->out, base_node);
arrput(addr->out, offset_node);
addr->id = stbds_hash_bytes(addr, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *addr);
if (tmp) {
free(addr);
return tmp;
}
return addr;
}
static ir_node *build_assign_ptr(ast_node *binary)
{
ir_node *val_node = build_expression(binary->expr.binary.right);
char *var_name = binary->expr.binary.left->expr.string.start;
ir_node *existing_def = get_def(var_name)->node;
ir_node *store = calloc(1, sizeof(ir_node));
store->code = OC_STORE;
arrput(store->out, current_control);
arrput(store->out, current_memory);
arrput(store->out, existing_def);
arrput(store->out, val_node);
store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *store, store);
current_memory = store;
return val_node;
}
static ir_node *build_assign(ast_node *binary)
{
ir_node *val_node = build_expression(binary->expr.binary.right);
char *var_name = binary->expr.binary.left->expr.string.start;
struct symbol_def *def = get_def(var_name);
if (def && def->is_lvalue) {
ir_node *existing_def = def->node;
ir_node *store = calloc(1, sizeof(ir_node));
store->code = OC_STORE;
arrput(store->out, current_control);
arrput(store->out, current_memory);
arrput(store->out, existing_def);
arrput(store->out, val_node);
store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *store, store);
current_memory = store;
return val_node;
}
set_def(var_name, val_node, false);
return val_node;
}
static ir_node *build_binary(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
switch (node->expr.binary.operator) {
case OP_ASSIGN:
free(n);
return build_assign(node);
case OP_ASSIGN_PTR:
free(n);
return build_assign_ptr(node);
case OP_PLUS:
n->code = OC_ADD;
break;
case OP_MINUS:
n->code = OC_SUB;
break;
case OP_DIV:
n->code = OC_DIV;
break;
case OP_MUL:
n->code = OC_MUL;
break;
case OP_MOD:
n->code = OC_MOD;
break;
case OP_BOR:
n->code = OC_BOR;
break;
case OP_BAND:
n->code = OC_BAND;
break;
case OP_BXOR:
n->code = OC_BXOR;
break;
case OP_EQ:
n->code = OC_EQ;
break;
default:
break;
}
arrput(n->out, build_expression(node->expr.binary.left));
arrput(n->out, build_expression(node->expr.binary.right));
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
const_fold(n);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_load(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
n->code = OC_LOAD;
arrput(n->out, current_memory);
arrput(n->out, build_expression(node));
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabebabecafe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_unary(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
switch (node->expr.unary.operator) {
case UOP_MINUS:
n->code = OC_NEG;
arrput(n->out, build_expression(node->expr.unary.right));
break;
case UOP_REF:
free(n);
if (node->expr.unary.right->type == NODE_IDENTIFIER) {
struct symbol_def *def = get_def(node->expr.unary.right->expr.string.start);
if (def) {
return def->node;
}
}
return build_expression(node->expr.unary.right);
case UOP_DEREF:
free(n);
return build_load(node->expr.unary.right);
default:
break;
}
if (n->out && n->out[0]->code == OC_CONST_INT) {
switch (n->code) {
case OC_NEG:
n->data.const_int = -(n->out[0]->data.const_int);
break;
default:
break;
}
n->code = OC_CONST_INT;
arrfree(n->out); n->out = NULL;
} else if (n->out && n->out[0]->code == OC_CONST_FLOAT) {
switch (n->code) {
case OC_NEG:
n->data.const_float = -(n->out[0]->data.const_float);
break;
default:
break;
}
n->code = OC_CONST_FLOAT;
arrfree(n->out); n->out = NULL;
}
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_if(ast_node *node)
{
ir_node *condition = build_expression(node->expr.if_stmt.condition);
ir_node *if_node = calloc(1, sizeof(ir_node));
if_node->code = OC_IF;
arrput(if_node->out, condition);
arrput(if_node->out, current_control);
if_node->id = stbds_hash_bytes(if_node, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *if_node, if_node);
ir_node *proj_true = calloc(1, sizeof(ir_node));
proj_true->code = OC_PROJ;
arrput(proj_true->out, if_node);
proj_true->id = stbds_hash_bytes(proj_true, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *proj_true, proj_true);
ir_node *proj_false = calloc(1, sizeof(ir_node));
proj_false->code = OC_PROJ;
arrput(proj_false->out, if_node);
proj_false->id = stbds_hash_bytes(proj_false, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *proj_false, proj_false);
ir_node *base_scope = copy_scope(current_scope);
ir_node *base_mem = current_memory;
current_control = proj_true;
ast_node *current = node->expr.if_stmt.body;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
ir_node *then_scope = current_scope;
ir_node *then_mem = current_memory;
ir_node *then_control = current_control;
current_scope = copy_scope(base_scope);
current_memory = base_mem;
current_control = proj_false;
current = node->expr.if_stmt.otherwise;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
ir_node *else_scope = current_scope;
ir_node *else_mem = current_memory;
ir_node *else_control = current_control;
ir_node *region = calloc(1, sizeof(ir_node));
region->code = OC_REGION;
arrput(region->out, then_control);
arrput(region->out, else_control);
region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *region, region);
if (then_mem->id != else_mem->id) {
ir_node *phi = calloc(1, sizeof(ir_node));
phi->code = OC_PHI;
arrput(phi->out, region);
arrput(phi->out, then_mem);
arrput(phi->out, else_mem);
phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *phi, phi);
current_memory = phi;
} else {
current_memory = then_mem;
}
current_scope = base_scope;
for (int i = 0; i < arrlen(current_scope->data.symbol_tables); i++) {
symbol_table *base_table = current_scope->data.symbol_tables[i];
for (int j = 0; j < shlen(base_table); j++) {
char *key = base_table[j].key;
ir_node *found_then = NULL;
symbol_table *t_table = then_scope->data.symbol_tables[i];
if (shget(t_table, key)->node) found_then = shget(t_table, key)->node;
else found_then = base_table[j].value->node;
ir_node *found_else = NULL;
symbol_table *e_table = else_scope->data.symbol_tables[i];
if (shget(e_table, key)->node) found_else = shget(e_table, key)->node;
else found_else = base_table[j].value->node;
if (found_then->id != found_else->id) {
ir_node *phi = calloc(1, sizeof(ir_node));
phi->code = OC_PHI;
arrput(phi->out, region);
arrput(phi->out, found_then);
arrput(phi->out, found_else);
phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe);
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->node = phi;
def->is_lvalue = false;
shput(current_scope->data.symbol_tables[i], key, def);
hmput(global_hash, *phi, phi);
} else {
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->node = found_then;
def->is_lvalue = false;
shput(current_scope->data.symbol_tables[i], key, def);
}
}
}
current_control = region;
return region;
}
static void build_return(ast_node *node)
{
ir_node *val = NULL;
if (node->expr.ret.value) {
val = build_expression(node->expr.ret.value);
} else {
val = calloc(1, sizeof(ir_node));
val->code = OC_VOID;
val->id = stbds_hash_bytes(val, sizeof(ir_node), 0xcafebabe);
}
arrput(current_func.return_controls, current_control);
arrput(current_func.return_memories, current_memory);
arrput(current_func.return_values, val);
current_control = NULL;
}
static void finalize_function(void)
{
int count = arrlen(current_func.return_controls);
if (count == 0) {
return;
}
ir_node *final_ctrl = NULL;
ir_node *final_mem = NULL;
ir_node *final_val = NULL;
if (count == 1) {
final_ctrl = current_func.return_controls[0];
final_mem = current_func.return_memories[0];
final_val = current_func.return_values[0];
}
else {
ir_node *region = calloc(1, sizeof(ir_node));
region->code = OC_REGION;
for (int i=0; i<count; i++) {
arrput(region->out, current_func.return_controls[i]);
}
hmput(global_hash, *region, region);
final_ctrl = region;
ir_node *mem_phi = calloc(1, sizeof(ir_node));
mem_phi->code = OC_PHI;
arrput(mem_phi->out, region);
for (int i=0; i<count; i++) {
arrput(mem_phi->out, current_func.return_memories[i]);
}
hmput(global_hash, *mem_phi, mem_phi);
mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe);
final_mem = mem_phi;
ir_node *val_phi = calloc(1, sizeof(ir_node));
val_phi->code = OC_PHI;
//arrput(val_phi->out, region);
for (int i=0; i<count; i++) {
arrput(val_phi->out, current_func.return_values[i]);
}
val_phi->id = stbds_hash_bytes(val_phi, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *val_phi, val_phi);
final_val = val_phi;
region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe);
}
ir_node *ret = calloc(1, sizeof(ir_node));
ret->code = OC_RETURN;
arrput(ret->out, final_ctrl);
arrput(ret->out, final_mem);
arrput(ret->out, final_val);
ret->id = stbds_hash_bytes(ret, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *ret, ret);
}
static ir_node *build_function(ast_node *node)
{
memset(&current_func, 0x0, sizeof(current_func));
ast_node *current = node->expr.function.body;
ir_node *func = calloc(1, sizeof(ir_node));
func->code = OC_START;
func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe);
func->data.start_name = node->expr.function.name;
ir_node *start_ctrl = calloc(1, sizeof(ir_node));
start_ctrl->code = OC_PROJ;
start_ctrl->id = stbds_hash_bytes(&start_ctrl, sizeof(usize), 0xcafebabe);
arrput(start_ctrl->out, func);
hmput(global_hash, *start_ctrl, start_ctrl);
current_control = start_ctrl;
ir_node *start_mem = calloc(1, sizeof(ir_node));
start_mem->code = OC_PROJ;
start_mem->id = stbds_hash_bytes(&start_mem, sizeof(usize), 0xcafebabe);
arrput(start_mem->out, func);
hmput(global_hash, *start_mem, start_mem);
current_memory = start_mem;
current_scope = calloc(1, sizeof(ir_node));
current_scope->code = OC_SCOPE;
push_scope();
member *m = node->expr.function.parameters;
while (m) {
ir_node *proj_param = calloc(1, sizeof(ir_node));
proj_param->code = OC_PROJ;
arrput(proj_param->out, func);
proj_param->id = stbds_hash_bytes(proj_param, sizeof(ir_node), 0xcafebabe);
set_def(m->name, proj_param, false);
hmput(global_hash, *proj_param, proj_param);
m = m->next;
}
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe);
finalize_function();
return func;
}
static ir_node *build_expression(ast_node *node)
{
ir_node *n = NULL;
ir_node *tmp = NULL;
switch (node->type) {
case NODE_UNARY:
n = build_unary(node);
break;
case NODE_BINARY:
n = build_binary(node);
break;
case NODE_INTEGER:
n = calloc(1, sizeof(ir_node));
n->code = OC_CONST_INT;
n->data.const_int = node->expr.integer;
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
break;
case NODE_VAR_DECL:
n = calloc(1, sizeof(ir_node));
if (node->address_taken) {
n->code = OC_STORE;
arrput(n->out, current_memory);
arrput(n->out, build_address(-1, current_stack));
arrput(n->out, build_expression(node->expr.var_decl.value));
current_memory = n;
current_stack += node->expr_type->size;
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *n, n);
n = n->out[1];
set_def(node->expr.var_decl.name, n, true);
} else {
n = build_expression(node->expr.var_decl.value);
set_def(node->expr.var_decl.name, n, false);
}
return n;
case NODE_IDENTIFIER:
struct symbol_def *def = get_def(node->expr.string.start);
n = def->node;
if (n && def->is_lvalue) {
ir_node *addr_node = n;
n = calloc(1, sizeof(ir_node));
n->code = OC_LOAD;
arrput(n->out, current_memory);
arrput(n->out, addr_node);
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
n = tmp;
} else {
hmput(global_hash, *n, n);
}
}
break;
case NODE_IF:
n = build_if(node);
break;
case NODE_RETURN:
build_return(node);
break;
default:
break;
}
if (n) hmput(global_hash, *n, n);
return n;
}
void ir_build(ast_node *ast)
{
ast_node *current = ast;
graph = calloc(1, sizeof(ir_node));
graph->code = OC_START;
graph->id = stbds_hash_bytes(graph, sizeof(ir_node), 0xcafebabe);
graph->data.start_name = "program";
current_memory = calloc(1, sizeof(ir_node));
current_memory->code = OC_FRAME_PTR;
current_memory->id = stbds_hash_bytes(current_memory, sizeof(ir_node), 0xcafebabe);
current_scope = calloc(1, sizeof(ir_node));
current_scope->code = OC_SCOPE;
push_scope();
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) {
ir_node *expr = build_function(current->expr.unit_node.expr);
arrput(graph->out, expr);
hmput(global_hash, *expr, expr);
}
current = current->expr.unit_node.next;
}
printf("digraph G {\n");
print_graph(graph);
printf("}\n");
}

65
ir.h
View file

@ -1,65 +0,0 @@
#ifndef IR_H
#define IR_H
#include "utils.h"
#include "parser.h"
struct _ir_node;
struct symbol_def {
struct _ir_node *node;
bool is_lvalue;
};
typedef struct { char *key; struct symbol_def *value; } symbol_table;
typedef enum {
OC_START,
OC_ADD,
OC_SUB,
OC_MUL,
OC_DIV,
OC_MOD,
OC_BAND,
OC_BOR,
OC_BXOR,
OC_NEG,
OC_EQ,
OC_CONST_INT,
OC_CONST_FLOAT,
OC_VOID,
OC_FRAME_PTR,
OC_ADDR,
OC_STORE,
OC_LOAD,
OC_REGION,
OC_PHI,
OC_IF,
OC_PROJ,
OC_STOP,
OC_RETURN,
OC_SCOPE,
} opcode;
typedef struct _ir_node {
opcode code;
usize id;
struct _ir_node **in;
struct _ir_node **out;
union {
i64 const_int;
f64 const_float;
symbol_table **symbol_tables;
char *start_name;
} data;
} ir_node;
void ir_build(ast_node *ast);
#endif

6
lc.c
View file

@ -4,7 +4,7 @@
#include "lexer.h"
#include "parser.h"
#include "sema.h"
#include "ir.h"
#include "codegen.h"
void print_indent(int depth) {
for (int i = 0; i < depth; i++) printf(" ");
@ -230,10 +230,10 @@ int main(void)
arena a = arena_init(0x1000 * 0x1000 * 64);
lexer *l = lexer_init(src, size, &a);
parser *p = parser_init(l, &a);
//print_ast(p->ast, 0);
print_ast(p->ast, 0);
sema_init(p, &a);
ir_build(p->ast);
generate(p->ast);
arena_deinit(a);

36
sema.c
View file

@ -26,9 +26,11 @@ static type *const_int = NULL;
static type *const_float = NULL;
static bool in_loop = false;
static bool has_errors = false;
static void error(ast_node *n, char *msg)
{
has_errors = true;
if (n) {
printf("\x1b[31m\x1b[1merror\x1b[0m\x1b[1m:%ld:%ld:\x1b[0m %s\n", n->position.row, n->position.column, msg);
} else {
@ -133,6 +135,16 @@ static type *get_type(sema *s, ast_node *n)
char *name = NULL;
type *t = NULL;
switch (n->type) {
case NODE_ACCESS:
t = get_type(s, n->expr.access.expr);
name = intern_string(s, n->expr.access.member->expr.string.start, n->expr.access.member->expr.string.len);
if (t->tag != TYPE_STRUCT) {
error(n->expr.access.expr, "expected structure.");
return NULL;
}
t = shget(t->data.structure.member_types, name);
return t;
case NODE_IDENTIFIER:
name = intern_string(s, n->expr.string.start, n->expr.string.len);
t = shget(type_reg, name);
@ -433,7 +445,8 @@ static bool can_cast(type *source, type *dest)
switch (dest->tag) {
case TYPE_INTEGER:
case TYPE_UINTEGER:
return source->tag == TYPE_INTEGER_CONST;
case TYPE_INTEGER_CONST:
return source->tag == TYPE_INTEGER_CONST || source->tag == TYPE_INTEGER || source->tag == TYPE_UINTEGER;
case TYPE_FLOAT:
return source->tag == TYPE_FLOAT_CONST;
default:
@ -558,11 +571,20 @@ static type *get_expression_type(sema *s, ast_node *node)
node->expr_type = t;
return t;
case NODE_CALL:
prot = shget(prototypes, intern_string(s, node->expr.call.name, node->expr.call.name_len));
node->expr.call.name = intern_string(s, node->expr.call.name, node->expr.call.name_len);
prot = shget(prototypes, node->expr.call.name);
if (!prot) {
error(node, "unknown function.");
return NULL;
}
// Process call arguments
ast_node *arg = node->expr.call.parameters;
while (arg && arg->type == NODE_UNIT) {
if (arg->expr.unit_node.expr) {
get_expression_type(s, arg->expr.unit_node.expr);
}
arg = arg->expr.unit_node.next;
}
t = prot->type;
node->expr_type = t;
return t;
@ -709,8 +731,9 @@ static void check_statement(sema *s, ast_node *node)
error(node, "redeclaration of variable.");
break;
}
if (!can_cast(get_expression_type(s, node->expr.var_decl.value), t) && !match(t, get_expression_type(s, node->expr.var_decl.value))) {
error(node, "type mismatch.");
if (t->tag == TYPE_STRUCT) {
} else if (!can_cast(get_expression_type(s, node->expr.var_decl.value), t) && !match(t, get_expression_type(s, node->expr.var_decl.value))) {
error(node, "type mismatch (decl).");
}
shput(current_scope->defs, name, node);
break;
@ -815,4 +838,9 @@ void sema_init(parser *p, arena *a)
const_float->data.flt = 0;
analyze_unit(s, s->ast);
if (has_errors) {
printf("Compilation failed.\n");
exit(1);
}
}

BIN
test Executable file

Binary file not shown.

16
test.l
View file

@ -1,12 +1,10 @@
u32 main(u32 b)
{
u32 a = 4;
//return a;
if (b == 3) {
return 3;
} else {
return 4;
struct point {
i32 x,
i32 y
}
return a;
i32 main()
{
point result = .{ x = 2, y = 1 };
return (result.y) + 2;
}

20
test.s Normal file
View file

@ -0,0 +1,20 @@
.section .text
.global main
main:
push %rbp
mov %rsp, %rbp
sub $256, %rsp
mov $2, %rax
mov %eax, -8(%rbp)
mov $1, %rax
mov %eax, -12(%rbp)
mov -12(%rbp), %rax
mov %rax, %rcx
mov $2, %rax
add %rcx, %rax
mov %rbp, %rsp
pop %rbp
ret
mov %rbp, %rsp
pop %rbp
ret

BIN
test_control Executable file

Binary file not shown.

23
test_control.l Normal file
View file

@ -0,0 +1,23 @@
i32 main()
{
i32 x = 0;
i32 i = 0;
// Test while loop with break
while (i < 10) {
i = i + 1;
if (i == 5) {
break;
}
x = x + i;
}
// Test goto and label
if (x == 10) {
goto skip;
}
x = 999;
skip:
return x;
}

BIN
test_simple Executable file

Binary file not shown.

13
test_simple.l Normal file
View file

@ -0,0 +1,13 @@
i32 main()
{
i32 x = 0;
while (x < 10) {
x = x + 1;
if (x == 5) {
break;
}
}
return x;
}