From 032c04ad1d4982544f16e27b827d42a87d6c8854 Mon Sep 17 00:00:00 2001 From: Lorenzo Torres Date: Wed, 14 Jan 2026 18:36:27 +0100 Subject: [PATCH] added global code motion --- ir.c | 1288 +++++++++++++++++++++++++++++++++++++++++++++++++++++++- ir.h | 45 ++ sema.c | 11 +- test.l | 14 +- 4 files changed, 1341 insertions(+), 17 deletions(-) diff --git a/ir.c b/ir.c index 61da14a..a3e2068 100644 --- a/ir.c +++ b/ir.c @@ -62,6 +62,27 @@ static void node_name(ir_node *node) case OC_EQ: printf("[label=\"==\"]\n"); break; + case OC_NEQ: + printf("[label=\"!=\"]\n"); + break; + case OC_LT: + printf("[label=\"<\"]\n"); + break; + case OC_GT: + printf("[label=\">\"]\n"); + break; + case OC_LE: + printf("[label=\"<=\"]\n"); + break; + case OC_GE: + printf("[label=\">=\"]\n"); + break; + case OC_AND: + printf("[label=\"&&\"]\n"); + break; + case OC_OR: + printf("[label=\"||\"]\n"); + break; case OC_CONST_INT: printf("[label=\"%ld\"]\n", node->data.const_int); break; @@ -92,6 +113,12 @@ static void node_name(ir_node *node) case OC_PROJ: printf("[label=\"proj\", shape=diamond, style=filled, color=cyan]\n"); break; + case OC_CALL: + printf("[label=\"call %s\", shape=box, style=filled, color=yellow]\n", node->data.call_name); + break; + case OC_LOOP: + printf("[label=\"loop\", shape=diamond, style=filled, color=purple]\n"); + break; default: printf("[label=\"%d\"]\n", node->code); break; @@ -196,6 +223,27 @@ static void const_fold(ir_node *binary) case OC_EQ: binary->data.const_int = left->data.const_int == right->data.const_int; break; + case OC_NEQ: + binary->data.const_int = left->data.const_int != right->data.const_int; + break; + case OC_LT: + binary->data.const_int = left->data.const_int < right->data.const_int; + break; + case OC_GT: + binary->data.const_int = left->data.const_int > right->data.const_int; + break; + case OC_LE: + binary->data.const_int = left->data.const_int <= right->data.const_int; + break; + case OC_GE: + binary->data.const_int = left->data.const_int >= right->data.const_int; + break; + case OC_AND: + binary->data.const_int = left->data.const_int && right->data.const_int; + break; + case OC_OR: + binary->data.const_int = left->data.const_int || right->data.const_int; + break; default: return; } @@ -230,6 +278,294 @@ static void const_fold(ir_node *binary) } } +static ir_node *peephole(ir_node *node) +{ + if (!node || !node->out || arrlen(node->out) < 2) + return node; + + ir_node *left = node->out[0]; + ir_node *right = node->out[1]; + + bool left_is_zero = (left->code == OC_CONST_INT && left->data.const_int == 0); + bool right_is_zero = (right->code == OC_CONST_INT && right->data.const_int == 0); + bool left_is_one = (left->code == OC_CONST_INT && left->data.const_int == 1); + bool right_is_one = (right->code == OC_CONST_INT && right->data.const_int == 1); + bool same_operand = (left->id == right->id); + + switch (node->code) { + case OC_ADD: + // x + 0 = x + if (right_is_zero) { + free(node); + return left; + } + // 0 + x = x + if (left_is_zero) { + free(node); + return right; + } + break; + + case OC_SUB: + // x - 0 = x + if (right_is_zero) { + free(node); + return left; + } + // x - x = 0 + if (same_operand) { + node->code = OC_CONST_INT; + node->data.const_int = 0; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + break; + + case OC_MUL: + // x * 0 = 0 + if (right_is_zero) { + free(node); + return right; + } + // 0 * x = 0 + if (left_is_zero) { + free(node); + return left; + } + // x * 1 = x + if (right_is_one) { + free(node); + return left; + } + // 1 * x = x + if (left_is_one) { + free(node); + return right; + } + break; + + case OC_DIV: + // x / 1 = x + if (right_is_one) { + free(node); + return left; + } + // 0 / x = 0 (when x != 0, but we assume no div by zero) + if (left_is_zero && !right_is_zero) { + free(node); + return left; + } + // x / x = 1 (assuming x != 0) + if (same_operand) { + node->code = OC_CONST_INT; + node->data.const_int = 1; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + break; + + case OC_MOD: + // 0 % x = 0 + if (left_is_zero) { + free(node); + return left; + } + // x % 1 = 0 + if (right_is_one) { + node->code = OC_CONST_INT; + node->data.const_int = 0; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + // x % x = 0 + if (same_operand) { + node->code = OC_CONST_INT; + node->data.const_int = 0; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + break; + + case OC_BOR: + // x | 0 = x + if (right_is_zero) { + free(node); + return left; + } + // 0 | x = x + if (left_is_zero) { + free(node); + return right; + } + // x | x = x + if (same_operand) { + free(node); + return left; + } + break; + + case OC_BAND: + // x & 0 = 0 + if (right_is_zero) { + free(node); + return right; + } + // 0 & x = 0 + if (left_is_zero) { + free(node); + return left; + } + // x & x = x + if (same_operand) { + free(node); + return left; + } + break; + + case OC_BXOR: + // x ^ 0 = x + if (right_is_zero) { + free(node); + return left; + } + // 0 ^ x = x + if (left_is_zero) { + free(node); + return right; + } + // x ^ x = 0 + if (same_operand) { + node->code = OC_CONST_INT; + node->data.const_int = 0; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + break; + + case OC_EQ: + // x == x = 1 (always true) + if (same_operand) { + node->code = OC_CONST_INT; + node->data.const_int = 1; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + break; + + case OC_NEQ: + // x != x = 0 (always false) + if (same_operand) { + node->code = OC_CONST_INT; + node->data.const_int = 0; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + break; + + case OC_LT: + // x < x = 0 (always false) + if (same_operand) { + node->code = OC_CONST_INT; + node->data.const_int = 0; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + break; + + case OC_GT: + // x > x = 0 (always false) + if (same_operand) { + node->code = OC_CONST_INT; + node->data.const_int = 0; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + break; + + case OC_LE: + // x <= x = 1 (always true) + if (same_operand) { + node->code = OC_CONST_INT; + node->data.const_int = 1; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + break; + + case OC_GE: + // x >= x = 1 (always true) + if (same_operand) { + node->code = OC_CONST_INT; + node->data.const_int = 1; + arrfree(node->out); node->out = NULL; + node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe); + return node; + } + break; + + case OC_AND: + // x && 0 = 0 + if (right_is_zero) { + free(node); + return right; + } + // 0 && x = 0 + if (left_is_zero) { + free(node); + return left; + } + // x && 1 = x (if x is boolean) + // 1 && x = x + if (left_is_one) { + free(node); + return right; + } + if (right_is_one) { + free(node); + return left; + } + break; + + case OC_OR: + // x || 1 = 1 + if (right_is_one) { + free(node); + return right; + } + // 1 || x = 1 + if (left_is_one) { + free(node); + return left; + } + // x || 0 = x + if (right_is_zero) { + free(node); + return left; + } + // 0 || x = x + if (left_is_zero) { + free(node); + return right; + } + break; + + default: + break; + } + + return node; +} + static ir_node *build_address(usize base, usize offset) { ir_node *addr = calloc(1, sizeof(ir_node)); addr->code = OC_ADDR; @@ -355,6 +691,27 @@ static ir_node *build_binary(ast_node *node) case OP_EQ: n->code = OC_EQ; break; + case OP_NEQ: + n->code = OC_NEQ; + break; + case OP_LT: + n->code = OC_LT; + break; + case OP_GT: + n->code = OC_GT; + break; + case OP_LE: + n->code = OC_LE; + break; + case OP_GE: + n->code = OC_GE; + break; + case OP_AND: + n->code = OC_AND; + break; + case OP_OR: + n->code = OC_OR; + break; default: break; } @@ -362,6 +719,7 @@ static ir_node *build_binary(ast_node *node) arrput(n->out, build_expression(node->expr.binary.right)); n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); const_fold(n); + n = peephole(n); ir_node *tmp = hmget(global_hash, *n); if (tmp) { free(n); @@ -415,6 +773,7 @@ static ir_node *build_unary(ast_node *node) break; } + // Constant folding for unary operations if (n->out && n->out[0]->code == OC_CONST_INT) { switch (n->code) { case OC_NEG: @@ -437,6 +796,13 @@ static ir_node *build_unary(ast_node *node) arrfree(n->out); n->out = NULL; } + // Peephole: double negation elimination --x => x + if (n->code == OC_NEG && n->out && n->out[0]->code == OC_NEG) { + ir_node *inner = n->out[0]->out[0]; + free(n); + return inner; + } + n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); ir_node *tmp = hmget(global_hash, *n); if (tmp) { @@ -447,6 +813,152 @@ static ir_node *build_unary(ast_node *node) return n; } +static ir_node *build_while(ast_node *node) +{ + // Save state before loop + ir_node *entry_control = current_control; + ir_node *entry_memory = current_memory; + + // Create loop header region - initially with just entry control + // Back edge will be added after processing the body + ir_node *loop = calloc(1, sizeof(ir_node)); + loop->code = OC_LOOP; + arrput(loop->out, entry_control); + // Placeholder for back edge - will be updated later + loop->id = stbds_hash_bytes(loop, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *loop, loop); + + // Create memory phi for the loop + ir_node *mem_phi = calloc(1, sizeof(ir_node)); + mem_phi->code = OC_PHI; + arrput(mem_phi->out, loop); + arrput(mem_phi->out, entry_memory); + // Placeholder for back edge memory - index 2 will be updated later + mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *mem_phi, mem_phi); + + // Create phi nodes for all variables in scope + // We need to track which phi corresponds to which variable + struct { char *key; ir_node *value; } *var_phis = NULL; + + for (int i = 0; i < arrlen(current_scope->data.symbol_tables); i++) { + symbol_table *table = current_scope->data.symbol_tables[i]; + for (int j = 0; j < shlen(table); j++) { + char *name = table[j].key; + struct symbol_def *def = table[j].value; + if (!def->is_lvalue) { + // Create phi for this variable + ir_node *var_phi = calloc(1, sizeof(ir_node)); + var_phi->code = OC_PHI; + arrput(var_phi->out, loop); + arrput(var_phi->out, def->node); + // Placeholder for back edge value + var_phi->id = stbds_hash_bytes(var_phi, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *var_phi, var_phi); + + // Update the variable to use the phi + struct symbol_def *new_def = calloc(1, sizeof(struct symbol_def)); + new_def->node = var_phi; + new_def->is_lvalue = false; + shput(current_scope->data.symbol_tables[i], name, new_def); + + // Track the phi for later update + shput(var_phis, name, var_phi); + } + } + } + + // Set current state to loop header + current_control = loop; + current_memory = mem_phi; + + // Build the condition expression + ir_node *condition = build_expression(node->expr.whle.condition); + + // Create if node for the loop condition + ir_node *if_node = calloc(1, sizeof(ir_node)); + if_node->code = OC_IF; + arrput(if_node->out, condition); + arrput(if_node->out, current_control); + if_node->id = stbds_hash_bytes(if_node, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *if_node, if_node); + + // Create projections for true (body) and false (exit) + ir_node *proj_body = calloc(1, sizeof(ir_node)); + proj_body->code = OC_PROJ; + arrput(proj_body->out, if_node); + proj_body->id = stbds_hash_bytes(proj_body, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *proj_body, proj_body); + + ir_node *proj_exit = calloc(1, sizeof(ir_node)); + proj_exit->code = OC_PROJ; + arrput(proj_exit->out, if_node); + proj_exit->id = stbds_hash_bytes(proj_exit, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *proj_exit, proj_exit); + + // Process the loop body + current_control = proj_body; + + ast_node *current = node->expr.whle.body; + while (current && current->type == NODE_UNIT) { + if (current->expr.unit_node.expr && current_control) { + build_expression(current->expr.unit_node.expr); + } + current = current->expr.unit_node.next; + } + + // After body - add back edge to loop header if control didn't terminate + if (current_control) { + // Add back edge control to loop region + arrput(loop->out, current_control); + loop->id = stbds_hash_bytes(loop, sizeof(ir_node), 0xcafebabe); + + // Add back edge memory to memory phi + arrput(mem_phi->out, current_memory); + mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe); + + // Update variable phis with back edge values + for (int i = 0; i < shlen(var_phis); i++) { + char *name = var_phis[i].key; + ir_node *phi = var_phis[i].value; + + // Get current value of variable after loop body + struct symbol_def *current_def = get_def(name); + if (current_def && current_def->node) { + arrput(phi->out, current_def->node); + phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe); + } + } + } + + // Restore phi values as current definitions for use after the loop + for (int i = 0; i < shlen(var_phis); i++) { + char *name = var_phis[i].key; + ir_node *phi = var_phis[i].value; + + // Find which scope table contains this variable and update it + for (int j = 0; j < arrlen(current_scope->data.symbol_tables); j++) { + if (shget(current_scope->data.symbol_tables[j], name)) { + struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); + def->node = phi; + def->is_lvalue = false; + shput(current_scope->data.symbol_tables[j], name, def); + break; + } + } + } + + // Clean up var_phis + shfree(var_phis); + + // Exit the loop - continue with false projection + current_control = proj_exit; + // Memory after loop is the memory phi (represents all possible memory states) + current_memory = mem_phi; + + return loop; +} + static ir_node *build_if(ast_node *node) { ir_node *condition = build_expression(node->expr.if_stmt.condition); @@ -477,7 +989,7 @@ static ir_node *build_if(ast_node *node) ast_node *current = node->expr.if_stmt.body; while (current && current->type == NODE_UNIT) { - if (current->expr.unit_node.expr) { + if (current->expr.unit_node.expr && current_control) { build_expression(current->expr.unit_node.expr); } current = current->expr.unit_node.next; @@ -492,7 +1004,7 @@ static ir_node *build_if(ast_node *node) current_control = proj_false; current = node->expr.if_stmt.otherwise; while (current && current->type == NODE_UNIT) { - if (current->expr.unit_node.expr) { + if (current->expr.unit_node.expr && current_control) { build_expression(current->expr.unit_node.expr); } current = current->expr.unit_node.next; @@ -501,7 +1013,30 @@ static ir_node *build_if(ast_node *node) ir_node *else_mem = current_memory; ir_node *else_control = current_control; - ir_node *region = calloc(1, sizeof(ir_node)); + // Handle control flow merging based on which branches terminated + ir_node *region = NULL; + + if (!then_control && !else_control) { + // Both branches returned - no merge point, code after if is unreachable + current_control = NULL; + current_scope = base_scope; + return NULL; + } else if (!then_control) { + // Only then branch returned - continue with else control + current_control = else_control; + current_memory = else_mem; + current_scope = else_scope; + return else_control; + } else if (!else_control) { + // Only else branch returned - continue with then control + current_control = then_control; + current_memory = then_mem; + current_scope = then_scope; + return then_control; + } + + // Both branches fall through - create merge region + region = calloc(1, sizeof(ir_node)); region->code = OC_REGION; arrput(region->out, then_control); arrput(region->out, else_control); @@ -529,7 +1064,7 @@ static ir_node *build_if(ast_node *node) symbol_table *base_table = current_scope->data.symbol_tables[i]; for (int j = 0; j < shlen(base_table); j++) { char *key = base_table[j].key; - + ir_node *found_then = NULL; symbol_table *t_table = then_scope->data.symbol_tables[i]; if (shget(t_table, key)->node) found_then = shget(t_table, key)->node; @@ -566,6 +1101,55 @@ static ir_node *build_if(ast_node *node) return region; } +static ir_node *build_call(ast_node *node) +{ + ir_node *call = calloc(1, sizeof(ir_node)); + call->code = OC_CALL; + call->data.call_name = node->expr.call.name; + + // Call inputs: control, memory, then arguments + arrput(call->out, current_control); + arrput(call->out, current_memory); + + // Build argument expressions + ast_node *param = node->expr.call.parameters; + while (param && param->type == NODE_UNIT) { + if (param->expr.unit_node.expr) { + ir_node *arg = build_expression(param->expr.unit_node.expr); + arrput(call->out, arg); + } + param = param->expr.unit_node.next; + } + + call->id = stbds_hash_bytes(call, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *call, call); + + // Create projection for new control + ir_node *call_ctrl = calloc(1, sizeof(ir_node)); + call_ctrl->code = OC_PROJ; + arrput(call_ctrl->out, call); + call_ctrl->id = stbds_hash_bytes(call_ctrl, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *call_ctrl, call_ctrl); + current_control = call_ctrl; + + // Create projection for new memory state + ir_node *call_mem = calloc(1, sizeof(ir_node)); + call_mem->code = OC_PROJ; + arrput(call_mem->out, call); + call_mem->id = stbds_hash_bytes(call_mem, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *call_mem, call_mem); + current_memory = call_mem; + + // Create projection for return value + ir_node *call_ret = calloc(1, sizeof(ir_node)); + call_ret->code = OC_PROJ; + arrput(call_ret->out, call); + call_ret->id = stbds_hash_bytes(call_ret, sizeof(ir_node), 0xcafebabe); + hmput(global_hash, *call_ret, call_ret); + + return call_ret; +} + static void build_return(ast_node *node) { ir_node *val = NULL; @@ -623,7 +1207,7 @@ static void finalize_function(void) ir_node *val_phi = calloc(1, sizeof(ir_node)); val_phi->code = OC_PHI; - //arrput(val_phi->out, region); + arrput(val_phi->out, region); for (int i=0; iout, current_func.return_values[i]); } @@ -688,7 +1272,7 @@ static ir_node *build_function(ast_node *node) } while (current && current->type == NODE_UNIT) { - if (current->expr.unit_node.expr) { + if (current->expr.unit_node.expr && current_control) { build_expression(current->expr.unit_node.expr); } current = current->expr.unit_node.next; @@ -745,8 +1329,12 @@ static ir_node *build_expression(ast_node *node) return n; case NODE_IDENTIFIER: struct symbol_def *def = get_def(node->expr.string.start); + if (!def) { + fprintf(stderr, "IR error: undefined identifier '%s'\n", node->expr.string.start); + return NULL; + } n = def->node; - + if (n && def->is_lvalue) { ir_node *addr_node = n; @@ -770,9 +1358,15 @@ static ir_node *build_expression(ast_node *node) case NODE_IF: n = build_if(node); break; + case NODE_WHILE: + n = build_while(node); + break; case NODE_RETURN: build_return(node); break; + case NODE_CALL: + n = build_call(node); + break; default: break; } @@ -803,6 +1397,13 @@ void ir_build(ast_node *ast) ir_node *expr = build_function(current->expr.unit_node.expr); arrput(graph->out, expr); hmput(global_hash, *expr, expr); + + // Run GCM on this function + ir_function *scheduled = gcm_schedule(expr); + if (scheduled) { + gcm_print_scheduled(scheduled); + printf("\n"); + } } current = current->expr.unit_node.next; } @@ -810,3 +1411,676 @@ void ir_build(ast_node *ast) print_graph(graph); printf("}\n"); } + +static int block_id_counter = 0; + +static basic_block *create_block(ir_node *control) +{ + basic_block *bb = calloc(1, sizeof(basic_block)); + bb->id = block_id_counter++; + bb->control = control; + bb->nodes = NULL; + bb->preds = NULL; + bb->succs = NULL; + bb->idom = NULL; + bb->dom_children = NULL; + bb->dom_depth = 0; + bb->loop_depth = 0; + bb->visited = false; + return bb; +} + +static void add_edge(basic_block *from, basic_block *to) +{ + arrput(from->succs, to); + arrput(to->preds, from); +} + +// Check if a node is a control node (defines a basic block boundary) +static bool is_control_node(ir_node *node) +{ + if (!node) return false; + switch (node->code) { + case OC_START: + case OC_REGION: + case OC_LOOP: + case OC_IF: + case OC_RETURN: + case OC_PROJ: // Projections from IF/CALL are control + return true; + default: + return false; + } +} + +// Check if a node is pinned (must stay in a specific block) +static bool is_pinned(ir_node *node) +{ + if (!node) return false; + switch (node->code) { + case OC_START: + case OC_REGION: + case OC_LOOP: + case OC_IF: + case OC_RETURN: + case OC_PHI: + case OC_PROJ: + case OC_STORE: + case OC_LOAD: + case OC_CALL: + return true; + default: + return false; + } +} + +// Map from control nodes to basic blocks +static struct { ir_node *key; basic_block *value; } *control_to_block = NULL; + +// Collect all nodes reachable from a function +static void collect_nodes(ir_node *node, ir_node ***all_nodes) +{ + if (!node || node->scheduled) return; + node->scheduled = true; // Use as visited marker temporarily + + arrput(*all_nodes, node); + + // Follow inputs (out array in this IR) + for (int i = 0; i < arrlen(node->out); i++) { + collect_nodes(node->out[i], all_nodes); + } +} + +// Build CFG from control nodes +static basic_block *build_cfg(ir_node *func_start, ir_function *func) +{ + ir_node **all_nodes = NULL; + + // Reset scheduled flags and collect all nodes + collect_nodes(func_start, &all_nodes); + + // Reset scheduled flags for actual scheduling later + for (int i = 0; i < arrlen(all_nodes); i++) { + all_nodes[i]->scheduled = false; + all_nodes[i]->pinned = is_pinned(all_nodes[i]); + all_nodes[i]->early = NULL; + all_nodes[i]->late = NULL; + all_nodes[i]->block = NULL; + } + + // Create blocks for control nodes + hmfree(control_to_block); + control_to_block = NULL; + + basic_block *entry = NULL; + + for (int i = 0; i < arrlen(all_nodes); i++) { + ir_node *node = all_nodes[i]; + + // Create block for START and control flow merge points + if (node->code == OC_START) { + basic_block *bb = create_block(node); + hmput(control_to_block, node, bb); + entry = bb; + node->block = bb; + arrput(func->blocks, bb); + } + // Projections from START are part of entry block + else if (node->code == OC_PROJ && node->out && arrlen(node->out) > 0) { + ir_node *parent = node->out[0]; + if (parent->code == OC_START) { + basic_block *bb = hmget(control_to_block, parent); + node->block = bb; + // Also add to hashmap so predecessors can find it + hmput(control_to_block, node, bb); + } + else if (parent->code == OC_IF || parent->code == OC_CALL) { + // Create new block for IF/CALL projections + basic_block *bb = create_block(node); + hmput(control_to_block, node, bb); + node->block = bb; + arrput(func->blocks, bb); + } + else if (parent->code == OC_LOOP) { + // Loop projections get their own blocks too + basic_block *bb = create_block(node); + hmput(control_to_block, node, bb); + node->block = bb; + arrput(func->blocks, bb); + } + } + else if (node->code == OC_REGION) { + basic_block *bb = create_block(node); + hmput(control_to_block, node, bb); + node->block = bb; + arrput(func->blocks, bb); + } + else if (node->code == OC_LOOP) { + basic_block *bb = create_block(node); + hmput(control_to_block, node, bb); + node->block = bb; + arrput(func->blocks, bb); + } + else if (node->code == OC_RETURN) { + basic_block *bb = create_block(node); + hmput(control_to_block, node, bb); + node->block = bb; + arrput(func->blocks, bb); + } + } + + // Build CFG edges + for (int i = 0; i < arrlen(all_nodes); i++) { + ir_node *node = all_nodes[i]; + basic_block *bb = node->block; + if (!bb) continue; + + // Connect based on control flow + if (node->code == OC_IF) { + // IF has projections as successors - find them + for (int j = 0; j < arrlen(all_nodes); j++) { + ir_node *other = all_nodes[j]; + if (other->code == OC_PROJ && other->out && arrlen(other->out) > 0) { + if (other->out[0] == node) { + basic_block *succ = hmget(control_to_block, other); + if (succ && succ != bb) { + add_edge(bb, succ); + } + } + } + } + } + else if (node->code == OC_REGION || node->code == OC_LOOP) { + // Region/Loop has control inputs as predecessors + for (int j = 0; j < arrlen(node->out); j++) { + ir_node *pred_ctrl = node->out[j]; + if (pred_ctrl) { + basic_block *pred = hmget(control_to_block, pred_ctrl); + if (pred && pred != bb) { + add_edge(pred, bb); + } + } + } + } + else if (node->code == OC_RETURN) { + // Return has control input + if (node->out && arrlen(node->out) > 0) { + ir_node *ctrl_in = node->out[0]; + basic_block *pred = hmget(control_to_block, ctrl_in); + if (pred && pred != bb) { + add_edge(pred, bb); + } + } + } + } + + // Pin PHI nodes to their region's block + for (int i = 0; i < arrlen(all_nodes); i++) { + ir_node *node = all_nodes[i]; + if (node->code == OC_PHI && node->out && arrlen(node->out) > 0) { + ir_node *region = node->out[0]; + node->block = hmget(control_to_block, region); + } + } + + // Pin IF nodes to their control input's block + for (int i = 0; i < arrlen(all_nodes); i++) { + ir_node *node = all_nodes[i]; + if (node->code == OC_IF && node->out && arrlen(node->out) > 1) { + ir_node *ctrl = node->out[1]; // Control is second input + node->block = hmget(control_to_block, ctrl); + } + } + + arrfree(all_nodes); + func->block_count = arrlen(func->blocks); + + return entry; +} + +// Compute dominators using simple iterative algorithm +static void compute_dominators(ir_function *func) +{ + if (!func->entry || func->block_count == 0) return; + + // Initialize: entry dominates itself + func->entry->idom = func->entry; + func->entry->dom_depth = 0; + + bool changed = true; + while (changed) { + changed = false; + + for (int i = 0; i < func->block_count; i++) { + basic_block *bb = func->blocks[i]; + if (bb == func->entry) continue; + + basic_block *new_idom = NULL; + + // Find first predecessor with computed idom + for (int j = 0; j < arrlen(bb->preds); j++) { + basic_block *pred = bb->preds[j]; + if (pred->idom) { + if (!new_idom) { + new_idom = pred; + } else { + // Intersect dominators + basic_block *a = pred; + basic_block *b = new_idom; + while (a != b) { + while (a && a->dom_depth > b->dom_depth) a = a->idom; + while (b && b->dom_depth > a->dom_depth) b = b->idom; + if (a != b) { + if (a) a = a->idom; + if (b) b = b->idom; + } + } + new_idom = a; + } + } + } + + if (new_idom && bb->idom != new_idom) { + bb->idom = new_idom; + bb->dom_depth = new_idom->dom_depth + 1; + changed = true; + } + } + } + + // Build dominator tree children + for (int i = 0; i < func->block_count; i++) { + basic_block *bb = func->blocks[i]; + if (bb->idom && bb->idom != bb) { + arrput(bb->idom->dom_children, bb); + } + } +} + +// Compute loop depths +static void compute_loop_depths(ir_function *func) +{ + // Simple approach: look for LOOP control nodes + for (int i = 0; i < func->block_count; i++) { + basic_block *bb = func->blocks[i]; + if (bb->control && bb->control->code == OC_LOOP) { + // Mark this block and dominated blocks as in a loop + bb->loop_depth = 1; + for (int j = 0; j < arrlen(bb->dom_children); j++) { + bb->dom_children[j]->loop_depth = bb->loop_depth; + } + } + } + + // Propagate loop depths through dominator tree + for (int i = 0; i < func->block_count; i++) { + basic_block *bb = func->blocks[i]; + if (bb->idom && bb->idom->loop_depth > bb->loop_depth) { + bb->loop_depth = bb->idom->loop_depth; + } + } +} + +// Schedule Early: place each node in earliest legal block +static void schedule_early(ir_node *node, basic_block *entry) +{ + if (!node || node->early) return; + + // Pinned nodes stay in their assigned block + if (node->pinned && node->block) { + node->early = node->block; + return; + } + + // Start with entry block + node->early = entry; + + // For each input, schedule it early and update our earliest block + for (int i = 0; i < arrlen(node->out); i++) { + ir_node *input = node->out[i]; + if (!input) continue; + + schedule_early(input, entry); + + // Our earliest block must be dominated by input's earliest block + if (input->early && input->early->dom_depth > node->early->dom_depth) { + node->early = input->early; + } + } +} + +// Find the Least Common Ancestor in dominator tree +static basic_block *dom_lca(basic_block *a, basic_block *b) +{ + if (!a) return b; + if (!b) return a; + + while (a != b) { + while (a && a->dom_depth > b->dom_depth) a = a->idom; + while (b && b->dom_depth > a->dom_depth) b = b->idom; + if (a != b) { + if (a) a = a->idom; + if (b) b = b->idom; + } + } + return a; +} + +// Find uses of a node +static void find_uses(ir_node *node, ir_node **all_nodes, int count, ir_node ***uses) +{ + for (int i = 0; i < count; i++) { + ir_node *other = all_nodes[i]; + if (other == node) continue; + + for (int j = 0; j < arrlen(other->out); j++) { + if (other->out[j] == node) { + arrput(*uses, other); + break; + } + } + } +} + +// Schedule Late: find latest legal block for each node +static void schedule_late(ir_node *node, ir_node **all_nodes, int count) +{ + if (!node || node->late || node->pinned) { + if (node && node->pinned && !node->late) { + node->late = node->block; + } + return; + } + + // Mark as being processed to prevent infinite recursion + // Use early as a sentinel if we're in progress + node->late = node->early; + if (!node->late) { + // Fallback - shouldn't happen if schedule_early was run + return; + } + + ir_node **uses = NULL; + find_uses(node, all_nodes, count, &uses); + + basic_block *lca = NULL; + + for (int i = 0; i < arrlen(uses); i++) { + ir_node *use = uses[i]; + + // Make sure use is scheduled (but avoid cycles) + if (!use->late) { + schedule_late(use, all_nodes, count); + } + + basic_block *use_block = use->early; + if (use->block) use_block = use->block; + else if (use->late) use_block = use->late; + + if (use_block) { + // For PHI nodes, use the predecessor block, not the PHI's block + if (use->code == OC_PHI) { + // Find which input we are + for (int j = 1; j < arrlen(use->out); j++) { + if (use->out[j] == node) { + // We're the j-th input, use j-th predecessor + ir_node *region = use->out[0]; + if (region && j-1 < arrlen(region->out)) { + ir_node *pred_ctrl = region->out[j-1]; + basic_block *pred = hmget(control_to_block, pred_ctrl); + if (pred) use_block = pred; + } + break; + } + } + } + + lca = dom_lca(lca, use_block); + } + } + + arrfree(uses); + + if (lca) { + node->late = lca; + } else { + node->late = node->early; + } +} + +// Select final block between early and late +static void select_block(ir_node *node) +{ + if (!node || node->block) return; // Already placed + + if (!node->early || !node->late) { + node->block = node->early ? node->early : node->late; + return; + } + + // Pick block with shallowest loop depth between early and late + basic_block *best = node->late; + basic_block *current = node->late; + + while (current && current->dom_depth >= node->early->dom_depth) { + if (current->loop_depth < best->loop_depth) { + best = current; + } + if (current == node->early) break; + current = current->idom; + } + + node->block = best; +} + +// Schedule nodes within a block (topological sort based on dependencies) +static void schedule_block(basic_block *bb, ir_node **all_nodes, int count) +{ + ir_node **ready = NULL; + ir_node **pending = NULL; + + // Collect nodes scheduled to this block + for (int i = 0; i < count; i++) { + ir_node *node = all_nodes[i]; + if (node->block == bb && !node->scheduled) { + arrput(pending, node); + } + } + + // Topological sort + while (arrlen(pending) > 0) { + // Find a node with all inputs satisfied + int ready_idx = -1; + for (int i = 0; i < arrlen(pending); i++) { + ir_node *node = pending[i]; + bool inputs_ready = true; + + for (int j = 0; j < arrlen(node->out); j++) { + ir_node *input = node->out[j]; + if (!input) continue; + + // Input is ready if it's in a different block or already scheduled + if (input->block == bb && !input->scheduled) { + inputs_ready = false; + break; + } + } + + if (inputs_ready) { + ready_idx = i; + break; + } + } + + if (ready_idx == -1) { + // Cycle detected or all remaining have unsatisfied deps - just pick first + ready_idx = 0; + } + + ir_node *node = pending[ready_idx]; + node->scheduled = true; + arrput(bb->nodes, node); + + // Remove from pending + arrdel(pending, ready_idx); + } + + arrfree(ready); + arrfree(pending); +} + +// Main GCM entry point +ir_function *gcm_schedule(ir_node *func_start) +{ + if (!func_start || func_start->code != OC_START) return NULL; + + block_id_counter = 0; + + ir_function *func = calloc(1, sizeof(ir_function)); + func->name = func_start->data.start_name; + func->blocks = NULL; + + // Build CFG + func->entry = build_cfg(func_start, func); + if (!func->entry) { + free(func); + return NULL; + } + + // Compute dominators + compute_dominators(func); + + // Compute loop depths + compute_loop_depths(func); + + // Collect all nodes for scheduling + ir_node **all_nodes = NULL; + for (int i = 0; i < hmlen(global_hash); i++) { + ir_node *node = global_hash[i].value; + // Check if this node belongs to this function + // (simplified: include all nodes for now) + arrput(all_nodes, node); + node->scheduled = false; + } + int node_count = arrlen(all_nodes); + + // Schedule Early + for (int i = 0; i < node_count; i++) { + schedule_early(all_nodes[i], func->entry); + } + + // Schedule Late + for (int i = 0; i < node_count; i++) { + schedule_late(all_nodes[i], all_nodes, node_count); + } + + // Select final blocks + for (int i = 0; i < node_count; i++) { + select_block(all_nodes[i]); + } + + // Reset scheduled flags for block scheduling + for (int i = 0; i < node_count; i++) { + all_nodes[i]->scheduled = false; + } + + // Schedule nodes within each block + for (int i = 0; i < func->block_count; i++) { + schedule_block(func->blocks[i], all_nodes, node_count); + } + + arrfree(all_nodes); + + return func; +} + +// Print scheduled IR for debugging +void gcm_print_scheduled(ir_function *func) +{ + if (!func) return; + + printf("Function: %s\n", func->name ? func->name : ""); + printf("Blocks: %d\n\n", func->block_count); + + for (int i = 0; i < func->block_count; i++) { + basic_block *bb = func->blocks[i]; + printf("BB%d (depth=%d, loop=%d):\n", bb->id, bb->dom_depth, bb->loop_depth); + + // Print predecessors + printf(" preds: "); + for (int j = 0; j < arrlen(bb->preds); j++) { + printf("BB%d ", bb->preds[j]->id); + } + printf("\n"); + + // Print successors + printf(" succs: "); + for (int j = 0; j < arrlen(bb->succs); j++) { + printf("BB%d ", bb->succs[j]->id); + } + printf("\n"); + + // Print idom + if (bb->idom && bb->idom != bb) { + printf(" idom: BB%d\n", bb->idom->id); + } + + // Print scheduled nodes + printf(" instructions:\n"); + for (int j = 0; j < arrlen(bb->nodes); j++) { + ir_node *node = bb->nodes[j]; + printf(" [%ld] ", node->id); + switch (node->code) { + case OC_START: printf("START %s", node->data.start_name); break; + case OC_ADD: printf("ADD"); break; + case OC_SUB: printf("SUB"); break; + case OC_MUL: printf("MUL"); break; + case OC_DIV: printf("DIV"); break; + case OC_MOD: printf("MOD"); break; + case OC_BAND: printf("AND"); break; + case OC_BOR: printf("OR"); break; + case OC_BXOR: printf("XOR"); break; + case OC_NEG: printf("NEG"); break; + case OC_EQ: printf("EQ"); break; + case OC_NEQ: printf("NEQ"); break; + case OC_LT: printf("LT"); break; + case OC_GT: printf("GT"); break; + case OC_LE: printf("LE"); break; + case OC_GE: printf("GE"); break; + case OC_AND: printf("LAND"); break; + case OC_OR: printf("LOR"); break; + case OC_CONST_INT: printf("CONST %ld", node->data.const_int); break; + case OC_CONST_FLOAT: printf("CONST %f", node->data.const_float); break; + case OC_VOID: printf("VOID"); break; + case OC_FRAME_PTR: printf("FRAME_PTR"); break; + case OC_ADDR: printf("ADDR"); break; + case OC_STORE: printf("STORE"); break; + case OC_LOAD: printf("LOAD"); break; + case OC_REGION: printf("REGION"); break; + case OC_PHI: printf("PHI"); break; + case OC_IF: printf("IF"); break; + case OC_PROJ: printf("PROJ"); break; + case OC_LOOP: printf("LOOP"); break; + case OC_CALL: printf("CALL %s", node->data.call_name); break; + case OC_RETURN: printf("RETURN"); break; + default: printf("OP_%d", node->code); break; + } + + // Print inputs + if (arrlen(node->out) > 0) { + printf(" ("); + for (int k = 0; k < arrlen(node->out); k++) { + if (k > 0) printf(", "); + if (node->out[k]) { + printf("%ld", node->out[k]->id); + } else { + printf("null"); + } + } + printf(")"); + } + printf("\n"); + } + printf("\n"); + } +} diff --git a/ir.h b/ir.h index bfd684f..7b7dc4d 100644 --- a/ir.h +++ b/ir.h @@ -5,6 +5,8 @@ #include "parser.h" struct _ir_node; +struct _basic_block; + struct symbol_def { struct _ir_node *node; bool is_lvalue; @@ -12,6 +14,28 @@ struct symbol_def { typedef struct { char *key; struct symbol_def *value; } symbol_table; +// Basic block for CFG representation +typedef struct _basic_block { + int id; + struct _ir_node *control; // Control node that starts this block (region, loop, proj, start) + struct _ir_node **nodes; // Scheduled nodes in this block (stb_ds array) + struct _basic_block **preds; // Predecessor blocks (stb_ds array) + struct _basic_block **succs; // Successor blocks (stb_ds array) + struct _basic_block *idom; // Immediate dominator + struct _basic_block **dom_children; // Children in dominator tree (stb_ds array) + int dom_depth; // Depth in dominator tree + int loop_depth; // Loop nesting depth (for GCM optimization) + bool visited; // For graph traversals +} basic_block; + +// Function representation after GCM +typedef struct { + char *name; + basic_block *entry; // Entry block + basic_block **blocks; // All blocks in RPO order (stb_ds array) + int block_count; +} ir_function; + typedef enum { OC_START, OC_ADD, @@ -24,6 +48,13 @@ typedef enum { OC_BXOR, OC_NEG, OC_EQ, + OC_NEQ, + OC_LT, + OC_GT, + OC_LE, + OC_GE, + OC_AND, + OC_OR, OC_CONST_INT, OC_CONST_FLOAT, @@ -40,6 +71,9 @@ typedef enum { OC_IF, OC_PROJ, + OC_LOOP, + + OC_CALL, OC_STOP, OC_RETURN, @@ -57,9 +91,20 @@ typedef struct _ir_node { f64 const_float; symbol_table **symbol_tables; char *start_name; + char *call_name; } data; + // GCM scheduling fields + struct _basic_block *early; // Earliest legal block + struct _basic_block *late; // Latest legal block + struct _basic_block *block; // Final scheduled block + bool pinned; // True if node must stay in its block (control nodes, phi, etc.) + bool scheduled; // True if already scheduled } ir_node; void ir_build(ast_node *ast); +// Global Code Motion and Scheduling +ir_function *gcm_schedule(ir_node *func_start); +void gcm_print_scheduled(ir_function *func); + #endif diff --git a/sema.c b/sema.c index 9f2033d..0619b43 100644 --- a/sema.c +++ b/sema.c @@ -558,11 +558,20 @@ static type *get_expression_type(sema *s, ast_node *node) node->expr_type = t; return t; case NODE_CALL: - prot = shget(prototypes, intern_string(s, node->expr.call.name, node->expr.call.name_len)); + node->expr.call.name = intern_string(s, node->expr.call.name, node->expr.call.name_len); + prot = shget(prototypes, node->expr.call.name); if (!prot) { error(node, "unknown function."); return NULL; } + // Process call arguments + ast_node *arg = node->expr.call.parameters; + while (arg && arg->type == NODE_UNIT) { + if (arg->expr.unit_node.expr) { + get_expression_type(s, arg->expr.unit_node.expr); + } + arg = arg->expr.unit_node.next; + } t = prot->type; node->expr_type = t; return t; diff --git a/test.l b/test.l index 6c3aec2..7a3f8c9 100644 --- a/test.l +++ b/test.l @@ -1,12 +1,8 @@ -u32 main(u32 b) +u32 main(u32 n) { - u32 a = 4; - //return a; - if (b == 3) { - return 3; - } else { - return 4; + i32 i = 0; + loop while i < 10 { + i = i + 1; } - - return a; + return i; }