broken! added return nodes

This commit is contained in:
Lorenzo Torres 2025-12-07 19:54:28 +01:00
parent f7689a3f54
commit 89e0d41fd9
4 changed files with 173 additions and 20 deletions

166
ir.c
View file

@ -14,6 +14,12 @@ static ir_node *current_scope = NULL;
static ir_node *build_expression(ast_node *node); static ir_node *build_expression(ast_node *node);
static struct {
ir_node **return_controls;
ir_node **return_memories;
ir_node **return_values;
} current_func = {0};
static void node_name(ir_node *node) static void node_name(ir_node *node)
{ {
if (!node) { if (!node) {
@ -23,7 +29,10 @@ static void node_name(ir_node *node)
printf("%ld ", node->id); printf("%ld ", node->id);
switch (node->code) { switch (node->code) {
case OC_START: case OC_START:
printf("[label=\"start\", style=filled, color=orange]\n"); printf("[label=\"%s\", style=filled, color=orange]\n", node->data.start_name);
break;
case OC_RETURN:
printf("[label=\"return\", style=filled, color=orange]\n");
break; break;
case OC_ADD: case OC_ADD:
printf("[label=\"+\"]\n"); printf("[label=\"+\"]\n");
@ -264,6 +273,8 @@ static ir_node *build_assign_ptr(ast_node *binary)
ir_node *store = calloc(1, sizeof(ir_node)); ir_node *store = calloc(1, sizeof(ir_node));
store->code = OC_STORE; store->code = OC_STORE;
arrput(store->out, current_control);
arrput(store->out, current_memory); arrput(store->out, current_memory);
arrput(store->out, existing_def); arrput(store->out, existing_def);
arrput(store->out, val_node); arrput(store->out, val_node);
@ -289,6 +300,8 @@ static ir_node *build_assign(ast_node *binary)
ir_node *store = calloc(1, sizeof(ir_node)); ir_node *store = calloc(1, sizeof(ir_node));
store->code = OC_STORE; store->code = OC_STORE;
arrput(store->out, current_control);
arrput(store->out, current_memory); arrput(store->out, current_memory);
arrput(store->out, existing_def); arrput(store->out, existing_def);
arrput(store->out, val_node); arrput(store->out, val_node);
@ -465,8 +478,7 @@ static ir_node *build_if(ast_node *node)
ast_node *current = node->expr.if_stmt.body; ast_node *current = node->expr.if_stmt.body;
while (current && current->type == NODE_UNIT) { while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) { if (current->expr.unit_node.expr) {
ir_node *expr = build_expression(current->expr.unit_node.expr); build_expression(current->expr.unit_node.expr);
arrput(graph->out, expr);
} }
current = current->expr.unit_node.next; current = current->expr.unit_node.next;
} }
@ -481,8 +493,7 @@ static ir_node *build_if(ast_node *node)
current = node->expr.if_stmt.otherwise; current = node->expr.if_stmt.otherwise;
while (current && current->type == NODE_UNIT) { while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) { if (current->expr.unit_node.expr) {
ir_node *expr = build_expression(current->expr.unit_node.expr); build_expression(current->expr.unit_node.expr);
arrput(graph->out, expr);
} }
current = current->expr.unit_node.next; current = current->expr.unit_node.next;
} }
@ -555,6 +566,141 @@ static ir_node *build_if(ast_node *node)
return region; return region;
} }
static void build_return(ast_node *node)
{
ir_node *val = NULL;
if (node->expr.ret.value) {
val = build_expression(node->expr.ret.value);
} else {
val = calloc(1, sizeof(ir_node));
val->code = OC_VOID;
val->id = stbds_hash_bytes(val, sizeof(ir_node), 0xcafebabe);
}
arrput(current_func.return_controls, current_control);
arrput(current_func.return_memories, current_memory);
arrput(current_func.return_values, val);
current_control = NULL;
}
static void finalize_function(void)
{
int count = arrlen(current_func.return_controls);
if (count == 0) {
return;
}
ir_node *final_ctrl = NULL;
ir_node *final_mem = NULL;
ir_node *final_val = NULL;
if (count == 1) {
final_ctrl = current_func.return_controls[0];
final_mem = current_func.return_memories[0];
final_val = current_func.return_values[0];
}
else {
ir_node *region = calloc(1, sizeof(ir_node));
region->code = OC_REGION;
for (int i=0; i<count; i++) {
arrput(region->out, current_func.return_controls[i]);
}
hmput(global_hash, *region, region);
final_ctrl = region;
ir_node *mem_phi = calloc(1, sizeof(ir_node));
mem_phi->code = OC_PHI;
arrput(mem_phi->out, region);
for (int i=0; i<count; i++) {
arrput(mem_phi->out, current_func.return_memories[i]);
}
hmput(global_hash, *mem_phi, mem_phi);
mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe);
final_mem = mem_phi;
ir_node *val_phi = calloc(1, sizeof(ir_node));
val_phi->code = OC_PHI;
//arrput(val_phi->out, region);
for (int i=0; i<count; i++) {
arrput(val_phi->out, current_func.return_values[i]);
}
val_phi->id = stbds_hash_bytes(val_phi, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *val_phi, val_phi);
final_val = val_phi;
region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe);
}
ir_node *ret = calloc(1, sizeof(ir_node));
ret->code = OC_RETURN;
arrput(ret->out, final_ctrl);
arrput(ret->out, final_mem);
arrput(ret->out, final_val);
ret->id = stbds_hash_bytes(ret, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *ret, ret);
}
static ir_node *build_function(ast_node *node)
{
memset(&current_func, 0x0, sizeof(current_func));
ast_node *current = node->expr.function.body;
ir_node *func = calloc(1, sizeof(ir_node));
func->code = OC_START;
func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe);
func->data.start_name = node->expr.function.name;
ir_node *start_ctrl = calloc(1, sizeof(ir_node));
start_ctrl->code = OC_PROJ;
start_ctrl->id = stbds_hash_bytes(&start_ctrl, sizeof(usize), 0xcafebabe);
arrput(start_ctrl->out, func);
hmput(global_hash, *start_ctrl, start_ctrl);
current_control = start_ctrl;
ir_node *start_mem = calloc(1, sizeof(ir_node));
start_mem->code = OC_PROJ;
start_mem->id = stbds_hash_bytes(&start_mem, sizeof(usize), 0xcafebabe);
arrput(start_mem->out, func);
hmput(global_hash, *start_mem, start_mem);
current_memory = start_mem;
current_scope = calloc(1, sizeof(ir_node));
current_scope->code = OC_SCOPE;
push_scope();
member *m = node->expr.function.parameters;
while (m) {
ir_node *proj_param = calloc(1, sizeof(ir_node));
proj_param->code = OC_PROJ;
arrput(proj_param->out, func);
proj_param->id = stbds_hash_bytes(proj_param, sizeof(ir_node), 0xcafebabe);
set_def(m->name, proj_param, false);
hmput(global_hash, *proj_param, proj_param);
m = m->next;
}
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe);
finalize_function();
return func;
}
static ir_node *build_expression(ast_node *node) static ir_node *build_expression(ast_node *node)
{ {
ir_node *n = NULL; ir_node *n = NULL;
@ -624,6 +770,9 @@ static ir_node *build_expression(ast_node *node)
case NODE_IF: case NODE_IF:
n = build_if(node); n = build_if(node);
break; break;
case NODE_RETURN:
build_return(node);
break;
default: default:
break; break;
} }
@ -639,7 +788,7 @@ void ir_build(ast_node *ast)
graph = calloc(1, sizeof(ir_node)); graph = calloc(1, sizeof(ir_node));
graph->code = OC_START; graph->code = OC_START;
graph->id = stbds_hash_bytes(graph, sizeof(ir_node), 0xcafebabe); graph->id = stbds_hash_bytes(graph, sizeof(ir_node), 0xcafebabe);
current_control = graph; graph->data.start_name = "program";
current_memory = calloc(1, sizeof(ir_node)); current_memory = calloc(1, sizeof(ir_node));
current_memory->code = OC_FRAME_PTR; current_memory->code = OC_FRAME_PTR;
@ -650,9 +799,10 @@ void ir_build(ast_node *ast)
push_scope(); push_scope();
while (current && current->type == NODE_UNIT) { while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) { if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) {
ir_node *expr = build_expression(current->expr.unit_node.expr); ir_node *expr = build_function(current->expr.unit_node.expr);
arrput(graph->out, expr); arrput(graph->out, expr);
hmput(global_hash, *expr, expr);
} }
current = current->expr.unit_node.next; current = current->expr.unit_node.next;
} }

