diff --git a/ir.c b/ir.c index 3d13fe3..61da14a 100644 --- a/ir.c +++ b/ir.c @@ -14,6 +14,12 @@ static ir_node *current_scope = NULL; static ir_node *build_expression(ast_node *node); +static struct { + ir_node **return_controls; + ir_node **return_memories; + ir_node **return_values; +} current_func = {0}; + static void node_name(ir_node *node) { if (!node) { @@ -23,7 +29,10 @@ static void node_name(ir_node *node) printf("%ld ", node->id); switch (node->code) { case OC_START: - printf("[label=\"start\", style=filled, color=orange]\n"); + printf("[label=\"%s\", style=filled, color=orange]\n", node->data.start_name); + break; + case OC_RETURN: + printf("[label=\"return\", style=filled, color=orange]\n"); break; case OC_ADD: printf("[label=\"+\"]\n"); @@ -264,6 +273,8 @@ static ir_node *build_assign_ptr(ast_node *binary) ir_node *store = calloc(1, sizeof(ir_node)); store->code = OC_STORE; + arrput(store->out, current_control); + arrput(store->out, current_memory); arrput(store->out, existing_def); arrput(store->out, val_node); @@ -289,6 +300,8 @@ static ir_node *build_assign(ast_node *binary) ir_node *store = calloc(1, sizeof(ir_node)); store->code = OC_STORE; + arrput(store->out, current_control); + arrput(store->out, current_memory); arrput(store->out, existing_def); arrput(store->out, val_node); @@ -465,8 +478,7 @@ static ir_node *build_if(ast_node *node) ast_node *current = node->expr.if_stmt.body; while (current && current->type == NODE_UNIT) { if (current->expr.unit_node.expr) { - ir_node *expr = build_expression(current->expr.unit_node.expr); - arrput(graph->out, expr); + build_expression(current->expr.unit_node.expr); } current = current->expr.unit_node.next; } @@ -481,8 +493,7 @@ static ir_node *build_if(ast_node *node) current = node->expr.if_stmt.otherwise; while (current && current->type == NODE_UNIT) { if (current->expr.unit_node.expr) { - ir_node *expr = build_expression(current->expr.unit_node.expr); - arrput(graph->out, expr); + build_expression(current->expr.unit_node.expr); } current = current->expr.unit_node.next; } @@ -555,6 +566,141 @@ static ir_node *build_if(ast_node *node) return region; } +static void build_return(ast_node *node) +{ + ir_node *val = NULL; + + if (node->expr.ret.value) { + val = build_expression(node->expr.ret.value); + } else { + val = calloc(1, sizeof(ir_node)); + val->code = OC_VOID; + val->id = stbds_hash_bytes(val, sizeof(ir_node), 0xcafebabe); + } + + arrput(current_func.return_controls, current_control); + arrput(current_func.return_memories, current_memory); + arrput(current_func.return_values, val); + + current_control = NULL; +} + +static void finalize_function(void) +{ + int count = arrlen(current_func.return_controls); + + if (count == 0) { + return; + } + + ir_node *final_ctrl = NULL; + ir_node *final_mem = NULL; + ir_node *final_val = NULL; + + if (count == 1) { + final_ctrl = current_func.return_controls[0]; + final_mem = current_func.return_memories[0]; + final_val = current_func.return_values[0]; + } + else { + ir_node *region = calloc(1, sizeof(ir_node)); + region->code = OC_REGION; + for (int i=0; iout, current_func.return_controls[i]); + } + hmput(global_hash, *region, region); + final_ctrl = region; + + ir_node *mem_phi = calloc(1, sizeof(ir_node)); + mem_phi->code = OC_PHI; + arrput(mem_phi->out, region); + for (int i=0; iout, current_func.return_memories[i]); + } + hmput(global_hash, *mem_phi, mem_phi); + mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe); + final_mem = mem_phi; + + ir_node *val_phi = calloc(1, sizeof(ir_node)); + val_phi->code = OC_PHI; + //arrput(val_phi->out, region); + for (int i=0; iout, current_func.return_values[i]); + } + val_phi->id = stbds_hash_bytes(val_phi, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *val_phi, val_phi); + final_val = val_phi; + + region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe); + } + + ir_node *ret = calloc(1, sizeof(ir_node)); + ret->code = OC_RETURN; + arrput(ret->out, final_ctrl); + arrput(ret->out, final_mem); + arrput(ret->out, final_val); + ret->id = stbds_hash_bytes(ret, sizeof(ir_node), 0xcafebabe); + + hmput(global_hash, *ret, ret); +} + +static ir_node *build_function(ast_node *node) +{ + memset(¤t_func, 0x0, sizeof(current_func)); + ast_node *current = node->expr.function.body; + + ir_node *func = calloc(1, sizeof(ir_node)); + func->code = OC_START; + func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe); + func->data.start_name = node->expr.function.name; + + ir_node *start_ctrl = calloc(1, sizeof(ir_node)); + start_ctrl->code = OC_PROJ; + start_ctrl->id = stbds_hash_bytes(&start_ctrl, sizeof(usize), 0xcafebabe); + arrput(start_ctrl->out, func); + hmput(global_hash, *start_ctrl, start_ctrl); + + current_control = start_ctrl; + + ir_node *start_mem = calloc(1, sizeof(ir_node)); + start_mem->code = OC_PROJ; + start_mem->id = stbds_hash_bytes(&start_mem, sizeof(usize), 0xcafebabe); + arrput(start_mem->out, func); + hmput(global_hash, *start_mem, start_mem); + + current_memory = start_mem; + + current_scope = calloc(1, sizeof(ir_node)); + current_scope->code = OC_SCOPE; + + push_scope(); + + member *m = node->expr.function.parameters; + while (m) { + ir_node *proj_param = calloc(1, sizeof(ir_node)); + proj_param->code = OC_PROJ; + arrput(proj_param->out, func); + proj_param->id = stbds_hash_bytes(proj_param, sizeof(ir_node), 0xcafebabe); + set_def(m->name, proj_param, false); + hmput(global_hash, *proj_param, proj_param); + + m = m->next; + } + + while (current && current->type == NODE_UNIT) { + if (current->expr.unit_node.expr) { + build_expression(current->expr.unit_node.expr); + } + current = current->expr.unit_node.next; + } + + func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe); + + finalize_function(); + + return func; +} + static ir_node *build_expression(ast_node *node) { ir_node *n = NULL; @@ -624,6 +770,9 @@ static ir_node *build_expression(ast_node *node) case NODE_IF: n = build_if(node); break; + case NODE_RETURN: + build_return(node); + break; default: break; } @@ -639,7 +788,7 @@ void ir_build(ast_node *ast) graph = calloc(1, sizeof(ir_node)); graph->code = OC_START; graph->id = stbds_hash_bytes(graph, sizeof(ir_node), 0xcafebabe); - current_control = graph; + graph->data.start_name = "program"; current_memory = calloc(1, sizeof(ir_node)); current_memory->code = OC_FRAME_PTR; @@ -650,9 +799,10 @@ void ir_build(ast_node *ast) push_scope(); while (current && current->type == NODE_UNIT) { - if (current->expr.unit_node.expr) { - ir_node *expr = build_expression(current->expr.unit_node.expr); + if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) { + ir_node *expr = build_function(current->expr.unit_node.expr); arrput(graph->out, expr); + hmput(global_hash, *expr, expr); } current = current->expr.unit_node.next; } diff --git a/ir.h b/ir.h index fc18bf8..bfd684f 100644 --- a/ir.h +++ b/ir.h @@ -27,6 +27,7 @@ typedef enum { OC_CONST_INT, OC_CONST_FLOAT, + OC_VOID, OC_FRAME_PTR, OC_ADDR, @@ -55,6 +56,7 @@ typedef struct _ir_node { i64 const_int; f64 const_float; symbol_table **symbol_tables; + char *start_name; } data; } ir_node; diff --git a/sema.c b/sema.c index 4f6afb9..9f2033d 100644 --- a/sema.c +++ b/sema.c @@ -316,6 +316,7 @@ static void create_prototype(sema *s, ast_node *node) { prototype *p = arena_alloc(s->allocator, sizeof(prototype)); p->name = intern_string(s, node->expr.function.name, node->expr.function.name_len); + node->expr.function.name = p->name; if (shget(prototypes, p->name)) { error(node, "function already defined."); } @@ -668,7 +669,7 @@ static void check_statement(sema *s, ast_node *node) char *name = NULL; switch(node->type) { case NODE_RETURN: - if (!match(get_expression_type(s, node->expr.ret.value), current_return)) { + if (!can_cast(get_expression_type(s, node->expr.ret.value), current_return) && !match(get_expression_type(s, node->expr.ret.value), current_return)) { error(node, "return type doesn't match function's one."); } break; @@ -728,7 +729,7 @@ static void check_function(sema *s, ast_node *f) while (param) { type *p_type = get_type(s, param->type); char *t_name = intern_string(s, param->name, param->name_len); - + param->name = t_name; ast_node *param_node = arena_alloc(s->allocator, sizeof(ast_node)); param_node->type = NODE_VAR_DECL; param_node->expr_type = p_type; diff --git a/test.l b/test.l index fb9e411..6c3aec2 100644 --- a/test.l +++ b/test.l @@ -1,12 +1,12 @@ -u32 a = 2; - -if (a == 3) { - a = 5; - if (a == 4) { - a = 3; +u32 main(u32 b) +{ + u32 a = 4; + //return a; + if (b == 3) { + return 3; + } else { + return 4; } -} else { - a = 1; + + return a; } - -u32 d = a;