From 849e0b68633f6ee2b12ee12f33df3e43ec1964f0 Mon Sep 17 00:00:00 2001 From: Lorenzo Torres Date: Sun, 7 Dec 2025 16:48:48 +0100 Subject: [PATCH] preliminary work on sea of nodes based intermediate representation. --- Makefile | 8 +- config.mk | 2 +- ir.c | 662 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ ir.h | 63 ++++++ lc.c | 9 +- lexer.c | 6 +- lexer.h | 2 +- parser.c | 34 ++- parser.h | 8 + sema.c | 161 ++++++++++--- sema.h | 2 +- test.l | 15 +- utils.c | 4 +- 13 files changed, 918 insertions(+), 58 deletions(-) create mode 100644 ir.c create mode 100644 ir.h diff --git a/Makefile b/Makefile index 2a3094b..203633a 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,8 @@ include config.mk -SRC = lc.c utils.c lexer.c parser.c sema.c -HDR = config.def.h utils.h lexer.h parser.h sema.h +SRC = lc.c utils.c lexer.c parser.c sema.c ir.c +HDR = config.def.h utils.h lexer.h parser.h sema.h ir.h OBJ = ${SRC:.c=.o} all: options lc @@ -51,5 +51,9 @@ install: all uninstall: rm -f ${DESTDIR}${PREFIX}/bin/lc\ ${DESTDIR}${MANPREFIX}/man1/lc.1 +graph: clean all + ./lc > graph.dot + dot -Tpdf graph.dot > graph.pdf + zathura ./graph.pdf .PHONY: all options clean dist install uninstall diff --git a/config.mk b/config.mk index c797027..d6cbc51 100644 --- a/config.mk +++ b/config.mk @@ -15,7 +15,7 @@ INCS = -I. LIBS = # flags CPPFLAGS = -DVERSION=\"${VERSION}\" -CFLAGS := -std=c99 -pedantic -Wall -O0 ${INCS} ${CPPFLAGS} +CFLAGS := -std=c23 -pedantic -Wall -O0 ${INCS} ${CPPFLAGS} CFLAGS := ${CFLAGS} -g LDFLAGS = ${LIBS} diff --git a/ir.c b/ir.c new file mode 100644 index 0000000..3d13fe3 --- /dev/null +++ b/ir.c @@ -0,0 +1,662 @@ +#include "ir.h" +#include +#include +#include "stb_ds.h" +#include "sema.h" + +struct { ir_node key; ir_node *value; } *global_hash = NULL; +static ir_node *graph; +static ir_node *current_memory; +static ir_node *current_control; +static usize current_stack = 0; + +static ir_node *current_scope = NULL; + +static ir_node *build_expression(ast_node *node); + +static void node_name(ir_node *node) +{ + if (!node) { + printf("null [label=\"NULL\", style=filled, fillcolor=red]\n"); + return; + } + printf("%ld ", node->id); + switch (node->code) { + case OC_START: + printf("[label=\"start\", style=filled, color=orange]\n"); + break; + case OC_ADD: + printf("[label=\"+\"]\n"); + break; + case OC_NEG: + case OC_SUB: + printf("[label=\"-\"]\n"); + break; + case OC_DIV: + printf("[label=\"/\"]\n"); + break; + case OC_MUL: + printf("[label=\"*\"]\n"); + break; + case OC_MOD: + printf("[label=\"%%\"]\n"); + break; + case OC_BAND: + printf("[label=\"&\"]\n"); + break; + case OC_BOR: + printf("[label=\"|\"]\n"); + break; + case OC_BXOR: + printf("[label=\"^\"]\n"); + break; + case OC_EQ: + printf("[label=\"==\"]\n"); + break; + case OC_CONST_INT: + printf("[label=\"%ld\"]\n", node->data.const_int); + break; + case OC_CONST_FLOAT: + printf("[label=\"%f\"]\n", node->data.const_float); + break; + case OC_FRAME_PTR: + printf("[label=\"frame_ptr\"]\n"); + break; + case OC_STORE: + printf("[label=\"store\", shape=box]\n"); + break; + case OC_LOAD: + printf("[label=\"load\", shape=box]\n"); + break; + case OC_ADDR: + printf("[label=\"addr\"]\n"); + break; + case OC_REGION: + printf("[label=\"region\", shape=diamond, style=filled, color=green]\n"); + break; + case OC_PHI: + printf("[label=\"phi\", shape=triangle]\n"); + break; + case OC_IF: + printf("[label=\"if\", shape=diamond, style=filled, color=lightblue]\n"); + break; + case OC_PROJ: + printf("[label=\"proj\", shape=diamond, style=filled, color=cyan]\n"); + break; + default: + printf("[label=\"%d\"]\n", node->code); + break; + } +} + +static void print_graph(ir_node *node) +{ + for (int i = 0; i < hmlen(global_hash); i++) { + ir_node *node = global_hash[i].value; + node_name(node); + + for (int j = 0; j < arrlen(node->out); j++) { + if (node->out[j]) { + node_name(node->out[j]); + printf("%ld->%ld\n", node->out[j]->id, node->id); + } + } + } +} + +static void push_scope(void) +{ + arrput(current_scope->data.symbol_tables, NULL); +} + +static struct symbol_def *get_def(char *name) +{ + for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) { + struct symbol_def *def = shget(current_scope->data.symbol_tables[i], name); + if (def) return def; + } + return NULL; +} + +static void set_def(char *name, ir_node *node, bool lvalue) +{ + for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) { + if (shget(current_scope->data.symbol_tables[i], name)) { + struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); + def->is_lvalue = lvalue; + def->node = node; + shput(current_scope->data.symbol_tables[i], name, def); + return; + } + } + int index = arrlen(current_scope->data.symbol_tables) - 1; + struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); + def->is_lvalue = lvalue; + def->node = node; + shput(current_scope->data.symbol_tables[index], name, def); +} + +static ir_node *copy_scope(ir_node *src) +{ + ir_node *dst = calloc(1, sizeof(ir_node)); + dst->code = OC_SCOPE; + + for (int i=0; i < arrlen(src->data.symbol_tables); i++) { + arrput(dst->data.symbol_tables, NULL); + symbol_table *src_table = src->data.symbol_tables[i]; + for (int j=0; j < shlen(src_table); j++) { + shput(dst->data.symbol_tables[i], src_table[j].key, src_table[j].value); + } + } + return dst; +} + +static void const_fold(ir_node *binary) +{ + ir_node *left = binary->out[0]; + ir_node *right = binary->out[1]; + + if (left->code == OC_CONST_INT && right->code == OC_CONST_INT) { + switch (binary->code) { + case OC_ADD: + binary->data.const_int = left->data.const_int + right->data.const_int; + break; + case OC_SUB: + binary->data.const_int = left->data.const_int - right->data.const_int; + break; + case OC_MUL: + binary->data.const_int = left->data.const_int * right->data.const_int; + break; + case OC_DIV: + if (right->data.const_int != 0) + binary->data.const_int = left->data.const_int / right->data.const_int; + break; + case OC_MOD: + if (right->data.const_int != 0) + binary->data.const_int = left->data.const_int % right->data.const_int; + break; + case OC_BOR: + binary->data.const_int = left->data.const_int | right->data.const_int; + break; + case OC_BAND: + binary->data.const_int = left->data.const_int & right->data.const_int; + break; + case OC_BXOR: + binary->data.const_int = left->data.const_int ^ right->data.const_int; + break; + case OC_EQ: + binary->data.const_int = left->data.const_int == right->data.const_int; + break; + default: + return; + } + binary->code = OC_CONST_INT; + arrfree(binary->out); binary->out = NULL; + arrfree(binary->in); binary->in = NULL; + binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe); + } + + if (left->code == OC_CONST_FLOAT && right->code == OC_CONST_FLOAT) { + switch (binary->code) { + case OC_ADD: + binary->data.const_float = left->data.const_float + right->data.const_float; + break; + case OC_SUB: + binary->data.const_float = left->data.const_float - right->data.const_float; + break; + case OC_MUL: + binary->data.const_float = left->data.const_float * right->data.const_float; + break; + case OC_DIV: + if (right->data.const_float != 0.0f) + binary->data.const_float = left->data.const_float / right->data.const_float; + break; + default: + return; + } + binary->code = OC_CONST_FLOAT; + arrfree(binary->out); binary->out = NULL; + arrfree(binary->in); binary->in = NULL; + binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe); + } +} + +static ir_node *build_address(usize base, usize offset) { + ir_node *addr = calloc(1, sizeof(ir_node)); + addr->code = OC_ADDR; + + ir_node *base_node = calloc(1, sizeof(ir_node)); + if (base == -1) { + base_node->code = OC_FRAME_PTR; + base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe); + } else { + base_node->code = OC_CONST_INT; + base_node->data.const_int = base; + base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe); + } + + ir_node *offset_node = calloc(1, sizeof(ir_node)); + offset_node->code = OC_CONST_INT; + offset_node->data.const_int = offset; + offset_node->id = stbds_hash_bytes(offset_node, sizeof(ir_node), 0xcafebabe); + + arrput(addr->out, base_node); + arrput(addr->out, offset_node); + + addr->id = stbds_hash_bytes(addr, sizeof(ir_node), 0xcafebabe); + ir_node *tmp = hmget(global_hash, *addr); + if (tmp) { + free(addr); + return tmp; + } + + return addr; +} + +static ir_node *build_assign_ptr(ast_node *binary) +{ + ir_node *val_node = build_expression(binary->expr.binary.right); + + char *var_name = binary->expr.binary.left->expr.string.start; + + ir_node *existing_def = get_def(var_name)->node; + + ir_node *store = calloc(1, sizeof(ir_node)); + store->code = OC_STORE; + + arrput(store->out, current_memory); + arrput(store->out, existing_def); + arrput(store->out, val_node); + + store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *store, store); + + current_memory = store; + + return val_node; +} + +static ir_node *build_assign(ast_node *binary) +{ + ir_node *val_node = build_expression(binary->expr.binary.right); + + char *var_name = binary->expr.binary.left->expr.string.start; + + struct symbol_def *def = get_def(var_name); + + if (def && def->is_lvalue) { + ir_node *existing_def = def->node; + ir_node *store = calloc(1, sizeof(ir_node)); + store->code = OC_STORE; + + arrput(store->out, current_memory); + arrput(store->out, existing_def); + arrput(store->out, val_node); + + store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *store, store); + + current_memory = store; + + return val_node; + } + + set_def(var_name, val_node, false); + return val_node; +} + +static ir_node *build_binary(ast_node *node) +{ + ir_node *n = calloc(1, sizeof(ir_node)); + switch (node->expr.binary.operator) { + case OP_ASSIGN: + free(n); + return build_assign(node); + case OP_ASSIGN_PTR: + free(n); + return build_assign_ptr(node); + case OP_PLUS: + n->code = OC_ADD; + break; + case OP_MINUS: + n->code = OC_SUB; + break; + case OP_DIV: + n->code = OC_DIV; + break; + case OP_MUL: + n->code = OC_MUL; + break; + case OP_MOD: + n->code = OC_MOD; + break; + case OP_BOR: + n->code = OC_BOR; + break; + case OP_BAND: + n->code = OC_BAND; + break; + case OP_BXOR: + n->code = OC_BXOR; + break; + case OP_EQ: + n->code = OC_EQ; + break; + default: + break; + } + arrput(n->out, build_expression(node->expr.binary.left)); + arrput(n->out, build_expression(node->expr.binary.right)); + n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); + const_fold(n); + ir_node *tmp = hmget(global_hash, *n); + if (tmp) { + free(n); + return tmp; + } + + return n; +} + +static ir_node *build_load(ast_node *node) +{ + ir_node *n = calloc(1, sizeof(ir_node)); + n->code = OC_LOAD; + + arrput(n->out, current_memory); + arrput(n->out, build_expression(node)); + n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabebabecafe); + + ir_node *tmp = hmget(global_hash, *n); + if (tmp) { + free(n); + return tmp; + } + + return n; +} + +static ir_node *build_unary(ast_node *node) +{ + ir_node *n = calloc(1, sizeof(ir_node)); + switch (node->expr.unary.operator) { + case UOP_MINUS: + n->code = OC_NEG; + arrput(n->out, build_expression(node->expr.unary.right)); + break; + case UOP_REF: + free(n); + + if (node->expr.unary.right->type == NODE_IDENTIFIER) { + struct symbol_def *def = get_def(node->expr.unary.right->expr.string.start); + if (def) { + return def->node; + } + } + + return build_expression(node->expr.unary.right); + case UOP_DEREF: + free(n); + return build_load(node->expr.unary.right); + default: + break; + } + + if (n->out && n->out[0]->code == OC_CONST_INT) { + switch (n->code) { + case OC_NEG: + n->data.const_int = -(n->out[0]->data.const_int); + break; + default: + break; + } + n->code = OC_CONST_INT; + arrfree(n->out); n->out = NULL; + } else if (n->out && n->out[0]->code == OC_CONST_FLOAT) { + switch (n->code) { + case OC_NEG: + n->data.const_float = -(n->out[0]->data.const_float); + break; + default: + break; + } + n->code = OC_CONST_FLOAT; + arrfree(n->out); n->out = NULL; + } + + n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); + ir_node *tmp = hmget(global_hash, *n); + if (tmp) { + free(n); + return tmp; + } + + return n; +} + +static ir_node *build_if(ast_node *node) +{ + ir_node *condition = build_expression(node->expr.if_stmt.condition); + + ir_node *if_node = calloc(1, sizeof(ir_node)); + if_node->code = OC_IF; + arrput(if_node->out, condition); + arrput(if_node->out, current_control); + if_node->id = stbds_hash_bytes(if_node, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *if_node, if_node); + + ir_node *proj_true = calloc(1, sizeof(ir_node)); + proj_true->code = OC_PROJ; + arrput(proj_true->out, if_node); + proj_true->id = stbds_hash_bytes(proj_true, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *proj_true, proj_true); + + ir_node *proj_false = calloc(1, sizeof(ir_node)); + proj_false->code = OC_PROJ; + arrput(proj_false->out, if_node); + proj_false->id = stbds_hash_bytes(proj_false, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *proj_false, proj_false); + + ir_node *base_scope = copy_scope(current_scope); + ir_node *base_mem = current_memory; + + current_control = proj_true; + + ast_node *current = node->expr.if_stmt.body; + while (current && current->type == NODE_UNIT) { + if (current->expr.unit_node.expr) { + ir_node *expr = build_expression(current->expr.unit_node.expr); + arrput(graph->out, expr); + } + current = current->expr.unit_node.next; + } + ir_node *then_scope = current_scope; + ir_node *then_mem = current_memory; + ir_node *then_control = current_control; + + current_scope = copy_scope(base_scope); + current_memory = base_mem; + + current_control = proj_false; + current = node->expr.if_stmt.otherwise; + while (current && current->type == NODE_UNIT) { + if (current->expr.unit_node.expr) { + ir_node *expr = build_expression(current->expr.unit_node.expr); + arrput(graph->out, expr); + } + current = current->expr.unit_node.next; + } + ir_node *else_scope = current_scope; + ir_node *else_mem = current_memory; + ir_node *else_control = current_control; + + ir_node *region = calloc(1, sizeof(ir_node)); + region->code = OC_REGION; + arrput(region->out, then_control); + arrput(region->out, else_control); + region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *region, region); + + if (then_mem->id != else_mem->id) { + ir_node *phi = calloc(1, sizeof(ir_node)); + phi->code = OC_PHI; + arrput(phi->out, region); + arrput(phi->out, then_mem); + arrput(phi->out, else_mem); + phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe); + + hmput(global_hash, *phi, phi); + + current_memory = phi; + } else { + current_memory = then_mem; + } + + current_scope = base_scope; + + for (int i = 0; i < arrlen(current_scope->data.symbol_tables); i++) { + symbol_table *base_table = current_scope->data.symbol_tables[i]; + for (int j = 0; j < shlen(base_table); j++) { + char *key = base_table[j].key; + + ir_node *found_then = NULL; + symbol_table *t_table = then_scope->data.symbol_tables[i]; + if (shget(t_table, key)->node) found_then = shget(t_table, key)->node; + else found_then = base_table[j].value->node; + + ir_node *found_else = NULL; + symbol_table *e_table = else_scope->data.symbol_tables[i]; + if (shget(e_table, key)->node) found_else = shget(e_table, key)->node; + else found_else = base_table[j].value->node; + + if (found_then->id != found_else->id) { + ir_node *phi = calloc(1, sizeof(ir_node)); + phi->code = OC_PHI; + arrput(phi->out, region); + arrput(phi->out, found_then); + arrput(phi->out, found_else); + phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe); + struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); + def->node = phi; + def->is_lvalue = false; + shput(current_scope->data.symbol_tables[i], key, def); + hmput(global_hash, *phi, phi); + } else { + struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); + def->node = found_then; + def->is_lvalue = false; + shput(current_scope->data.symbol_tables[i], key, def); + } + } + } + + current_control = region; + + return region; +} + +static ir_node *build_expression(ast_node *node) +{ + ir_node *n = NULL; + ir_node *tmp = NULL; + switch (node->type) { + case NODE_UNARY: + n = build_unary(node); + break; + case NODE_BINARY: + n = build_binary(node); + break; + case NODE_INTEGER: + n = calloc(1, sizeof(ir_node)); + n->code = OC_CONST_INT; + n->data.const_int = node->expr.integer; + n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); + tmp = hmget(global_hash, *n); + if (tmp) { + free(n); + return tmp; + } + break; + case NODE_VAR_DECL: + n = calloc(1, sizeof(ir_node)); + if (node->address_taken) { + n->code = OC_STORE; + + arrput(n->out, current_memory); + arrput(n->out, build_address(-1, current_stack)); + arrput(n->out, build_expression(node->expr.var_decl.value)); + current_memory = n; + current_stack += node->expr_type->size; + n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *n, n); + n = n->out[1]; + set_def(node->expr.var_decl.name, n, true); + } else { + n = build_expression(node->expr.var_decl.value); + set_def(node->expr.var_decl.name, n, false); + } + + return n; + case NODE_IDENTIFIER: + struct symbol_def *def = get_def(node->expr.string.start); + n = def->node; + + if (n && def->is_lvalue) { + ir_node *addr_node = n; + + n = calloc(1, sizeof(ir_node)); + n->code = OC_LOAD; + + arrput(n->out, current_memory); + arrput(n->out, addr_node); + + n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); + + ir_node *tmp = hmget(global_hash, *n); + if (tmp) { + free(n); + n = tmp; + } else { + hmput(global_hash, *n, n); + } + } + break; + case NODE_IF: + n = build_if(node); + break; + default: + break; + } + + if (n) hmput(global_hash, *n, n); + return n; +} + +void ir_build(ast_node *ast) +{ + ast_node *current = ast; + + graph = calloc(1, sizeof(ir_node)); + graph->code = OC_START; + graph->id = stbds_hash_bytes(graph, sizeof(ir_node), 0xcafebabe); + current_control = graph; + + current_memory = calloc(1, sizeof(ir_node)); + current_memory->code = OC_FRAME_PTR; + current_memory->id = stbds_hash_bytes(current_memory, sizeof(ir_node), 0xcafebabe); + + current_scope = calloc(1, sizeof(ir_node)); + current_scope->code = OC_SCOPE; + push_scope(); + + while (current && current->type == NODE_UNIT) { + if (current->expr.unit_node.expr) { + ir_node *expr = build_expression(current->expr.unit_node.expr); + arrput(graph->out, expr); + } + current = current->expr.unit_node.next; + } + printf("digraph G {\n"); + print_graph(graph); + printf("}\n"); +} diff --git a/ir.h b/ir.h new file mode 100644 index 0000000..fc18bf8 --- /dev/null +++ b/ir.h @@ -0,0 +1,63 @@ +#ifndef IR_H +#define IR_H + +#include "utils.h" +#include "parser.h" + +struct _ir_node; +struct symbol_def { + struct _ir_node *node; + bool is_lvalue; +}; + +typedef struct { char *key; struct symbol_def *value; } symbol_table; + +typedef enum { + OC_START, + OC_ADD, + OC_SUB, + OC_MUL, + OC_DIV, + OC_MOD, + OC_BAND, + OC_BOR, + OC_BXOR, + OC_NEG, + OC_EQ, + + OC_CONST_INT, + OC_CONST_FLOAT, + + OC_FRAME_PTR, + OC_ADDR, + + OC_STORE, + OC_LOAD, + + OC_REGION, + OC_PHI, + + OC_IF, + OC_PROJ, + + OC_STOP, + OC_RETURN, + + OC_SCOPE, +} opcode; + +typedef struct _ir_node { + opcode code; + usize id; + struct _ir_node **in; + struct _ir_node **out; + union { + i64 const_int; + f64 const_float; + symbol_table **symbol_tables; + } data; +} ir_node; + +void ir_build(ast_node *ast); + +#endif diff --git a/lc.c b/lc.c index e4a0229..0ec50fb 100644 --- a/lc.c +++ b/lc.c @@ -4,6 +4,7 @@ #include "lexer.h" #include "parser.h" #include "sema.h" +#include "ir.h" void print_indent(int depth) { for (int i = 0; i < depth; i++) printf(" "); @@ -17,6 +18,7 @@ const char* get_op_str(binary_op op) { case OP_MUL: return "*"; case OP_EQ: return "=="; case OP_ASSIGN: return "="; + case OP_ASSIGN_PTR: return "<-"; case OP_AND: return "&&"; case OP_OR: return "||"; case OP_NEQ: return "!="; @@ -215,7 +217,7 @@ void print_ast(ast_node *node, int depth) { int main(void) { - FILE *fp = fopen("examples/hello_world.l", "r"); + FILE *fp = fopen("test.l", "r"); usize size = 0; fseek(fp, 0, SEEK_END); size = ftell(fp); @@ -228,8 +230,11 @@ int main(void) arena a = arena_init(0x1000 * 0x1000 * 64); lexer *l = lexer_init(src, size, &a); parser *p = parser_init(l, &a); - print_ast(p->ast, 0); + //print_ast(p->ast, 0); sema *s = sema_init(p, &a); + //(void) s; + + ir_build(p->ast); arena_deinit(a); diff --git a/lexer.c b/lexer.c index 07d1a9d..22063fd 100644 --- a/lexer.c +++ b/lexer.c @@ -133,9 +133,6 @@ static bool parse_special(lexer *l) } else if (l->source[l->index+1] == '-') { add_token(l, TOKEN_MINUS_MINUS, 2); l->index += 2; - } else if (l->source[l->index+1] == '>') { - add_token(l, TOKEN_ARROW, 2); - l->index += 2; } else { add_token(l, TOKEN_MINUS, 1); l->index += 1; @@ -231,6 +228,9 @@ static bool parse_special(lexer *l) if (l->source[l->index+1] == '=') { add_token(l, TOKEN_LESS_EQ, 2); l->index += 2; + } else if (l->source[l->index+1] == '-') { + add_token(l, TOKEN_ARROW, 2); + l->index += 2; } else if (l->source[l->index+1] == '<') { if (l->source[l->index+2] == '=') { add_token(l, TOKEN_LSHIFT_EQ, 3); diff --git a/lexer.h b/lexer.h index a3859ba..72277df 100644 --- a/lexer.h +++ b/lexer.h @@ -16,10 +16,10 @@ typedef enum { TOKEN_AND, // & TOKEN_HAT, // ^ TOKEN_PIPE, // | - TOKEN_ARROW, // -> TOKEN_LSHIFT, // << TOKEN_RSHIFT, // >> TOKEN_DOUBLE_EQ, // == + TOKEN_ARROW, // <- TOKEN_EQ, // = TOKEN_LESS_THAN, // < TOKEN_GREATER_THAN, // > diff --git a/parser.c b/parser.c index 96b5c87..8061fc1 100644 --- a/parser.c +++ b/parser.c @@ -367,9 +367,21 @@ ast_node *parse_term(parser *p) { ast_node *left = parse_unary(p); - while (match_peek(p, TOKEN_STAR) || match_peek(p, TOKEN_SLASH)) - { - binary_op op = peek(p)->type == TOKEN_STAR ? OP_MUL : OP_DIV; + while (match_peek(p, TOKEN_STAR) || match_peek(p, TOKEN_SLASH) || match_peek(p, TOKEN_PERC)) { + binary_op op; + switch (peek(p)->type) { + case TOKEN_STAR: + op = OP_MUL; + break; + case TOKEN_SLASH: + op = OP_DIV; + break; + case TOKEN_PERC: + op = OP_MOD; + break; + default: + continue; + } advance(p); ast_node *right = parse_factor(p); ast_node *node = arena_alloc(p->allocator, sizeof(ast_node)); @@ -561,6 +573,9 @@ ast_node *parse_expression(parser *p) binary_op op; switch (p->tokens->type) { + case TOKEN_ARROW: + op = OP_ASSIGN_PTR; + break; case TOKEN_EQ: op = OP_ASSIGN; break; @@ -854,8 +869,12 @@ static ast_node *parse_if(parser *p) ast_node *node = arena_alloc(p->allocator, sizeof(ast_node)); node->type = NODE_IF; node->position = p->previous->position; - node->expr.whle.body = body; - node->expr.whle.condition = condition; + node->expr.if_stmt.body = body; + node->expr.if_stmt.condition = condition; + if (match(p, TOKEN_ELSE)) { + body = parse_compound(p); + node->expr.if_stmt.otherwise = body; + } return node; } @@ -1302,11 +1321,6 @@ static void parse(parser *p) ast_node *tail = p->ast; ast_node *expr = parse_statement(p); while (expr) { - if (expr->type != NODE_FUNCTION && expr->type != NODE_VAR_DECL && expr->type != NODE_IMPORT && - expr->type != NODE_STRUCT && expr->type != NODE_UNION && expr->type != NODE_ENUM) { - error(p, "expected function, struct, enum, union, global variable or import statement."); - return; - } tail->expr.unit_node.next = arena_alloc(p->allocator, sizeof(ast_node)); tail->expr.unit_node.next->expr.unit_node.expr = expr; tail = tail->expr.unit_node.next; diff --git a/parser.h b/parser.h index 5428101..dced7ec 100644 --- a/parser.h +++ b/parser.h @@ -19,6 +19,7 @@ typedef enum { OP_BXOR, // ^ OP_ASSIGN, // = + OP_ASSIGN_PTR, // <- OP_RSHIFT_EQ, // >>= OP_LSHIFT_EQ, // <<= OP_PLUS_EQ, // += @@ -126,6 +127,7 @@ typedef struct _ast_node { node_type type; source_pos position; struct _type *expr_type; + bool address_taken; // used in IR generation. union { struct { struct _ast_node *type; @@ -200,6 +202,12 @@ typedef struct _ast_node { struct _ast_node *body; u8 flags; } whle; // while + struct { + struct _ast_node *condition; + struct _ast_node *body; + struct _ast_node *otherwise; + u8 flags; + } if_stmt; // while struct { struct _ast_node **statements; usize stmt_len; diff --git a/sema.c b/sema.c index 15ea070..bb14821 100644 --- a/sema.c +++ b/sema.c @@ -27,7 +27,6 @@ static type *const_float = NULL; static bool in_loop = false; -/* Print the error message and sync the parser. */ static void error(ast_node *n, char *msg) { if (n) { @@ -78,7 +77,6 @@ static type *create_float(sema *s, char *name, u8 bits) return t; } -/* https://en.wikipedia.org/wiki/Topological_sorting */ static void order_type(sema *s, ast_node *node) { if (node->type == NODE_STRUCT || node->type == NODE_UNION) { @@ -350,12 +348,12 @@ static void pop_scope(sema *s) current_scope = current_scope->parent; } -static type *get_def(sema *s, char *name) +static ast_node *get_def(sema *s, char *name) { scope *current = current_scope; while (current) { - type *t = shget(current->defs, name); - if (t) return t; + ast_node *def = shget(current->defs, name); + if (def) return def; current = current->parent; } @@ -416,11 +414,13 @@ static type *get_identifier_type(sema *s, ast_node *node) { char *name_start = node->expr.string.start; usize name_len = node->expr.string.len; - type *t = get_def(s, intern_string(s, name_start, name_len)); - if (!t) { + char *name = intern_string(s, name_start, name_len); + node->expr.string.start = name; + ast_node *def = get_def(s, name); + if (!def) { error(node, "unknown identifier."); } - return t; + return def->expr_type; } static bool match(type *t1, type *t2); @@ -450,60 +450,129 @@ static type *get_expression_type(sema *s, ast_node *node) prototype *prot = NULL; switch (node->type) { case NODE_IDENTIFIER: - return get_identifier_type(s, node); + t = get_identifier_type(s, node); + node->expr_type = t; + return t; case NODE_INTEGER: + node->expr_type = const_int; return const_int; case NODE_FLOAT: + node->expr_type = const_float; return const_float; case NODE_STRING: - return get_string_type(s, node); + t = get_string_type(s, node); + node->expr_type = t; + return t; case NODE_CHAR: - return shget(type_reg, "u8"); + t = shget(type_reg, "u8"); + node->expr_type = t; + return t; case NODE_BOOL: - return shget(type_reg, "bool"); + t = shget(type_reg, "bool"); + node->expr_type = t; + return t; case NODE_CAST: - return get_type(s, node->expr.cast.type); + t = get_type(s, node->expr.cast.type); + node->expr_type = t; + return t; case NODE_POSTFIX: case NODE_UNARY: - return get_expression_type(s, node->expr.unary.right); + t = get_expression_type(s, node->expr.unary.right); + if (node->expr.unary.operator == UOP_REF) { + ast_node *target = node->expr.unary.right; + while (target->type == NODE_ACCESS) { + target = target->expr.access.expr; + } + + if (target->type != NODE_IDENTIFIER) { + error(node, "expected identifier."); + return NULL; + } + + char *name = target->expr.string.start; + ast_node *def = get_def(s, name); + + if (def) { + def->address_taken = true; + target->address_taken = true; + } + + type *tmp = t; + t = arena_alloc(s->allocator, sizeof(type)); + t->tag = TYPE_PTR; + t->size = sizeof(usize); + t->alignment = sizeof(usize); + t->name = "ptr"; + t->data.ptr.is_const = false; + t->data.ptr.is_volatile = false; + t->data.ptr.child = tmp; + } else if (node->expr.unary.operator == UOP_DEREF) { + if (t->tag != TYPE_PTR) { + error(node, "only pointers can be dereferenced."); + return NULL; + } + t = t->data.ptr.child; + } + node->expr_type = t; + return t; case NODE_BINARY: t = get_expression_type(s, node->expr.binary.left); if (!t) return NULL; - if (!match(t, get_expression_type(s, node->expr.binary.right))) { + if (node->expr.binary.operator == OP_ASSIGN_PTR) { + if (t->tag != TYPE_PTR) { + error(node, "expected pointer."); + return NULL; + } + t = t->data.ptr.child; + } + if (!can_cast(get_expression_type(s, node->expr.binary.right), t) && !match(t, get_expression_type(s, node->expr.binary.right))) { error(node, "type mismatch."); + node->expr_type = NULL; return NULL; } if (node->expr.binary.operator >= OP_EQ) { - return shget(type_reg, "bool"); + t = shget(type_reg, "bool"); } else if (node->expr.binary.operator >= OP_ASSIGN && node->expr.binary.operator <= OP_MOD_EQ) { - return shget(type_reg, "void"); - } else { - return t; + t = shget(type_reg, "void"); } + node->expr_type = t; + return t; case NODE_RANGE: - return get_range_type(s, node); + t = get_range_type(s, node); + node->expr_type = t; + return t; case NODE_ARRAY_SUBSCRIPT: t = get_expression_type(s, node->expr.subscript.expr); switch (t->tag) { case TYPE_SLICE: - return t->data.slice.child; + t = t->data.slice.child; + break; case TYPE_PTR: - return t->data.ptr.child; + t = t->data.ptr.child; + break; default: error(node, "only pointers and slices can be indexed."); return NULL; } + node->expr_type = t; + return t; case NODE_CALL: prot = shget(prototypes, intern_string(s, node->expr.call.name, node->expr.call.name_len)); if (!prot) { error(node, "unknown function."); return NULL; } - return prot->type; + t = prot->type; + node->expr_type = t; + return t; case NODE_ACCESS: - return get_access_type(s, node); + t = get_access_type(s, node); + node->expr_type = t; + return t; default: - return shget(type_reg, "void"); + t = shget(type_reg, "void"); + node->expr_type = t; + return t; } } @@ -567,7 +636,14 @@ static void check_for(sema *s, ast_node *node) while (current_capture) { type *c_type = get_expression_type(s, current_slice->expr.unit_node.expr); char *c_name = intern_string(s, current_capture->expr.unit_node.expr->expr.string.start, current_capture->expr.unit_node.expr->expr.string.len); - shput(current_scope->defs, c_name, c_type); + + ast_node *cap_node = arena_alloc(s->allocator, sizeof(ast_node)); + cap_node->type = NODE_VAR_DECL; + cap_node->expr_type = c_type; + cap_node->address_taken = false; + cap_node->expr.var_decl.name = c_name; + + shput(current_scope->defs, c_name, cap_node); current_capture = current_capture->expr.unit_node.next; current_slice = current_slice->expr.unit_node.next; } @@ -611,12 +687,23 @@ static void check_statement(sema *s, ast_node *node) check_body(s, node->expr.whle.body); in_loop = false; break; + case NODE_IF: + if (!match(get_expression_type(s, node->expr.if_stmt.condition), shget(type_reg, "bool"))) { + error(node, "expected boolean value."); + return; + } + + check_body(s, node->expr.if_stmt.body); + if (node->expr.if_stmt.otherwise) check_body(s, node->expr.if_stmt.otherwise); + break; case NODE_FOR: check_for(s, node); break; case NODE_VAR_DECL: t = get_type(s, node->expr.var_decl.type); + node->expr_type = t; name = intern_string(s, node->expr.var_decl.name, node->expr.var_decl.name_len); + node->expr.var_decl.name = name; if (get_def(s, name)) { error(node, "redeclaration of variable."); break; @@ -624,7 +711,7 @@ static void check_statement(sema *s, ast_node *node) if (!can_cast(get_expression_type(s, node->expr.var_decl.value), t) && !match(t, get_expression_type(s, node->expr.var_decl.value))) { error(node, "type mismatch."); } - shput(current_scope->defs, name, t); + shput(current_scope->defs, name, node); break; default: get_expression_type(s, node); @@ -641,7 +728,14 @@ static void check_function(sema *s, ast_node *f) while (param) { type *p_type = get_type(s, param->type); char *t_name = intern_string(s, param->name, param->name_len); - shput(current_scope->defs, t_name, p_type); + + ast_node *param_node = arena_alloc(s->allocator, sizeof(ast_node)); + param_node->type = NODE_VAR_DECL; + param_node->expr_type = p_type; + param_node->address_taken = false; + param_node->expr.var_decl.name = t_name; + + shput(current_scope->defs, t_name, param_node); param = param->next; } @@ -658,7 +752,8 @@ static void analyze_unit(sema *s, ast_node *node) { ast_node *current = node; while (current && current->type == NODE_UNIT) { - order_type(s, current->expr.unit_node.expr); + if (current->expr.unit_node.expr) + order_type(s, current->expr.unit_node.expr); current = current->expr.unit_node.next; } @@ -666,7 +761,7 @@ static void analyze_unit(sema *s, ast_node *node) current = node; while (current && current->type == NODE_UNIT) { - if (current->expr.unit_node.expr->type == NODE_FUNCTION) { + if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) { create_prototype(s, current->expr.unit_node.expr); } current = current->expr.unit_node.next; @@ -674,8 +769,10 @@ static void analyze_unit(sema *s, ast_node *node) current = node; while (current && current->type == NODE_UNIT) { - if (current->expr.unit_node.expr->type == NODE_FUNCTION) { + if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) { check_function(s, current->expr.unit_node.expr); + } else { + check_statement(s, current->expr.unit_node.expr); } current = current->expr.unit_node.next; } @@ -720,5 +817,3 @@ sema *sema_init(parser *p, arena *a) return s; } - - diff --git a/sema.h b/sema.h index a1f4285..84fab1c 100644 --- a/sema.h +++ b/sema.h @@ -63,7 +63,7 @@ typedef struct { typedef struct _scope { struct _scope *parent; - struct { char *key; type *value; } *defs; + struct { char *key; ast_node *value; } *defs; } scope; typedef struct { diff --git a/test.l b/test.l index dbc32f1..fb9e411 100644 --- a/test.l +++ b/test.l @@ -1,5 +1,12 @@ -u32 a() -{ - [u32] v = {1, 2, 3}; - return z[0]; +u32 a = 2; + +if (a == 3) { + a = 5; + if (a == 4) { + a = 3; + } +} else { + a = 1; } + +u32 d = a; diff --git a/utils.c b/utils.c index e30188a..c6f0781 100644 --- a/utils.c +++ b/utils.c @@ -112,10 +112,12 @@ static usize align_forward(usize ptr, usize align) { arena arena_init(usize size) { + void *memory = malloc(size); + memset(memory, 0x0, size); return (arena){ .capacity = size, .position = 0, - .memory = malloc(size), + .memory = memory, }; }