diff --git a/Makefile b/Makefile index 203633a..7b960e9 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,8 @@ include config.mk -SRC = lc.c utils.c lexer.c parser.c sema.c ir.c -HDR = config.def.h utils.h lexer.h parser.h sema.h ir.h +SRC = lc.c utils.c lexer.c parser.c sema.c codegen.c +HDR = config.def.h utils.h lexer.h parser.h sema.h codegen.h OBJ = ${SRC:.c=.o} all: options lc diff --git a/codegen.c b/codegen.c new file mode 100644 index 0000000..d1c58fd --- /dev/null +++ b/codegen.c @@ -0,0 +1,1168 @@ +#include +#include +#include +#include "codegen.h" +#include "sema.h" + +#include "stb_ds.h" + +typedef struct { + char *key; + int value; +} var_map; + +static var_map *variables = NULL; +static int stack_offset = 0; +static int label_counter = 0; +static int *break_stack = NULL; + +void gen_expr(FILE *fp, ast_node *expr); +void gen_unary(FILE *fp, ast_node *expr); +void gen_statement_list(FILE *fp, ast_node *stmt); +int get_var_offset(char *name, usize name_len); +int get_var_offset_sized(char *name, usize name_len, usize size); + +void gen_binary(FILE *fp, ast_node *expr) +{ + switch (expr->expr.binary.operator) { + case OP_PLUS: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "add %%rcx, %%rax\n"); + break; + case OP_MINUS: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "sub %%rax, %%rcx\n"); + fprintf(fp, "mov %%rcx, %%rax\n"); + break; + case OP_MUL: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "imul %%rcx, %%rax\n"); + break; + case OP_DIV: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "mov %%rax, %%rbx\n"); + fprintf(fp, "mov %%rcx, %%rax\n"); + fprintf(fp, "cqo\n"); + fprintf(fp, "idiv %%rbx\n"); + break; + case OP_MOD: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "mov %%rax, %%rbx\n"); + fprintf(fp, "mov %%rcx, %%rax\n"); + fprintf(fp, "cqo\n"); + fprintf(fp, "idiv %%rbx\n"); + fprintf(fp, "mov %%rdx, %%rax\n"); + break; + case OP_BOR: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "or %%rcx, %%rax\n"); + break; + case OP_BAND: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "and %%rcx, %%rax\n"); + break; + case OP_BXOR: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "xor %%rcx, %%rax\n"); + break; + case OP_EQ: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "cmp %%rax, %%rcx\n"); + fprintf(fp, "sete %%al\n"); + fprintf(fp, "movzx %%al, %%rax\n"); + break; + case OP_NEQ: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "cmp %%rax, %%rcx\n"); + fprintf(fp, "setne %%al\n"); + fprintf(fp, "movzx %%al, %%rax\n"); + break; + case OP_LT: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "cmp %%rax, %%rcx\n"); + fprintf(fp, "setl %%al\n"); + fprintf(fp, "movzx %%al, %%rax\n"); + break; + case OP_GT: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "cmp %%rax, %%rcx\n"); + fprintf(fp, "setg %%al\n"); + fprintf(fp, "movzx %%al, %%rax\n"); + break; + case OP_LE: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "cmp %%rax, %%rcx\n"); + fprintf(fp, "setle %%al\n"); + fprintf(fp, "movzx %%al, %%rax\n"); + break; + case OP_GE: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "mov %%rax, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "cmp %%rax, %%rcx\n"); + fprintf(fp, "setge %%al\n"); + fprintf(fp, "movzx %%al, %%rax\n"); + break; + case OP_AND: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "test %%rax, %%rax\n"); + fprintf(fp, "setne %%al\n"); + fprintf(fp, "movzx %%al, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "test %%rax, %%rax\n"); + fprintf(fp, "setne %%al\n"); + fprintf(fp, "movzx %%al, %%rax\n"); + fprintf(fp, "and %%rcx, %%rax\n"); + break; + case OP_OR: + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "test %%rax, %%rax\n"); + fprintf(fp, "setne %%al\n"); + fprintf(fp, "movzx %%al, %%rcx\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "test %%rax, %%rax\n"); + fprintf(fp, "setne %%al\n"); + fprintf(fp, "movzx %%al, %%rax\n"); + fprintf(fp, "or %%rcx, %%rax\n"); + break; + case OP_ASSIGN: { + if (expr->expr.binary.left->type == NODE_IDENTIFIER) { + gen_expr(fp, expr->expr.binary.right); + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + } else if (expr->expr.binary.left->type == NODE_ARRAY_SUBSCRIPT) { + ast_node *subscript = expr->expr.binary.left; + usize element_size = 8; + type *base_type = subscript->expr.subscript.expr->expr_type; + bool is_slice = false; + + if (base_type) { + if (base_type->tag == TYPE_PTR && base_type->data.ptr.child) { + element_size = base_type->data.ptr.child->size; + } else if (base_type->tag == TYPE_SLICE && base_type->data.slice.child) { + element_size = base_type->data.slice.child->size; + is_slice = true; + } + } + + if (subscript->expr.subscript.expr->type == NODE_IDENTIFIER && is_slice) { + int base_offset = get_var_offset(subscript->expr.subscript.expr->expr.string.start, + subscript->expr.subscript.expr->expr.string.len); + + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", base_offset); + gen_expr(fp, subscript->expr.subscript.index); + + if (element_size != 1) { + fprintf(fp, "imul $%lu, %%rax\n", element_size); + } + + fprintf(fp, "add %%rcx, %%rax\n"); + fprintf(fp, "push %%rax\n"); + + gen_expr(fp, expr->expr.binary.right); + + fprintf(fp, "pop %%rcx\n"); + + if (subscript->expr_type && subscript->expr_type->size == 4) { + fprintf(fp, "mov %%eax, (%%rcx)\n"); + } else if (subscript->expr_type && subscript->expr_type->size == 2) { + fprintf(fp, "mov %%ax, (%%rcx)\n"); + } else if (subscript->expr_type && subscript->expr_type->size == 1) { + fprintf(fp, "mov %%al, (%%rcx)\n"); + } else { + fprintf(fp, "mov %%rax, (%%rcx)\n"); + } + } else if (subscript->expr.subscript.expr->type == NODE_IDENTIFIER) { + int base_offset = get_var_offset(subscript->expr.subscript.expr->expr.string.start, + subscript->expr.subscript.expr->expr.string.len); + + gen_expr(fp, subscript->expr.subscript.index); + + if (element_size != 1) { + fprintf(fp, "imul $%lu, %%rax\n", element_size); + } + + fprintf(fp, "add $%d, %%rax\n", base_offset); + fprintf(fp, "neg %%rax\n"); + fprintf(fp, "add %%rbp, %%rax\n"); + fprintf(fp, "push %%rax\n"); + + gen_expr(fp, expr->expr.binary.right); + + fprintf(fp, "pop %%rcx\n"); + + if (subscript->expr_type && subscript->expr_type->size == 4) { + fprintf(fp, "mov %%eax, (%%rcx)\n"); + } else if (subscript->expr_type && subscript->expr_type->size == 2) { + fprintf(fp, "mov %%ax, (%%rcx)\n"); + } else if (subscript->expr_type && subscript->expr_type->size == 1) { + fprintf(fp, "mov %%al, (%%rcx)\n"); + } else { + fprintf(fp, "mov %%rax, (%%rcx)\n"); + } + } else { + gen_expr(fp, subscript->expr.subscript.expr); + fprintf(fp, "push %%rax\n"); + + gen_expr(fp, subscript->expr.subscript.index); + + if (element_size != 1) { + fprintf(fp, "imul $%lu, %%rax\n", element_size); + } + + fprintf(fp, "pop %%rcx\n"); + fprintf(fp, "add %%rcx, %%rax\n"); + fprintf(fp, "push %%rax\n"); + + gen_expr(fp, expr->expr.binary.right); + + fprintf(fp, "pop %%rcx\n"); + + if (subscript->expr_type && subscript->expr_type->size == 4) { + fprintf(fp, "mov %%eax, (%%rcx)\n"); + } else if (subscript->expr_type && subscript->expr_type->size == 2) { + fprintf(fp, "mov %%ax, (%%rcx)\n"); + } else if (subscript->expr_type && subscript->expr_type->size == 1) { + fprintf(fp, "mov %%al, (%%rcx)\n"); + } else { + fprintf(fp, "mov %%rax, (%%rcx)\n"); + } + } + } else { + fprintf(fp, "# ERROR: left side of assignment must be identifier\n"); + break; + } + + break; + } + case OP_ASSIGN_PTR: { + gen_expr(fp, expr->expr.binary.left); + fprintf(fp, "push %%rax\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "pop %%rcx\n"); + fprintf(fp, "mov %%rax, (%%rcx)\n"); + break; + } + case OP_PLUS_EQ: { + if (expr->expr.binary.left->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: left side of assignment must be identifier\n"); + break; + } + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "add %%rcx, %%rax\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + break; + } + case OP_MINUS_EQ: { + if (expr->expr.binary.left->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: left side of assignment must be identifier\n"); + break; + } + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "sub %%rax, %%rcx\n"); + fprintf(fp, "mov %%rcx, %%rax\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + break; + } + case OP_MUL_EQ: { + if (expr->expr.binary.left->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: left side of assignment must be identifier\n"); + break; + } + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "imul %%rcx, %%rax\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + break; + } + case OP_DIV_EQ: { + if (expr->expr.binary.left->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: left side of assignment must be identifier\n"); + break; + } + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "mov %%rax, %%rbx\n"); + fprintf(fp, "mov %%rcx, %%rax\n"); + fprintf(fp, "cqo\n"); + fprintf(fp, "idiv %%rbx\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + break; + } + case OP_MOD_EQ: { + if (expr->expr.binary.left->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: left side of assignment must be identifier\n"); + break; + } + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "mov %%rax, %%rbx\n"); + fprintf(fp, "mov %%rcx, %%rax\n"); + fprintf(fp, "cqo\n"); + fprintf(fp, "idiv %%rbx\n"); + fprintf(fp, "mov %%rdx, -%d(%%rbp)\n", offset); + break; + } + case OP_BOR_EQ: { + if (expr->expr.binary.left->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: left side of assignment must be identifier\n"); + break; + } + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "or %%rcx, %%rax\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + break; + } + case OP_BAND_EQ: { + if (expr->expr.binary.left->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: left side of assignment must be identifier\n"); + break; + } + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "and %%rcx, %%rax\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + break; + } + case OP_BXOR_EQ: { + if (expr->expr.binary.left->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: left side of assignment must be identifier\n"); + break; + } + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", offset); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "xor %%rcx, %%rax\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + break; + } + case OP_RSHIFT_EQ: { + if (expr->expr.binary.left->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: left side of assignment must be identifier\n"); + break; + } + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rax\n", offset); + fprintf(fp, "push %%rax\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "mov %%rax, %%rcx\n"); + fprintf(fp, "pop %%rax\n"); + fprintf(fp, "sar %%cl, %%rax\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + break; + } + case OP_LSHIFT_EQ: { + if (expr->expr.binary.left->type != NODE_IDENTIFIER) { + break; + } + int offset = get_var_offset(expr->expr.binary.left->expr.string.start, + expr->expr.binary.left->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rax\n", offset); + fprintf(fp, "push %%rax\n"); + gen_expr(fp, expr->expr.binary.right); + fprintf(fp, "mov %%rax, %%rcx\n"); + fprintf(fp, "pop %%rax\n"); + fprintf(fp, "shl %%cl, %%rax\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + break; + } + } +} + +int get_var_offset(char *name, usize name_len) +{ + char *var_name = strndup(name, name_len); + ptrdiff_t idx = shgeti(variables, var_name); + + if (idx >= 0) { + free(var_name); + return variables[idx].value; + } + + stack_offset += 8; + shput(variables, var_name, stack_offset); + return stack_offset; +} + +int get_var_offset_sized(char *name, usize name_len, usize size) +{ + char *var_name = strndup(name, name_len); + ptrdiff_t idx = shgeti(variables, var_name); + + if (idx >= 0) { + free(var_name); + return variables[idx].value; + } + + usize aligned_size = (size + 7) & ~7; + stack_offset += aligned_size; + shput(variables, var_name, stack_offset); + return stack_offset; +} + +void gen_statement_list(FILE *fp, ast_node *stmt) +{ + if (!stmt) return; + + if (stmt->type == NODE_UNIT) { + ast_node *current = stmt; + while (current && current->type == NODE_UNIT) { + if (current->expr.unit_node.expr) { + gen_expr(fp, current->expr.unit_node.expr); + } + current = current->expr.unit_node.next; + } + } else { + gen_expr(fp, stmt); + } +} + +void gen_unary(FILE *fp, ast_node *expr) +{ + switch (expr->expr.unary.operator) { + case UOP_MINUS: + gen_expr(fp, expr->expr.unary.right); + fprintf(fp, "neg %%rax\n"); + break; + case UOP_NOT: + gen_expr(fp, expr->expr.unary.right); + fprintf(fp, "test %%rax, %%rax\n"); + fprintf(fp, "sete %%al\n"); + fprintf(fp, "movzx %%al, %%rax\n"); + break; + case UOP_INCR: + if (expr->expr.unary.right->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: increment requires identifier\n"); + break; + } + int offset_incr = get_var_offset(expr->expr.unary.right->expr.string.start, + expr->expr.unary.right->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rax\n", offset_incr); + fprintf(fp, "inc %%rax\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset_incr); + break; + case UOP_DECR: + if (expr->expr.unary.right->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: decrement requires identifier\n"); + break; + } + int offset_decr = get_var_offset(expr->expr.unary.right->expr.string.start, + expr->expr.unary.right->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rax\n", offset_decr); + fprintf(fp, "dec %%rax\n"); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset_decr); + break; + case UOP_REF: + if (expr->expr.unary.right->type != NODE_IDENTIFIER) { + fprintf(fp, "# ERROR: address-of requires identifier\n"); + break; + } + int offset_ref = get_var_offset(expr->expr.unary.right->expr.string.start, + expr->expr.unary.right->expr.string.len); + fprintf(fp, "lea -%d(%%rbp), %%rax\n", offset_ref); + break; + case UOP_DEREF: + gen_expr(fp, expr->expr.unary.right); + fprintf(fp, "mov (%%rax), %%rax\n"); + break; + } +} + +void gen_expr(FILE *fp, ast_node *expr) +{ + switch (expr->type) { + case NODE_INTEGER: + fprintf(fp, "mov $%lu, %%rax\n", expr->expr.integer); + break; + case NODE_FLOAT: { + // TODO: do not truncate + i64 int_val = (i64)expr->expr.flt; + fprintf(fp, "mov $%ld, %%rax\n", int_val); + break; + } + case NODE_CHAR: + fprintf(fp, "mov $%d, %%rax\n", (int)(unsigned char)expr->expr.ch); + break; + case NODE_BOOL: + fprintf(fp, "mov $%d, %%rax\n", expr->expr.boolean ? 1 : 0); + break; + case NODE_IDENTIFIER: { + int offset = get_var_offset(expr->expr.string.start, expr->expr.string.len); + fprintf(fp, "mov -%d(%%rbp), %%rax\n", offset); + break; + } + case NODE_BINARY: + gen_binary(fp, expr); + break; + case NODE_UNARY: + gen_unary(fp, expr); + break; + case NODE_CAST: + gen_expr(fp, expr->expr.cast.value); + break; + case NODE_VAR_DECL: { + int offset = 0; + type *var_type = expr->expr_type; + if (!var_type && expr->expr.var_decl.type) { + var_type = expr->expr.var_decl.type->expr_type; + } + + bool is_inline_slice = false; + if (var_type && var_type->tag == TYPE_SLICE && expr->expr.var_decl.value && + (expr->expr.var_decl.value->type == NODE_STRUCT_INIT || + expr->expr.var_decl.value->type == NODE_RANGE)) { + is_inline_slice = true; + } + + if (!is_inline_slice) { + if (var_type && var_type->size > 0) { + offset = get_var_offset_sized(expr->expr.var_decl.name, + expr->expr.var_decl.name_len, var_type->size); + } else { + offset = get_var_offset(expr->expr.var_decl.name, expr->expr.var_decl.name_len); + } + } + + if (expr->expr.var_decl.value) { + if (expr->expr.var_decl.value->type == NODE_RANGE && var_type && var_type->tag == TYPE_SLICE) { + ast_node *range = expr->expr.var_decl.value; + if (range->expr.binary.left->type == NODE_INTEGER && + range->expr.binary.right->type == NODE_INTEGER) { + i64 start = range->expr.binary.left->expr.integer; + i64 end = range->expr.binary.right->expr.integer; + i64 count = end - start + 1; + + type *element_type = var_type->data.slice.child; + usize element_size = element_type ? element_type->size : 8; + + usize data_size = count * element_size; + usize aligned_data_size = (data_size + 7) & ~7; + stack_offset += aligned_data_size; + int data_offset = stack_offset; + + stack_offset += 16; + offset = stack_offset; + + char *var_name = strndup(expr->expr.var_decl.name, expr->expr.var_decl.name_len); + shput(variables, var_name, offset); + + for (i64 i = 0; i < count; i++) { + i64 value = start + i; + int element_offset = data_offset - (i * element_size); + fprintf(fp, "mov $%ld, %%rax\n", value); + if (element_size == 4) { + fprintf(fp, "mov %%eax, -%d(%%rbp)\n", element_offset); + } else if (element_size == 2) { + fprintf(fp, "mov %%ax, -%d(%%rbp)\n", element_offset); + } else if (element_size == 1) { + fprintf(fp, "mov %%al, -%d(%%rbp)\n", element_offset); + } else { + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", element_offset); + } + } + + fprintf(fp, "lea -%d(%%rbp), %%rax\n", data_offset); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + fprintf(fp, "mov $%ld, %%rax\n", count); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset - 8); + } + } else if (expr->expr.var_decl.value->type == NODE_STRING && var_type && var_type->tag == TYPE_SLICE) { + ast_node *str = expr->expr.var_decl.value; + usize str_len = str->expr.string.len; + char *str_data = str->expr.string.start; + + usize aligned_data_size = (str_len + 7) & ~7; + stack_offset += aligned_data_size; + int data_offset = stack_offset; + + stack_offset += 16; + offset = stack_offset; + + char *var_name = strndup(expr->expr.var_decl.name, expr->expr.var_decl.name_len); + shput(variables, var_name, offset); + + for (usize i = 0; i < str_len; i++) { + int byte_offset = data_offset - i; + if ((unsigned char)str_data[i] == '\\' && (unsigned char)str_data[i+1] == 'n') { + fprintf(fp, "movb $%d, -%d(%%rbp)\n", (unsigned char)'\n', byte_offset); + i += 1; + } else { + fprintf(fp, "movb $%d, -%d(%%rbp)\n", (unsigned char)str_data[i], byte_offset); + } + } + + fprintf(fp, "lea -%d(%%rbp), %%rax\n", data_offset); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + fprintf(fp, "mov $%lu, %%rax\n", str_len); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset - 8); + } else if (expr->expr.var_decl.value->type == NODE_STRUCT_INIT) { + ast_node *member_list = expr->expr.var_decl.value->expr.struct_init.members; + ast_node *current = member_list; + + if (var_type && var_type->tag == TYPE_STRUCT) { + while (current && current->type == NODE_UNIT) { + ast_node *assignment = current->expr.unit_node.expr; + if (assignment && assignment->type == NODE_BINARY && + assignment->expr.binary.operator == OP_ASSIGN) { + ast_node *field = assignment->expr.binary.left; + ast_node *value = assignment->expr.binary.right; + + if (field->type == NODE_IDENTIFIER) { + char *field_name = strndup(field->expr.string.start, + field->expr.string.len); + + member *m = var_type->data.structure.members; + int field_offset = -1; + while (m) { + if (m->name_len == field->expr.string.len && + strncmp(m->name, field->expr.string.start, m->name_len) == 0) { + field_offset = m->offset; + break; + } + m = m->next; + } + + if (field_offset >= 0) { + gen_expr(fp, value); + + type *field_type = shget(var_type->data.structure.member_types, field_name); + + if (field_type && field_type->size == 4) { + fprintf(fp, "mov %%eax, -%d(%%rbp)\n", offset + field_offset); + } else if (field_type && field_type->size == 2) { + fprintf(fp, "mov %%ax, -%d(%%rbp)\n", offset + field_offset); + } else if (field_type && field_type->size == 1) { + fprintf(fp, "mov %%al, -%d(%%rbp)\n", offset + field_offset); + } else { + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset + field_offset); + } + } + + free(field_name); + } + } + current = current->expr.unit_node.next; + } + } + else if (var_type && (var_type->tag == TYPE_PTR || var_type->tag == TYPE_SLICE)) { + usize element_size = 8; + type *element_type = NULL; + + if (var_type->tag == TYPE_PTR && var_type->data.ptr.child) { + element_type = var_type->data.ptr.child; + element_size = element_type->size; + } else if (var_type->tag == TYPE_SLICE && var_type->data.slice.child) { + element_type = var_type->data.slice.child; + element_size = element_type->size; + } + + int element_count = 0; + ast_node *count_node = current; + while (count_node && count_node->type == NODE_UNIT) { + element_count++; + count_node = count_node->expr.unit_node.next; + } + if (var_type->tag == TYPE_SLICE) { + usize data_size = element_count * element_size; + usize aligned_data_size = (data_size + 7) & ~7; + stack_offset += aligned_data_size; + int data_offset = stack_offset; + stack_offset += 16; + offset = stack_offset; + + char *var_name = strndup(expr->expr.var_decl.name, expr->expr.var_decl.name_len); + shput(variables, var_name, offset); + + int index = 0; + while (current && current->type == NODE_UNIT) { + ast_node *value = current->expr.unit_node.expr; + if (value) { + gen_expr(fp, value); + + int element_offset = data_offset - (index * element_size); + + if (element_type && element_type->size == 4) { + fprintf(fp, "mov %%eax, -%d(%%rbp)\n", element_offset); + } else if (element_type && element_type->size == 2) { + fprintf(fp, "mov %%ax, -%d(%%rbp)\n", element_offset); + } else if (element_type && element_type->size == 1) { + fprintf(fp, "mov %%al, -%d(%%rbp)\n", element_offset); + } else { + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", element_offset); + } + } + index++; + current = current->expr.unit_node.next; + } + + fprintf(fp, "lea -%d(%%rbp), %%rax\n", data_offset); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + fprintf(fp, "mov $%d, %%rax\n", element_count); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset - 8); + } else { + int index = 0; + while (current && current->type == NODE_UNIT) { + ast_node *value = current->expr.unit_node.expr; + if (value) { + gen_expr(fp, value); + + int element_offset = offset + (index * element_size); + + if (element_type && element_type->size == 4) { + fprintf(fp, "mov %%eax, -%d(%%rbp)\n", element_offset); + } else if (element_type && element_type->size == 2) { + fprintf(fp, "mov %%ax, -%d(%%rbp)\n", element_offset); + } else if (element_type && element_type->size == 1) { + fprintf(fp, "mov %%al, -%d(%%rbp)\n", element_offset); + } else { + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", element_offset); + } + } + index++; + current = current->expr.unit_node.next; + } + } + } + } else { + gen_expr(fp, expr->expr.var_decl.value); + + // If assigning a slice value, copy the 16-byte structure + if (var_type && var_type->tag == TYPE_SLICE) { + fprintf(fp, "mov (%%rax), %%rcx\n"); // Load ptr field + fprintf(fp, "mov 8(%%rax), %%rdx\n"); // Load len field + fprintf(fp, "mov %%rcx, -%d(%%rbp)\n", offset); // Store ptr + fprintf(fp, "mov %%rdx, -%d(%%rbp)\n", offset - 8); // Store len + } else { + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", offset); + } + } + } + break; + } + case NODE_RETURN: { + if (expr->expr.ret.value) { + gen_expr(fp, expr->expr.ret.value); + } + fprintf(fp, "mov %%rbp, %%rsp\n"); + fprintf(fp, "pop %%rbp\n"); + fprintf(fp, "ret\n"); + break; + } + case NODE_IF: { + int label_else = label_counter++; + int label_end = label_counter++; + + gen_expr(fp, expr->expr.if_stmt.condition); + fprintf(fp, "test %%rax, %%rax\n"); + + if (expr->expr.if_stmt.otherwise) { + fprintf(fp, "jz .L%d\n", label_else); + } else { + fprintf(fp, "jz .L%d\n", label_end); + } + + gen_statement_list(fp, expr->expr.if_stmt.body); + + if (expr->expr.if_stmt.otherwise) { + fprintf(fp, "jmp .L%d\n", label_end); + fprintf(fp, ".L%d:\n", label_else); + gen_statement_list(fp, expr->expr.if_stmt.otherwise); + } + + fprintf(fp, ".L%d:\n", label_end); + break; + } + case NODE_WHILE: { + int label_start = label_counter++; + int label_end = label_counter++; + + fprintf(fp, ".L%d:\n", label_start); + + gen_expr(fp, expr->expr.whle.condition); + fprintf(fp, "test %%rax, %%rax\n"); + fprintf(fp, "jz .L%d\n", label_end); + + arrput(break_stack, label_end); + + gen_statement_list(fp, expr->expr.whle.body); + + arrpop(break_stack); + + fprintf(fp, "jmp .L%d\n", label_start); + + fprintf(fp, ".L%d:\n", label_end); + break; + } + case NODE_LABEL: { + char *label_name = strndup(expr->expr.label.name, expr->expr.label.name_len); + fprintf(fp, ".L_%s:\n", label_name); + free(label_name); + break; + } + case NODE_GOTO: { + char *label_name = strndup(expr->expr.label.name, expr->expr.label.name_len); + fprintf(fp, "jmp .L_%s\n", label_name); + free(label_name); + break; + } + case NODE_BREAK: { + if (arrlen(break_stack) > 0) { + int loop_end = break_stack[arrlen(break_stack) - 1]; + fprintf(fp, "jmp .L%d\n", loop_end); + } else { + fprintf(fp, "# ERROR: break outside of loop\n"); + } + break; + } + case NODE_ACCESS: { + ast_node *base = expr->expr.access.expr; + ast_node *member_node = expr->expr.access.member; + + if (base->type == NODE_IDENTIFIER) { + int base_offset = get_var_offset(base->expr.string.start, base->expr.string.len); + type *base_type = base->expr_type; + + if (base_type && base_type->tag == TYPE_SLICE && member_node->type == NODE_IDENTIFIER) { + char *field_name = strndup(member_node->expr.string.start, member_node->expr.string.len); + + if (strcmp(field_name, "ptr") == 0) { + fprintf(fp, "mov -%d(%%rbp), %%rax\n", base_offset); + } else if (strcmp(field_name, "len") == 0) { + fprintf(fp, "mov -%d(%%rbp), %%rax\n", base_offset - 8); + } else { + fprintf(fp, "# ERROR: slice field '%s' not found\n", field_name); + } + + free(field_name); + } + else if (member_node->type == NODE_IDENTIFIER && base_type && + base_type->tag == TYPE_STRUCT) { + member *m = base_type->data.structure.members; + int field_offset = -1; + while (m) { + if (m->name_len == member_node->expr.string.len && + strncmp(m->name, member_node->expr.string.start, m->name_len) == 0) { + field_offset = m->offset; + break; + } + m = m->next; + } + + if (field_offset >= 0) { + fprintf(fp, "mov -%d(%%rbp), %%rax\n", base_offset + field_offset); + } else { + fprintf(fp, "# ERROR: field not found\n"); + } + } else { + fprintf(fp, "# ERROR: not a struct or slice type\n"); + } + } else { + fprintf(fp, "# ERROR: complex struct access not implemented\n"); + } + break; + } + case NODE_RANGE: { + if (expr->expr.binary.left->type == NODE_INTEGER && + expr->expr.binary.right->type == NODE_INTEGER) { + i64 start = expr->expr.binary.left->expr.integer; + i64 end = expr->expr.binary.right->expr.integer; + i64 count = end - start + 1; + + usize element_size = 8; + usize data_size = count * element_size; + usize aligned_data_size = (data_size + 7) & ~7; + stack_offset += aligned_data_size; + int data_offset = stack_offset; + + for (i64 i = 0; i < count; i++) { + i64 value = start + i; + int element_offset = data_offset - (i * element_size); + fprintf(fp, "mov $%ld, %%rax\n", value); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", element_offset); + } + + stack_offset += 16; + int slice_offset = stack_offset; + + fprintf(fp, "lea -%d(%%rbp), %%rax\n", data_offset); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", slice_offset); + fprintf(fp, "mov $%ld, %%rax\n", count); + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", slice_offset - 8); + + fprintf(fp, "lea -%d(%%rbp), %%rax\n", slice_offset); + } else { + fprintf(fp, "# ERROR: range expression requires constant bounds\n"); + } + break; + } + case NODE_STRUCT_INIT: { + fprintf(fp, "# ERROR: struct init outside of variable declaration\n"); + break; + } + case NODE_ARRAY_SUBSCRIPT: { + usize element_size = 8; + type *base_type = expr->expr.subscript.expr->expr_type; + bool is_slice = false; + + if (base_type) { + if (base_type->tag == TYPE_PTR && base_type->data.ptr.child) { + element_size = base_type->data.ptr.child->size; + } else if (base_type->tag == TYPE_SLICE && base_type->data.slice.child) { + element_size = base_type->data.slice.child->size; + is_slice = true; + } + } + + if (expr->expr.subscript.index->type == NODE_RANGE) { + if (expr->expr.subscript.expr->type == NODE_IDENTIFIER) { + int base_offset = get_var_offset(expr->expr.subscript.expr->expr.string.start, + expr->expr.subscript.expr->expr.string.len); + + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", base_offset); + + gen_expr(fp, expr->expr.subscript.index->expr.binary.left); + fprintf(fp, "push %%rax\n"); + + gen_expr(fp, expr->expr.subscript.index->expr.binary.right); + fprintf(fp, "mov %%rax, %%rdx\n"); // rdx = end + fprintf(fp, "pop %%rax\n"); // rax = start + fprintf(fp, "mov %%rdx, %%r8\n"); + fprintf(fp, "sub %%rax, %%r8\n"); + fprintf(fp, "inc %%r8\n"); // r8 = new length + + if (element_size != 1) { + fprintf(fp, "imul $%lu, %%rax\n", element_size); + } + fprintf(fp, "add %%rcx, %%rax\n"); // rax = new ptr + + // Allocate temporary slice struct (16 bytes) + stack_offset += 16; + fprintf(fp, "mov %%rax, -%d(%%rbp)\n", stack_offset); // Store ptr + fprintf(fp, "mov %%r8, -%d(%%rbp)\n", stack_offset - 8); // Store len + fprintf(fp, "lea -%d(%%rbp), %%rax\n", stack_offset); // Return address of temp slice + } + } + else if (expr->expr.subscript.expr->type == NODE_IDENTIFIER && is_slice) { + int base_offset = get_var_offset(expr->expr.subscript.expr->expr.string.start, + expr->expr.subscript.expr->expr.string.len); + + fprintf(fp, "mov -%d(%%rbp), %%rcx\n", base_offset); + + gen_expr(fp, expr->expr.subscript.index); + + if (element_size != 1) { + fprintf(fp, "imul $%lu, %%rax\n", element_size); + } + + fprintf(fp, "add %%rcx, %%rax\n"); + + if (expr->expr_type && expr->expr_type->size == 4) { + fprintf(fp, "movl (%%rax), %%eax\n"); + } else if (expr->expr_type && expr->expr_type->size == 2) { + fprintf(fp, "movzwl (%%rax), %%eax\n"); + } else if (expr->expr_type && expr->expr_type->size == 1) { + fprintf(fp, "movzbl (%%rax), %%eax\n"); + } else { + fprintf(fp, "mov (%%rax), %%rax\n"); + } + } else if (expr->expr.subscript.expr->type == NODE_IDENTIFIER) { + int base_offset = get_var_offset(expr->expr.subscript.expr->expr.string.start, + expr->expr.subscript.expr->expr.string.len); + + gen_expr(fp, expr->expr.subscript.index); + + if (element_size != 1) { + fprintf(fp, "imul $%lu, %%rax\n", element_size); + } + + fprintf(fp, "add $%d, %%rax\n", base_offset); + fprintf(fp, "neg %%rax\n"); + fprintf(fp, "add %%rbp, %%rax\n"); + + if (expr->expr_type && expr->expr_type->size == 4) { + fprintf(fp, "movl (%%rax), %%eax\n"); + } else if (expr->expr_type && expr->expr_type->size == 2) { + fprintf(fp, "movzwl (%%rax), %%eax\n"); + } else if (expr->expr_type && expr->expr_type->size == 1) { + fprintf(fp, "movzbl (%%rax), %%eax\n"); + } else { + fprintf(fp, "mov (%%rax), %%rax\n"); + } + } else { + gen_expr(fp, expr->expr.subscript.expr); + fprintf(fp, "push %%rax\n"); + + gen_expr(fp, expr->expr.subscript.index); + + if (element_size != 1) { + fprintf(fp, "imul $%lu, %%rax\n", element_size); + } + + fprintf(fp, "pop %%rcx\n"); + fprintf(fp, "add %%rcx, %%rax\n"); + + if (expr->expr_type && expr->expr_type->size == 4) { + fprintf(fp, "movl (%%rax), %%eax\n"); + } else if (expr->expr_type && expr->expr_type->size == 2) { + fprintf(fp, "movzwl (%%rax), %%eax\n"); + } else if (expr->expr_type && expr->expr_type->size == 1) { + fprintf(fp, "movzbl (%%rax), %%eax\n"); + } else { + fprintf(fp, "mov (%%rax), %%rax\n"); + } + } + break; + } + case NODE_CALL: { + const char *arg_regs[] = {"%rdi", "%rsi", "%rdx", "%rcx", "%r8", "%r9"}; + + int param_count = 0; + ast_node *param = expr->expr.call.parameters; + while (param && param->type == NODE_UNIT) { + param_count++; + param = param->expr.unit_node.next; + } + + param = expr->expr.call.parameters; + int i = 0; + while (param && param->type == NODE_UNIT) { + if (param->expr.unit_node.expr) { + gen_expr(fp, param->expr.unit_node.expr); + fprintf(fp, "push %%rax\n"); + } + param = param->expr.unit_node.next; + i++; + } + + for (int j = param_count - 1; j >= 0; j--) { + if (j < 6) { + fprintf(fp, "pop %s\n", arg_regs[j]); + } else { + // TODO: handle more than 6 arguments properly + fprintf(fp, "pop %%rax\n"); + fprintf(fp, "push %%rax\n"); + } + } + + fprintf(fp, "call %.*s\n", (int)expr->expr.call.name_len, expr->expr.call.name); + + if (param_count > 6) { + int stack_args = param_count - 6; + fprintf(fp, "add $%d, %%rsp\n", stack_args * 8); + } + + break; + } + } +} + +void gen_function(FILE *fp, ast_node *fn) +{ + if (fn->expr.function.is_extern || fn->expr.function.body == NULL) { + return; + } + + ast_node *current = fn->expr.function.body; + + stack_offset = 0; + label_counter = 0; + shfree(variables); + variables = NULL; + arrfree(break_stack); + break_stack = NULL; + + fprintf(fp, ".global %s\n%s:\n", fn->expr.function.name, fn->expr.function.name); + + fprintf(fp, "push %%rbp\n"); + fprintf(fp, "mov %%rsp, %%rbp\n"); + fprintf(fp, "sub $256, %%rsp\n"); + + const char *param_regs[] = {"%rdi", "%rsi", "%rdx", "%rcx", "%r8", "%r9"}; + member *param = fn->expr.function.parameters; + int param_idx = 0; + while (param && param_idx < 6) { + int offset = get_var_offset(param->name, param->name_len); + fprintf(fp, "mov %s, -%d(%%rbp)\n", param_regs[param_idx], offset); + param = param->next; + param_idx++; + } + + while (current && current->type == NODE_UNIT) { + if (current->expr.unit_node.expr) { + gen_expr(fp, current->expr.unit_node.expr); + } + current = current->expr.unit_node.next; + } + + fprintf(fp, "mov %%rbp, %%rsp\n"); + fprintf(fp, "pop %%rbp\n"); + fprintf(fp, "ret\n"); +} + +void generate(ast_node *node) +{ + FILE *fp = fopen("test.s", "w"); + + ast_node *current = node; + + fprintf(fp, ".section .text\n"); + while (current && current->type == NODE_UNIT) { + if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) { + gen_function(fp, current->expr.unit_node.expr); + } + current = current->expr.unit_node.next; + } + + fclose(fp); + + shfree(variables); + variables = NULL; + arrfree(break_stack); + break_stack = NULL; +} diff --git a/codegen.h b/codegen.h new file mode 100644 index 0000000..8e87114 --- /dev/null +++ b/codegen.h @@ -0,0 +1,8 @@ +#ifndef CODEGEN_H +#define CODEGEN_H + +#include "parser.h" + +void generate(ast_node *node); + +#endif diff --git a/done.txt b/done.txt deleted file mode 100644 index e69de29..0000000 diff --git a/ir.c b/ir.c deleted file mode 100644 index 61da14a..0000000 --- a/ir.c +++ /dev/null @@ -1,812 +0,0 @@ -#include "ir.h" -#include -#include -#include "stb_ds.h" -#include "sema.h" - -struct { ir_node key; ir_node *value; } *global_hash = NULL; -static ir_node *graph; -static ir_node *current_memory; -static ir_node *current_control; -static usize current_stack = 0; - -static ir_node *current_scope = NULL; - -static ir_node *build_expression(ast_node *node); - -static struct { - ir_node **return_controls; - ir_node **return_memories; - ir_node **return_values; -} current_func = {0}; - -static void node_name(ir_node *node) -{ - if (!node) { - printf("null [label=\"NULL\", style=filled, fillcolor=red]\n"); - return; - } - printf("%ld ", node->id); - switch (node->code) { - case OC_START: - printf("[label=\"%s\", style=filled, color=orange]\n", node->data.start_name); - break; - case OC_RETURN: - printf("[label=\"return\", style=filled, color=orange]\n"); - break; - case OC_ADD: - printf("[label=\"+\"]\n"); - break; - case OC_NEG: - case OC_SUB: - printf("[label=\"-\"]\n"); - break; - case OC_DIV: - printf("[label=\"/\"]\n"); - break; - case OC_MUL: - printf("[label=\"*\"]\n"); - break; - case OC_MOD: - printf("[label=\"%%\"]\n"); - break; - case OC_BAND: - printf("[label=\"&\"]\n"); - break; - case OC_BOR: - printf("[label=\"|\"]\n"); - break; - case OC_BXOR: - printf("[label=\"^\"]\n"); - break; - case OC_EQ: - printf("[label=\"==\"]\n"); - break; - case OC_CONST_INT: - printf("[label=\"%ld\"]\n", node->data.const_int); - break; - case OC_CONST_FLOAT: - printf("[label=\"%f\"]\n", node->data.const_float); - break; - case OC_FRAME_PTR: - printf("[label=\"frame_ptr\"]\n"); - break; - case OC_STORE: - printf("[label=\"store\", shape=box]\n"); - break; - case OC_LOAD: - printf("[label=\"load\", shape=box]\n"); - break; - case OC_ADDR: - printf("[label=\"addr\"]\n"); - break; - case OC_REGION: - printf("[label=\"region\", shape=diamond, style=filled, color=green]\n"); - break; - case OC_PHI: - printf("[label=\"phi\", shape=triangle]\n"); - break; - case OC_IF: - printf("[label=\"if\", shape=diamond, style=filled, color=lightblue]\n"); - break; - case OC_PROJ: - printf("[label=\"proj\", shape=diamond, style=filled, color=cyan]\n"); - break; - default: - printf("[label=\"%d\"]\n", node->code); - break; - } -} - -static void print_graph(ir_node *node) -{ - for (int i = 0; i < hmlen(global_hash); i++) { - ir_node *node = global_hash[i].value; - node_name(node); - - for (int j = 0; j < arrlen(node->out); j++) { - if (node->out[j]) { - node_name(node->out[j]); - printf("%ld->%ld\n", node->out[j]->id, node->id); - } - } - } -} - -static void push_scope(void) -{ - arrput(current_scope->data.symbol_tables, NULL); -} - -static struct symbol_def *get_def(char *name) -{ - for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) { - struct symbol_def *def = shget(current_scope->data.symbol_tables[i], name); - if (def) return def; - } - return NULL; -} - -static void set_def(char *name, ir_node *node, bool lvalue) -{ - for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) { - if (shget(current_scope->data.symbol_tables[i], name)) { - struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); - def->is_lvalue = lvalue; - def->node = node; - shput(current_scope->data.symbol_tables[i], name, def); - return; - } - } - int index = arrlen(current_scope->data.symbol_tables) - 1; - struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); - def->is_lvalue = lvalue; - def->node = node; - shput(current_scope->data.symbol_tables[index], name, def); -} - -static ir_node *copy_scope(ir_node *src) -{ - ir_node *dst = calloc(1, sizeof(ir_node)); - dst->code = OC_SCOPE; - - for (int i=0; i < arrlen(src->data.symbol_tables); i++) { - arrput(dst->data.symbol_tables, NULL); - symbol_table *src_table = src->data.symbol_tables[i]; - for (int j=0; j < shlen(src_table); j++) { - shput(dst->data.symbol_tables[i], src_table[j].key, src_table[j].value); - } - } - return dst; -} - -static void const_fold(ir_node *binary) -{ - ir_node *left = binary->out[0]; - ir_node *right = binary->out[1]; - - if (left->code == OC_CONST_INT && right->code == OC_CONST_INT) { - switch (binary->code) { - case OC_ADD: - binary->data.const_int = left->data.const_int + right->data.const_int; - break; - case OC_SUB: - binary->data.const_int = left->data.const_int - right->data.const_int; - break; - case OC_MUL: - binary->data.const_int = left->data.const_int * right->data.const_int; - break; - case OC_DIV: - if (right->data.const_int != 0) - binary->data.const_int = left->data.const_int / right->data.const_int; - break; - case OC_MOD: - if (right->data.const_int != 0) - binary->data.const_int = left->data.const_int % right->data.const_int; - break; - case OC_BOR: - binary->data.const_int = left->data.const_int | right->data.const_int; - break; - case OC_BAND: - binary->data.const_int = left->data.const_int & right->data.const_int; - break; - case OC_BXOR: - binary->data.const_int = left->data.const_int ^ right->data.const_int; - break; - case OC_EQ: - binary->data.const_int = left->data.const_int == right->data.const_int; - break; - default: - return; - } - binary->code = OC_CONST_INT; - arrfree(binary->out); binary->out = NULL; - arrfree(binary->in); binary->in = NULL; - binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe); - } - - if (left->code == OC_CONST_FLOAT && right->code == OC_CONST_FLOAT) { - switch (binary->code) { - case OC_ADD: - binary->data.const_float = left->data.const_float + right->data.const_float; - break; - case OC_SUB: - binary->data.const_float = left->data.const_float - right->data.const_float; - break; - case OC_MUL: - binary->data.const_float = left->data.const_float * right->data.const_float; - break; - case OC_DIV: - if (right->data.const_float != 0.0f) - binary->data.const_float = left->data.const_float / right->data.const_float; - break; - default: - return; - } - binary->code = OC_CONST_FLOAT; - arrfree(binary->out); binary->out = NULL; - arrfree(binary->in); binary->in = NULL; - binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe); - } -} - -static ir_node *build_address(usize base, usize offset) { - ir_node *addr = calloc(1, sizeof(ir_node)); - addr->code = OC_ADDR; - - ir_node *base_node = calloc(1, sizeof(ir_node)); - if (base == -1) { - base_node->code = OC_FRAME_PTR; - base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe); - } else { - base_node->code = OC_CONST_INT; - base_node->data.const_int = base; - base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe); - } - - ir_node *offset_node = calloc(1, sizeof(ir_node)); - offset_node->code = OC_CONST_INT; - offset_node->data.const_int = offset; - offset_node->id = stbds_hash_bytes(offset_node, sizeof(ir_node), 0xcafebabe); - - arrput(addr->out, base_node); - arrput(addr->out, offset_node); - - addr->id = stbds_hash_bytes(addr, sizeof(ir_node), 0xcafebabe); - ir_node *tmp = hmget(global_hash, *addr); - if (tmp) { - free(addr); - return tmp; - } - - return addr; -} - -static ir_node *build_assign_ptr(ast_node *binary) -{ - ir_node *val_node = build_expression(binary->expr.binary.right); - - char *var_name = binary->expr.binary.left->expr.string.start; - - ir_node *existing_def = get_def(var_name)->node; - - ir_node *store = calloc(1, sizeof(ir_node)); - store->code = OC_STORE; - - arrput(store->out, current_control); - - arrput(store->out, current_memory); - arrput(store->out, existing_def); - arrput(store->out, val_node); - - store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe); - hmput(global_hash, *store, store); - - current_memory = store; - - return val_node; -} - -static ir_node *build_assign(ast_node *binary) -{ - ir_node *val_node = build_expression(binary->expr.binary.right); - - char *var_name = binary->expr.binary.left->expr.string.start; - - struct symbol_def *def = get_def(var_name); - - if (def && def->is_lvalue) { - ir_node *existing_def = def->node; - ir_node *store = calloc(1, sizeof(ir_node)); - store->code = OC_STORE; - - arrput(store->out, current_control); - - arrput(store->out, current_memory); - arrput(store->out, existing_def); - arrput(store->out, val_node); - - store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe); - hmput(global_hash, *store, store); - - current_memory = store; - - return val_node; - } - - set_def(var_name, val_node, false); - return val_node; -} - -static ir_node *build_binary(ast_node *node) -{ - ir_node *n = calloc(1, sizeof(ir_node)); - switch (node->expr.binary.operator) { - case OP_ASSIGN: - free(n); - return build_assign(node); - case OP_ASSIGN_PTR: - free(n); - return build_assign_ptr(node); - case OP_PLUS: - n->code = OC_ADD; - break; - case OP_MINUS: - n->code = OC_SUB; - break; - case OP_DIV: - n->code = OC_DIV; - break; - case OP_MUL: - n->code = OC_MUL; - break; - case OP_MOD: - n->code = OC_MOD; - break; - case OP_BOR: - n->code = OC_BOR; - break; - case OP_BAND: - n->code = OC_BAND; - break; - case OP_BXOR: - n->code = OC_BXOR; - break; - case OP_EQ: - n->code = OC_EQ; - break; - default: - break; - } - arrput(n->out, build_expression(node->expr.binary.left)); - arrput(n->out, build_expression(node->expr.binary.right)); - n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); - const_fold(n); - ir_node *tmp = hmget(global_hash, *n); - if (tmp) { - free(n); - return tmp; - } - - return n; -} - -static ir_node *build_load(ast_node *node) -{ - ir_node *n = calloc(1, sizeof(ir_node)); - n->code = OC_LOAD; - - arrput(n->out, current_memory); - arrput(n->out, build_expression(node)); - n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabebabecafe); - - ir_node *tmp = hmget(global_hash, *n); - if (tmp) { - free(n); - return tmp; - } - - return n; -} - -static ir_node *build_unary(ast_node *node) -{ - ir_node *n = calloc(1, sizeof(ir_node)); - switch (node->expr.unary.operator) { - case UOP_MINUS: - n->code = OC_NEG; - arrput(n->out, build_expression(node->expr.unary.right)); - break; - case UOP_REF: - free(n); - - if (node->expr.unary.right->type == NODE_IDENTIFIER) { - struct symbol_def *def = get_def(node->expr.unary.right->expr.string.start); - if (def) { - return def->node; - } - } - - return build_expression(node->expr.unary.right); - case UOP_DEREF: - free(n); - return build_load(node->expr.unary.right); - default: - break; - } - - if (n->out && n->out[0]->code == OC_CONST_INT) { - switch (n->code) { - case OC_NEG: - n->data.const_int = -(n->out[0]->data.const_int); - break; - default: - break; - } - n->code = OC_CONST_INT; - arrfree(n->out); n->out = NULL; - } else if (n->out && n->out[0]->code == OC_CONST_FLOAT) { - switch (n->code) { - case OC_NEG: - n->data.const_float = -(n->out[0]->data.const_float); - break; - default: - break; - } - n->code = OC_CONST_FLOAT; - arrfree(n->out); n->out = NULL; - } - - n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); - ir_node *tmp = hmget(global_hash, *n); - if (tmp) { - free(n); - return tmp; - } - - return n; -} - -static ir_node *build_if(ast_node *node) -{ - ir_node *condition = build_expression(node->expr.if_stmt.condition); - - ir_node *if_node = calloc(1, sizeof(ir_node)); - if_node->code = OC_IF; - arrput(if_node->out, condition); - arrput(if_node->out, current_control); - if_node->id = stbds_hash_bytes(if_node, sizeof(ir_node), 0xcafebabe); - hmput(global_hash, *if_node, if_node); - - ir_node *proj_true = calloc(1, sizeof(ir_node)); - proj_true->code = OC_PROJ; - arrput(proj_true->out, if_node); - proj_true->id = stbds_hash_bytes(proj_true, sizeof(ir_node), 0xcafebabe); - hmput(global_hash, *proj_true, proj_true); - - ir_node *proj_false = calloc(1, sizeof(ir_node)); - proj_false->code = OC_PROJ; - arrput(proj_false->out, if_node); - proj_false->id = stbds_hash_bytes(proj_false, sizeof(ir_node), 0xcafebabe); - hmput(global_hash, *proj_false, proj_false); - - ir_node *base_scope = copy_scope(current_scope); - ir_node *base_mem = current_memory; - - current_control = proj_true; - - ast_node *current = node->expr.if_stmt.body; - while (current && current->type == NODE_UNIT) { - if (current->expr.unit_node.expr) { - build_expression(current->expr.unit_node.expr); - } - current = current->expr.unit_node.next; - } - ir_node *then_scope = current_scope; - ir_node *then_mem = current_memory; - ir_node *then_control = current_control; - - current_scope = copy_scope(base_scope); - current_memory = base_mem; - - current_control = proj_false; - current = node->expr.if_stmt.otherwise; - while (current && current->type == NODE_UNIT) { - if (current->expr.unit_node.expr) { - build_expression(current->expr.unit_node.expr); - } - current = current->expr.unit_node.next; - } - ir_node *else_scope = current_scope; - ir_node *else_mem = current_memory; - ir_node *else_control = current_control; - - ir_node *region = calloc(1, sizeof(ir_node)); - region->code = OC_REGION; - arrput(region->out, then_control); - arrput(region->out, else_control); - region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe); - hmput(global_hash, *region, region); - - if (then_mem->id != else_mem->id) { - ir_node *phi = calloc(1, sizeof(ir_node)); - phi->code = OC_PHI; - arrput(phi->out, region); - arrput(phi->out, then_mem); - arrput(phi->out, else_mem); - phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe); - - hmput(global_hash, *phi, phi); - - current_memory = phi; - } else { - current_memory = then_mem; - } - - current_scope = base_scope; - - for (int i = 0; i < arrlen(current_scope->data.symbol_tables); i++) { - symbol_table *base_table = current_scope->data.symbol_tables[i]; - for (int j = 0; j < shlen(base_table); j++) { - char *key = base_table[j].key; - - ir_node *found_then = NULL; - symbol_table *t_table = then_scope->data.symbol_tables[i]; - if (shget(t_table, key)->node) found_then = shget(t_table, key)->node; - else found_then = base_table[j].value->node; - - ir_node *found_else = NULL; - symbol_table *e_table = else_scope->data.symbol_tables[i]; - if (shget(e_table, key)->node) found_else = shget(e_table, key)->node; - else found_else = base_table[j].value->node; - - if (found_then->id != found_else->id) { - ir_node *phi = calloc(1, sizeof(ir_node)); - phi->code = OC_PHI; - arrput(phi->out, region); - arrput(phi->out, found_then); - arrput(phi->out, found_else); - phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe); - struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); - def->node = phi; - def->is_lvalue = false; - shput(current_scope->data.symbol_tables[i], key, def); - hmput(global_hash, *phi, phi); - } else { - struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); - def->node = found_then; - def->is_lvalue = false; - shput(current_scope->data.symbol_tables[i], key, def); - } - } - } - - current_control = region; - - return region; -} - -static void build_return(ast_node *node) -{ - ir_node *val = NULL; - - if (node->expr.ret.value) { - val = build_expression(node->expr.ret.value); - } else { - val = calloc(1, sizeof(ir_node)); - val->code = OC_VOID; - val->id = stbds_hash_bytes(val, sizeof(ir_node), 0xcafebabe); - } - - arrput(current_func.return_controls, current_control); - arrput(current_func.return_memories, current_memory); - arrput(current_func.return_values, val); - - current_control = NULL; -} - -static void finalize_function(void) -{ - int count = arrlen(current_func.return_controls); - - if (count == 0) { - return; - } - - ir_node *final_ctrl = NULL; - ir_node *final_mem = NULL; - ir_node *final_val = NULL; - - if (count == 1) { - final_ctrl = current_func.return_controls[0]; - final_mem = current_func.return_memories[0]; - final_val = current_func.return_values[0]; - } - else { - ir_node *region = calloc(1, sizeof(ir_node)); - region->code = OC_REGION; - for (int i=0; iout, current_func.return_controls[i]); - } - hmput(global_hash, *region, region); - final_ctrl = region; - - ir_node *mem_phi = calloc(1, sizeof(ir_node)); - mem_phi->code = OC_PHI; - arrput(mem_phi->out, region); - for (int i=0; iout, current_func.return_memories[i]); - } - hmput(global_hash, *mem_phi, mem_phi); - mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe); - final_mem = mem_phi; - - ir_node *val_phi = calloc(1, sizeof(ir_node)); - val_phi->code = OC_PHI; - //arrput(val_phi->out, region); - for (int i=0; iout, current_func.return_values[i]); - } - val_phi->id = stbds_hash_bytes(val_phi, sizeof(ir_node), 0xcafebabe); - hmput(global_hash, *val_phi, val_phi); - final_val = val_phi; - - region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe); - } - - ir_node *ret = calloc(1, sizeof(ir_node)); - ret->code = OC_RETURN; - arrput(ret->out, final_ctrl); - arrput(ret->out, final_mem); - arrput(ret->out, final_val); - ret->id = stbds_hash_bytes(ret, sizeof(ir_node), 0xcafebabe); - - hmput(global_hash, *ret, ret); -} - -static ir_node *build_function(ast_node *node) -{ - memset(¤t_func, 0x0, sizeof(current_func)); - ast_node *current = node->expr.function.body; - - ir_node *func = calloc(1, sizeof(ir_node)); - func->code = OC_START; - func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe); - func->data.start_name = node->expr.function.name; - - ir_node *start_ctrl = calloc(1, sizeof(ir_node)); - start_ctrl->code = OC_PROJ; - start_ctrl->id = stbds_hash_bytes(&start_ctrl, sizeof(usize), 0xcafebabe); - arrput(start_ctrl->out, func); - hmput(global_hash, *start_ctrl, start_ctrl); - - current_control = start_ctrl; - - ir_node *start_mem = calloc(1, sizeof(ir_node)); - start_mem->code = OC_PROJ; - start_mem->id = stbds_hash_bytes(&start_mem, sizeof(usize), 0xcafebabe); - arrput(start_mem->out, func); - hmput(global_hash, *start_mem, start_mem); - - current_memory = start_mem; - - current_scope = calloc(1, sizeof(ir_node)); - current_scope->code = OC_SCOPE; - - push_scope(); - - member *m = node->expr.function.parameters; - while (m) { - ir_node *proj_param = calloc(1, sizeof(ir_node)); - proj_param->code = OC_PROJ; - arrput(proj_param->out, func); - proj_param->id = stbds_hash_bytes(proj_param, sizeof(ir_node), 0xcafebabe); - set_def(m->name, proj_param, false); - hmput(global_hash, *proj_param, proj_param); - - m = m->next; - } - - while (current && current->type == NODE_UNIT) { - if (current->expr.unit_node.expr) { - build_expression(current->expr.unit_node.expr); - } - current = current->expr.unit_node.next; - } - - func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe); - - finalize_function(); - - return func; -} - -static ir_node *build_expression(ast_node *node) -{ - ir_node *n = NULL; - ir_node *tmp = NULL; - switch (node->type) { - case NODE_UNARY: - n = build_unary(node); - break; - case NODE_BINARY: - n = build_binary(node); - break; - case NODE_INTEGER: - n = calloc(1, sizeof(ir_node)); - n->code = OC_CONST_INT; - n->data.const_int = node->expr.integer; - n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); - tmp = hmget(global_hash, *n); - if (tmp) { - free(n); - return tmp; - } - break; - case NODE_VAR_DECL: - n = calloc(1, sizeof(ir_node)); - if (node->address_taken) { - n->code = OC_STORE; - - arrput(n->out, current_memory); - arrput(n->out, build_address(-1, current_stack)); - arrput(n->out, build_expression(node->expr.var_decl.value)); - current_memory = n; - current_stack += node->expr_type->size; - n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); - hmput(global_hash, *n, n); - n = n->out[1]; - set_def(node->expr.var_decl.name, n, true); - } else { - n = build_expression(node->expr.var_decl.value); - set_def(node->expr.var_decl.name, n, false); - } - - return n; - case NODE_IDENTIFIER: - struct symbol_def *def = get_def(node->expr.string.start); - n = def->node; - - if (n && def->is_lvalue) { - ir_node *addr_node = n; - - n = calloc(1, sizeof(ir_node)); - n->code = OC_LOAD; - - arrput(n->out, current_memory); - arrput(n->out, addr_node); - - n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); - - ir_node *tmp = hmget(global_hash, *n); - if (tmp) { - free(n); - n = tmp; - } else { - hmput(global_hash, *n, n); - } - } - break; - case NODE_IF: - n = build_if(node); - break; - case NODE_RETURN: - build_return(node); - break; - default: - break; - } - - if (n) hmput(global_hash, *n, n); - return n; -} - -void ir_build(ast_node *ast) -{ - ast_node *current = ast; - - graph = calloc(1, sizeof(ir_node)); - graph->code = OC_START; - graph->id = stbds_hash_bytes(graph, sizeof(ir_node), 0xcafebabe); - graph->data.start_name = "program"; - - current_memory = calloc(1, sizeof(ir_node)); - current_memory->code = OC_FRAME_PTR; - current_memory->id = stbds_hash_bytes(current_memory, sizeof(ir_node), 0xcafebabe); - - current_scope = calloc(1, sizeof(ir_node)); - current_scope->code = OC_SCOPE; - push_scope(); - - while (current && current->type == NODE_UNIT) { - if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) { - ir_node *expr = build_function(current->expr.unit_node.expr); - arrput(graph->out, expr); - hmput(global_hash, *expr, expr); - } - current = current->expr.unit_node.next; - } - printf("digraph G {\n"); - print_graph(graph); - printf("}\n"); -} diff --git a/ir.h b/ir.h deleted file mode 100644 index bfd684f..0000000 --- a/ir.h +++ /dev/null @@ -1,65 +0,0 @@ -#ifndef IR_H -#define IR_H - -#include "utils.h" -#include "parser.h" - -struct _ir_node; -struct symbol_def { - struct _ir_node *node; - bool is_lvalue; -}; - -typedef struct { char *key; struct symbol_def *value; } symbol_table; - -typedef enum { - OC_START, - OC_ADD, - OC_SUB, - OC_MUL, - OC_DIV, - OC_MOD, - OC_BAND, - OC_BOR, - OC_BXOR, - OC_NEG, - OC_EQ, - - OC_CONST_INT, - OC_CONST_FLOAT, - OC_VOID, - - OC_FRAME_PTR, - OC_ADDR, - - OC_STORE, - OC_LOAD, - - OC_REGION, - OC_PHI, - - OC_IF, - OC_PROJ, - - OC_STOP, - OC_RETURN, - - OC_SCOPE, -} opcode; - -typedef struct _ir_node { - opcode code; - usize id; - struct _ir_node **in; - struct _ir_node **out; - union { - i64 const_int; - f64 const_float; - symbol_table **symbol_tables; - char *start_name; - } data; -} ir_node; - -void ir_build(ast_node *ast); - -#endif diff --git a/lc.c b/lc.c index 0d0ddbb..ad01157 100644 --- a/lc.c +++ b/lc.c @@ -4,7 +4,7 @@ #include "lexer.h" #include "parser.h" #include "sema.h" -#include "ir.h" +#include "codegen.h" void print_indent(int depth) { for (int i = 0; i < depth; i++) printf(" "); @@ -230,10 +230,10 @@ int main(void) arena a = arena_init(0x1000 * 0x1000 * 64); lexer *l = lexer_init(src, size, &a); parser *p = parser_init(l, &a); - //print_ast(p->ast, 0); + print_ast(p->ast, 0); sema_init(p, &a); - ir_build(p->ast); + generate(p->ast); arena_deinit(a); diff --git a/parser.c b/parser.c index 8061fc1..4fd7c65 100644 --- a/parser.c +++ b/parser.c @@ -1089,6 +1089,7 @@ static ast_node *parse_function(parser *p) { ast_node *fn = arena_alloc(p->allocator, sizeof(ast_node)); fn->type = NODE_FUNCTION; + fn->expr.function.is_extern = false; fn->expr.function.type = parse_type(p); fn->expr.function.name = peek(p)->lexeme; fn->expr.function.name_len = peek(p)->lexeme_len; @@ -1097,7 +1098,14 @@ static ast_node *parse_function(parser *p) advance(p); if (match(p, TOKEN_RPAREN)) { - fn->expr.function.body = parse_compound(p);; + // Check if this is an extern declaration (semicolon) or definition (body) + if (match_peek(p, TOKEN_SEMICOLON)) { + // Extern function - no body, just consume semicolon + advance(p); + fn->expr.function.body = NULL; + } else { + fn->expr.function.body = parse_compound(p); + } fn->expr.function.parameters = NULL; fn->expr.function.parameters_len = 0; return fn; @@ -1111,7 +1119,13 @@ static ast_node *parse_function(parser *p) error(p, "expected `,`."); return NULL; } else { - fn->expr.function.body = parse_compound(p); + // Check if this is an extern declaration (semicolon) or definition (body) + if (match_peek(p, TOKEN_SEMICOLON)) { + advance(p); + fn->expr.function.body = NULL; + } else { + fn->expr.function.body = parse_compound(p); + } return fn; } } @@ -1132,7 +1146,14 @@ static ast_node *parse_function(parser *p) prev = current; } - fn->expr.function.body = parse_compound(p); + + // Check if this is an extern declaration (semicolon) or definition (body) + if (match_peek(p, TOKEN_SEMICOLON)) { + advance(p); + fn->expr.function.body = NULL; + } else { + fn->expr.function.body = parse_compound(p); + } return fn; } @@ -1140,6 +1161,13 @@ static ast_node *parse_function(parser *p) static ast_node *parse_statement(parser *p) { token *cur = peek(p); + + /* Check for extern function declaration */ + bool is_extern = false; + if (match(p, TOKEN_EXTERN)) { + is_extern = true; + } + ast_node *type = parse_type(p); if (type && type->type == NODE_STRUCT && type->expr.structure.name_len > 0) { goto skip_struct; @@ -1148,9 +1176,20 @@ static ast_node *parse_statement(parser *p) if (p->tokens->next && p->tokens->next->type == TOKEN_LPAREN) { /* Function definition. */ p->tokens = cur; - return parse_function(p); + if (is_extern) { + advance(p); // Skip TOKEN_EXTERN + } + ast_node *fn = parse_function(p); + if (fn && is_extern) { + fn->expr.function.is_extern = true; + fn->expr.function.body = NULL; + } + return fn; } p->tokens = cur; + if (is_extern) { + advance(p); // Skip TOKEN_EXTERN for non-function case + } /* Variable declaration. */ ast_node *node = arena_alloc(p->allocator, sizeof(ast_node)); node->type = NODE_VAR_DECL; diff --git a/parser.h b/parser.h index dced7ec..bbe3d0c 100644 --- a/parser.h +++ b/parser.h @@ -230,6 +230,7 @@ typedef struct _ast_node { usize name_len; struct _ast_node *type; struct _ast_node *body; + bool is_extern; } function; struct { variant *variants; diff --git a/report.txt b/report.txt deleted file mode 100644 index e69de29..0000000 diff --git a/sema.c b/sema.c index 9f2033d..4e87625 100644 --- a/sema.c +++ b/sema.c @@ -26,9 +26,11 @@ static type *const_int = NULL; static type *const_float = NULL; static bool in_loop = false; +static bool has_errors = false; static void error(ast_node *n, char *msg) { + has_errors = true; if (n) { printf("\x1b[31m\x1b[1merror\x1b[0m\x1b[1m:%ld:%ld:\x1b[0m %s\n", n->position.row, n->position.column, msg); } else { @@ -133,6 +135,16 @@ static type *get_type(sema *s, ast_node *n) char *name = NULL; type *t = NULL; switch (n->type) { + case NODE_ACCESS: + t = get_type(s, n->expr.access.expr); + name = intern_string(s, n->expr.access.member->expr.string.start, n->expr.access.member->expr.string.len); + if (t->tag != TYPE_STRUCT) { + error(n->expr.access.expr, "expected structure."); + return NULL; + } + t = shget(t->data.structure.member_types, name); + + return t; case NODE_IDENTIFIER: name = intern_string(s, n->expr.string.start, n->expr.string.len); t = shget(type_reg, name); @@ -140,17 +152,18 @@ static type *get_type(sema *s, ast_node *n) return t; case NODE_PTR_TYPE: t = malloc(sizeof(type)); - t->size = sizeof(usize); t->alignment = sizeof(usize); if (n->expr.ptr_type.flags & PTR_RAW) { t->name = "ptr"; t->tag = TYPE_PTR; + t->size = sizeof(usize); t->data.ptr.child = get_type(s, n->expr.ptr_type.type); t->data.ptr.is_const = (n->expr.ptr_type.flags & PTR_CONST) != 0; t->data.ptr.is_volatile = (n->expr.ptr_type.flags & PTR_VOLATILE) != 0; } else { t->name = "slice"; t->tag = TYPE_SLICE; + t->size = sizeof(usize) * 2; // ptr + len = 16 bytes t->data.slice.child = get_type(s, n->expr.ptr_type.type); t->data.slice.is_const = (n->expr.ptr_type.flags & PTR_CONST) != 0; t->data.slice.is_volatile = (n->expr.ptr_type.flags & PTR_VOLATILE) != 0; @@ -365,8 +378,8 @@ static ast_node *get_def(sema *s, char *name) static type *get_string_type(sema *s, ast_node *node) { type *string_type = arena_alloc(s->allocator, sizeof(type)); - string_type->tag = TYPE_PTR; - string_type->size = sizeof(usize); + string_type->tag = TYPE_SLICE; + string_type->size = sizeof(usize) * 2; // ptr + len = 16 bytes string_type->alignment = sizeof(usize); string_type->name = "slice"; string_type->data.slice.child = shget(type_reg, "u8"); @@ -397,6 +410,33 @@ static type *get_access_type(sema *s, ast_node *node) ast_node *member = node->expr.access.member; char *name_start = member->expr.string.start; usize name_len = member->expr.string.len; + + // Handle slice field access + if (t && t->tag == TYPE_SLICE) { + char *name = intern_string(s, name_start, name_len); + if (strcmp(name, "ptr") == 0) { + // Return pointer to element type + type *ptr_type = arena_alloc(s->allocator, sizeof(type)); + ptr_type->tag = TYPE_PTR; + ptr_type->size = 8; + ptr_type->alignment = 8; + ptr_type->name = "ptr"; + ptr_type->data.ptr.child = t->data.slice.child; + ptr_type->data.ptr.is_const = t->data.slice.is_const; + ptr_type->data.ptr.is_volatile = t->data.slice.is_volatile; + free(name); + return ptr_type; + } else if (strcmp(name, "len") == 0) { + // Return usize type + free(name); + return shget(type_reg, "usize"); + } else { + error(node, "slice doesn't have that field"); + free(name); + return NULL; + } + } + if (!t || (t->tag != TYPE_STRUCT && t->tag != TYPE_UNION)) { error(node, "invalid expression."); return NULL; @@ -433,7 +473,8 @@ static bool can_cast(type *source, type *dest) switch (dest->tag) { case TYPE_INTEGER: case TYPE_UINTEGER: - return source->tag == TYPE_INTEGER_CONST; + case TYPE_INTEGER_CONST: + return source->tag == TYPE_INTEGER_CONST || source->tag == TYPE_INTEGER || source->tag == TYPE_UINTEGER; case TYPE_FLOAT: return source->tag == TYPE_FLOAT_CONST; default: @@ -544,6 +585,36 @@ static type *get_expression_type(sema *s, ast_node *node) return t; case NODE_ARRAY_SUBSCRIPT: t = get_expression_type(s, node->expr.subscript.expr); + + // Check if this is range subscripting (creates a slice) + if (node->expr.subscript.index && node->expr.subscript.index->type == NODE_RANGE) { + type *element_type = NULL; + switch (t->tag) { + case TYPE_SLICE: + element_type = t->data.slice.child; + break; + case TYPE_PTR: + element_type = t->data.ptr.child; + break; + default: + error(node, "only pointers and slices can be indexed."); + return NULL; + } + + // Return a slice type + type *slice_type = arena_alloc(s->allocator, sizeof(type)); + slice_type->tag = TYPE_SLICE; + slice_type->size = sizeof(usize) * 2; + slice_type->alignment = sizeof(usize); + slice_type->data.slice.child = element_type; + slice_type->data.slice.is_const = false; + slice_type->data.slice.len = 0; + + node->expr_type = slice_type; + return slice_type; + } + + // Regular subscript - return element type switch (t->tag) { case TYPE_SLICE: t = t->data.slice.child; @@ -558,11 +629,20 @@ static type *get_expression_type(sema *s, ast_node *node) node->expr_type = t; return t; case NODE_CALL: - prot = shget(prototypes, intern_string(s, node->expr.call.name, node->expr.call.name_len)); + node->expr.call.name = intern_string(s, node->expr.call.name, node->expr.call.name_len); + prot = shget(prototypes, node->expr.call.name); if (!prot) { error(node, "unknown function."); return NULL; } + // Process call arguments + ast_node *arg = node->expr.call.parameters; + while (arg && arg->type == NODE_UNIT) { + if (arg->expr.unit_node.expr) { + get_expression_type(s, arg->expr.unit_node.expr); + } + arg = arg->expr.unit_node.next; + } t = prot->type; node->expr_type = t; return t; @@ -709,8 +789,21 @@ static void check_statement(sema *s, ast_node *node) error(node, "redeclaration of variable."); break; } - if (!can_cast(get_expression_type(s, node->expr.var_decl.value), t) && !match(t, get_expression_type(s, node->expr.var_decl.value))) { - error(node, "type mismatch."); + if (t->tag == TYPE_STRUCT) { + // Struct initialization with NODE_STRUCT_INIT is allowed + } else if (node->expr.var_decl.value && node->expr.var_decl.value->type == NODE_STRUCT_INIT && + (t->tag == TYPE_SLICE || t->tag == TYPE_PTR)) { + // Array/slice initialization with NODE_STRUCT_INIT is allowed + } else if (node->expr.var_decl.value && node->expr.var_decl.value->type == NODE_RANGE && + t->tag == TYPE_SLICE) { + // Range initialization for slices is allowed + get_expression_type(s, node->expr.var_decl.value); + } else if (node->expr.var_decl.value && node->expr.var_decl.value->type == NODE_STRING && + t->tag == TYPE_SLICE) { + // String literal can be assigned to slice + get_expression_type(s, node->expr.var_decl.value); + } else if (!can_cast(get_expression_type(s, node->expr.var_decl.value), t) && !match(t, get_expression_type(s, node->expr.var_decl.value))) { + error(node, "type mismatch (decl)."); } shput(current_scope->defs, name, node); break; @@ -735,15 +828,18 @@ static void check_function(sema *s, ast_node *f) param_node->expr_type = p_type; param_node->address_taken = false; param_node->expr.var_decl.name = t_name; - + shput(current_scope->defs, t_name, param_node); param = param->next; } - ast_node *current = f->expr.function.body; - while (current && current->type == NODE_UNIT) { - check_statement(s, current->expr.unit_node.expr); - current = current->expr.unit_node.next; + // Skip body checking for extern functions + if (!f->expr.function.is_extern && f->expr.function.body) { + ast_node *current = f->expr.function.body; + while (current && current->type == NODE_UNIT) { + check_statement(s, current->expr.unit_node.expr); + current = current->expr.unit_node.next; + } } pop_scope(s); @@ -797,6 +893,7 @@ void sema_init(parser *p, arena *a) register_type(s, "u16", create_integer(s, "u16", 16, false)); register_type(s, "u32", create_integer(s, "u32", 32, false)); register_type(s, "u64", create_integer(s, "u64", 64, false)); + register_type(s, "usize", create_integer(s, "usize", 64, false)); register_type(s, "i8", create_integer(s, "i8", 8, true)); register_type(s, "i16", create_integer(s, "i16", 16, true)); register_type(s, "i32", create_integer(s, "i32", 32, true)); @@ -815,4 +912,9 @@ void sema_init(parser *p, arena *a) const_float->data.flt = 0; analyze_unit(s, s->ast); + + if (has_errors) { + printf("Compilation failed.\n"); + exit(1); + } } diff --git a/test b/test new file mode 100755 index 0000000..1b89d35 Binary files /dev/null and b/test differ diff --git a/test.l b/test.l index 6c3aec2..2763f97 100644 --- a/test.l +++ b/test.l @@ -1,12 +1,20 @@ -u32 main(u32 b) +extern i64 write(i32 fd, *u8 buf, u64 count); +extern void exit(i32 code); +extern *u8 malloc(usize size); + +i32 main() { - u32 a = 4; - //return a; - if (b == 3) { - return 3; - } else { - return 4; + [u8] message = "Hello world!\n"; + *u8 message_heap = malloc(message.len); + [u8] new_message = message_heap[0..13]; + u32 i = 0; + + loop while i < message.len { + new_message[i] = message[i]; + i = i + 1; } - - return a; + + write(1, new_message.ptr, new_message.len); + + return 0; } diff --git a/test.s b/test.s new file mode 100644 index 0000000..b86c66e --- /dev/null +++ b/test.s @@ -0,0 +1,90 @@ +.section .text +.global main +main: +push %rbp +mov %rsp, %rbp +sub $256, %rsp +movb $72, -32(%rbp) +movb $101, -31(%rbp) +movb $108, -30(%rbp) +movb $108, -29(%rbp) +movb $111, -28(%rbp) +movb $32, -27(%rbp) +movb $119, -26(%rbp) +movb $111, -25(%rbp) +movb $114, -24(%rbp) +movb $108, -23(%rbp) +movb $100, -22(%rbp) +movb $33, -21(%rbp) +movb $10, -20(%rbp) +lea -32(%rbp), %rax +mov %rax, -48(%rbp) +mov $14, %rax +mov %rax, -40(%rbp) +mov -40(%rbp), %rax +push %rax +pop %rdi +call malloc +mov %rax, -56(%rbp) +mov -56(%rbp), %rcx +mov $0, %rax +push %rax +mov $13, %rax +mov %rax, %rdx +pop %rax +mov %rdx, %r8 +sub %rax, %r8 +inc %r8 +add %rcx, %rax +mov %rax, -88(%rbp) +mov %r8, -80(%rbp) +lea -88(%rbp), %rax +mov (%rax), %rcx +mov 8(%rax), %rdx +mov %rcx, -72(%rbp) +mov %rdx, -64(%rbp) +mov $0, %rax +mov %rax, -96(%rbp) +.L0: +mov -96(%rbp), %rax +mov %rax, %rcx +mov -40(%rbp), %rax +cmp %rax, %rcx +setl %al +movzx %al, %rax +test %rax, %rax +jz .L1 +mov -72(%rbp), %rcx +mov -96(%rbp), %rax +add %rcx, %rax +push %rax +mov -48(%rbp), %rcx +mov -96(%rbp), %rax +add %rcx, %rax +movzbl (%rax), %eax +pop %rcx +mov %al, (%rcx) +mov -96(%rbp), %rax +mov %rax, %rcx +mov $1, %rax +add %rcx, %rax +mov %rax, -96(%rbp) +jmp .L0 +.L1: +mov $1, %rax +push %rax +mov -72(%rbp), %rax +push %rax +mov -64(%rbp), %rax +push %rax +pop %rdx +pop %rsi +pop %rdi +call write +mov $0, %rax +mov %rbp, %rsp +pop %rbp +ret +mov %rbp, %rsp +pop %rbp +ret diff --git a/todo.cfg b/todo.cfg deleted file mode 100644 index 5e35825..0000000 --- a/todo.cfg +++ /dev/null @@ -1,2 +0,0 @@ -export TODO_DIR="." -export TODO_FILE="$TODO_DIR/todo.txt" diff --git a/todo.txt b/todo.txt deleted file mode 100644 index c023562..0000000 --- a/todo.txt +++ /dev/null @@ -1 +0,0 @@ -implement dominator tree for control flow