2
ir.h
View file

@ -27,6 +27,7 @@ typedef enum {
OC_CONST_INT, OC_CONST_INT,
OC_CONST_FLOAT, OC_CONST_FLOAT,
OC_VOID,
OC_FRAME_PTR, OC_FRAME_PTR,
OC_ADDR, OC_ADDR,
@ -55,6 +56,7 @@ typedef struct _ir_node {
i64 const_int; i64 const_int;
f64 const_float; f64 const_float;
symbol_table **symbol_tables; symbol_table **symbol_tables;
char *start_name;
} data; } data;
} ir_node; } ir_node;

5
sema.c
View file

@ -316,6 +316,7 @@ static void create_prototype(sema *s, ast_node *node)
{ {
prototype *p = arena_alloc(s->allocator, sizeof(prototype)); prototype *p = arena_alloc(s->allocator, sizeof(prototype));
p->name = intern_string(s, node->expr.function.name, node->expr.function.name_len); p->name = intern_string(s, node->expr.function.name, node->expr.function.name_len);
node->expr.function.name = p->name;
if (shget(prototypes, p->name)) { if (shget(prototypes, p->name)) {
error(node, "function already defined."); error(node, "function already defined.");
} }
@ -668,7 +669,7 @@ static void check_statement(sema *s, ast_node *node)
char *name = NULL; char *name = NULL;
switch(node->type) { switch(node->type) {
case NODE_RETURN: case NODE_RETURN:
if (!match(get_expression_type(s, node->expr.ret.value), current_return)) { if (!can_cast(get_expression_type(s, node->expr.ret.value), current_return) && !match(get_expression_type(s, node->expr.ret.value), current_return)) {
error(node, "return type doesn't match function's one."); error(node, "return type doesn't match function's one.");
} }
break; break;
@ -728,7 +729,7 @@ static void check_function(sema *s, ast_node *f)
while (param) { while (param) {
type *p_type = get_type(s, param->type); type *p_type = get_type(s, param->type);
char *t_name = intern_string(s, param->name, param->name_len); char *t_name = intern_string(s, param->name, param->name_len);
param->name = t_name;
ast_node *param_node = arena_alloc(s->allocator, sizeof(ast_node)); ast_node *param_node = arena_alloc(s->allocator, sizeof(ast_node));
param_node->type = NODE_VAR_DECL; param_node->type = NODE_VAR_DECL;
param_node->expr_type = p_type; param_node->expr_type = p_type;

20
test.l
View file

@ -1,12 +1,12 @@
u32 a = 2; u32 main(u32 b)
{
if (a == 3) { u32 a = 4;
a = 5; //return a;
if (a == 4) { if (b == 3) {
a = 3; return 3;
} else {
return 4;
} }
} else {
a = 1; return a;
} }
u32 d = a;