lc/ir.c

2086 lines
52 KiB
C

#include "ir.h"
#include <stdlib.h>
#include <stdio.h>
#include "stb_ds.h"
#include "sema.h"
struct { ir_node key; ir_node *value; } *global_hash = NULL;
static ir_node *graph;
static ir_node *current_memory;
static ir_node *current_control;
static usize current_stack = 0;
static ir_node *current_scope = NULL;
static ir_node *build_expression(ast_node *node);
static struct {
ir_node **return_controls;
ir_node **return_memories;
ir_node **return_values;
} current_func = {0};
static void node_name(ir_node *node)
{
if (!node) {
printf("null [label=\"NULL\", style=filled, fillcolor=red]\n");
return;
}
printf("%ld ", node->id);
switch (node->code) {
case OC_START:
printf("[label=\"%s\", style=filled, color=orange]\n", node->data.start_name);
break;
case OC_RETURN:
printf("[label=\"return\", style=filled, color=orange]\n");
break;
case OC_ADD:
printf("[label=\"+\"]\n");
break;
case OC_NEG:
case OC_SUB:
printf("[label=\"-\"]\n");
break;
case OC_DIV:
printf("[label=\"/\"]\n");
break;
case OC_MUL:
printf("[label=\"*\"]\n");
break;
case OC_MOD:
printf("[label=\"%%\"]\n");
break;
case OC_BAND:
printf("[label=\"&\"]\n");
break;
case OC_BOR:
printf("[label=\"|\"]\n");
break;
case OC_BXOR:
printf("[label=\"^\"]\n");
break;
case OC_EQ:
printf("[label=\"==\"]\n");
break;
case OC_NEQ:
printf("[label=\"!=\"]\n");
break;
case OC_LT:
printf("[label=\"<\"]\n");
break;
case OC_GT:
printf("[label=\">\"]\n");
break;
case OC_LE:
printf("[label=\"<=\"]\n");
break;
case OC_GE:
printf("[label=\">=\"]\n");
break;
case OC_AND:
printf("[label=\"&&\"]\n");
break;
case OC_OR:
printf("[label=\"||\"]\n");
break;
case OC_CONST_INT:
printf("[label=\"%ld\"]\n", node->data.const_int);
break;
case OC_CONST_FLOAT:
printf("[label=\"%f\"]\n", node->data.const_float);
break;
case OC_FRAME_PTR:
printf("[label=\"frame_ptr\"]\n");
break;
case OC_STORE:
printf("[label=\"store\", shape=box]\n");
break;
case OC_LOAD:
printf("[label=\"load\", shape=box]\n");
break;
case OC_ADDR:
printf("[label=\"addr\"]\n");
break;
case OC_REGION:
printf("[label=\"region\", shape=diamond, style=filled, color=green]\n");
break;
case OC_PHI:
printf("[label=\"phi\", shape=triangle]\n");
break;
case OC_IF:
printf("[label=\"if\", shape=diamond, style=filled, color=lightblue]\n");
break;
case OC_PROJ:
printf("[label=\"proj\", shape=diamond, style=filled, color=cyan]\n");
break;
case OC_CALL:
printf("[label=\"call %s\", shape=box, style=filled, color=yellow]\n", node->data.call_name);
break;
case OC_LOOP:
printf("[label=\"loop\", shape=diamond, style=filled, color=purple]\n");
break;
default:
printf("[label=\"%d\"]\n", node->code);
break;
}
}
static void print_graph(ir_node *node)
{
for (int i = 0; i < hmlen(global_hash); i++) {
ir_node *node = global_hash[i].value;
node_name(node);
for (int j = 0; j < arrlen(node->out); j++) {
if (node->out[j]) {
node_name(node->out[j]);
printf("%ld->%ld\n", node->out[j]->id, node->id);
}
}
}
}
static void push_scope(void)
{
arrput(current_scope->data.symbol_tables, NULL);
}
static struct symbol_def *get_def(char *name)
{
for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) {
struct symbol_def *def = shget(current_scope->data.symbol_tables[i], name);
if (def) return def;
}
return NULL;
}
static void set_def(char *name, ir_node *node, bool lvalue)
{
for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) {
if (shget(current_scope->data.symbol_tables[i], name)) {
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->is_lvalue = lvalue;
def->node = node;
shput(current_scope->data.symbol_tables[i], name, def);
return;
}
}
int index = arrlen(current_scope->data.symbol_tables) - 1;
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->is_lvalue = lvalue;
def->node = node;
shput(current_scope->data.symbol_tables[index], name, def);
}
static ir_node *copy_scope(ir_node *src)
{
ir_node *dst = calloc(1, sizeof(ir_node));
dst->code = OC_SCOPE;
for (int i=0; i < arrlen(src->data.symbol_tables); i++) {
arrput(dst->data.symbol_tables, NULL);
symbol_table *src_table = src->data.symbol_tables[i];
for (int j=0; j < shlen(src_table); j++) {
shput(dst->data.symbol_tables[i], src_table[j].key, src_table[j].value);
}
}
return dst;
}
static void const_fold(ir_node *binary)
{
ir_node *left = binary->out[0];
ir_node *right = binary->out[1];
if (left->code == OC_CONST_INT && right->code == OC_CONST_INT) {
switch (binary->code) {
case OC_ADD:
binary->data.const_int = left->data.const_int + right->data.const_int;
break;
case OC_SUB:
binary->data.const_int = left->data.const_int - right->data.const_int;
break;
case OC_MUL:
binary->data.const_int = left->data.const_int * right->data.const_int;
break;
case OC_DIV:
if (right->data.const_int != 0)
binary->data.const_int = left->data.const_int / right->data.const_int;
break;
case OC_MOD:
if (right->data.const_int != 0)
binary->data.const_int = left->data.const_int % right->data.const_int;
break;
case OC_BOR:
binary->data.const_int = left->data.const_int | right->data.const_int;
break;
case OC_BAND:
binary->data.const_int = left->data.const_int & right->data.const_int;
break;
case OC_BXOR:
binary->data.const_int = left->data.const_int ^ right->data.const_int;
break;
case OC_EQ:
binary->data.const_int = left->data.const_int == right->data.const_int;
break;
case OC_NEQ:
binary->data.const_int = left->data.const_int != right->data.const_int;
break;
case OC_LT:
binary->data.const_int = left->data.const_int < right->data.const_int;
break;
case OC_GT:
binary->data.const_int = left->data.const_int > right->data.const_int;
break;
case OC_LE:
binary->data.const_int = left->data.const_int <= right->data.const_int;
break;
case OC_GE:
binary->data.const_int = left->data.const_int >= right->data.const_int;
break;
case OC_AND:
binary->data.const_int = left->data.const_int && right->data.const_int;
break;
case OC_OR:
binary->data.const_int = left->data.const_int || right->data.const_int;
break;
default:
return;
}
binary->code = OC_CONST_INT;
arrfree(binary->out); binary->out = NULL;
arrfree(binary->in); binary->in = NULL;
binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe);
}
if (left->code == OC_CONST_FLOAT && right->code == OC_CONST_FLOAT) {
switch (binary->code) {
case OC_ADD:
binary->data.const_float = left->data.const_float + right->data.const_float;
break;
case OC_SUB:
binary->data.const_float = left->data.const_float - right->data.const_float;
break;
case OC_MUL:
binary->data.const_float = left->data.const_float * right->data.const_float;
break;
case OC_DIV:
if (right->data.const_float != 0.0f)
binary->data.const_float = left->data.const_float / right->data.const_float;
break;
default:
return;
}
binary->code = OC_CONST_FLOAT;
arrfree(binary->out); binary->out = NULL;
arrfree(binary->in); binary->in = NULL;
binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe);
}
}
static ir_node *peephole(ir_node *node)
{
if (!node || !node->out || arrlen(node->out) < 2)
return node;
ir_node *left = node->out[0];
ir_node *right = node->out[1];
bool left_is_zero = (left->code == OC_CONST_INT && left->data.const_int == 0);
bool right_is_zero = (right->code == OC_CONST_INT && right->data.const_int == 0);
bool left_is_one = (left->code == OC_CONST_INT && left->data.const_int == 1);
bool right_is_one = (right->code == OC_CONST_INT && right->data.const_int == 1);
bool same_operand = (left->id == right->id);
switch (node->code) {
case OC_ADD:
// x + 0 = x
if (right_is_zero) {
free(node);
return left;
}
// 0 + x = x
if (left_is_zero) {
free(node);
return right;
}
break;
case OC_SUB:
// x - 0 = x
if (right_is_zero) {
free(node);
return left;
}
// x - x = 0
if (same_operand) {
node->code = OC_CONST_INT;
node->data.const_int = 0;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
break;
case OC_MUL:
// x * 0 = 0
if (right_is_zero) {
free(node);
return right;
}
// 0 * x = 0
if (left_is_zero) {
free(node);
return left;
}
// x * 1 = x
if (right_is_one) {
free(node);
return left;
}
// 1 * x = x
if (left_is_one) {
free(node);
return right;
}
break;
case OC_DIV:
// x / 1 = x
if (right_is_one) {
free(node);
return left;
}
// 0 / x = 0 (when x != 0, but we assume no div by zero)
if (left_is_zero && !right_is_zero) {
free(node);
return left;
}
// x / x = 1 (assuming x != 0)
if (same_operand) {
node->code = OC_CONST_INT;
node->data.const_int = 1;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
break;
case OC_MOD:
// 0 % x = 0
if (left_is_zero) {
free(node);
return left;
}
// x % 1 = 0
if (right_is_one) {
node->code = OC_CONST_INT;
node->data.const_int = 0;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
// x % x = 0
if (same_operand) {
node->code = OC_CONST_INT;
node->data.const_int = 0;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
break;
case OC_BOR:
// x | 0 = x
if (right_is_zero) {
free(node);
return left;
}
// 0 | x = x
if (left_is_zero) {
free(node);
return right;
}
// x | x = x
if (same_operand) {
free(node);
return left;
}
break;
case OC_BAND:
// x & 0 = 0
if (right_is_zero) {
free(node);
return right;
}
// 0 & x = 0
if (left_is_zero) {
free(node);
return left;
}
// x & x = x
if (same_operand) {
free(node);
return left;
}
break;
case OC_BXOR:
// x ^ 0 = x
if (right_is_zero) {
free(node);
return left;
}
// 0 ^ x = x
if (left_is_zero) {
free(node);
return right;
}
// x ^ x = 0
if (same_operand) {
node->code = OC_CONST_INT;
node->data.const_int = 0;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
break;
case OC_EQ:
// x == x = 1 (always true)
if (same_operand) {
node->code = OC_CONST_INT;
node->data.const_int = 1;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
break;
case OC_NEQ:
// x != x = 0 (always false)
if (same_operand) {
node->code = OC_CONST_INT;
node->data.const_int = 0;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
break;
case OC_LT:
// x < x = 0 (always false)
if (same_operand) {
node->code = OC_CONST_INT;
node->data.const_int = 0;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
break;
case OC_GT:
// x > x = 0 (always false)
if (same_operand) {
node->code = OC_CONST_INT;
node->data.const_int = 0;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
break;
case OC_LE:
// x <= x = 1 (always true)
if (same_operand) {
node->code = OC_CONST_INT;
node->data.const_int = 1;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
break;
case OC_GE:
// x >= x = 1 (always true)
if (same_operand) {
node->code = OC_CONST_INT;
node->data.const_int = 1;
arrfree(node->out); node->out = NULL;
node->id = stbds_hash_bytes(node, sizeof(ir_node), 0xcafebabe);
return node;
}
break;
case OC_AND:
// x && 0 = 0
if (right_is_zero) {
free(node);
return right;
}
// 0 && x = 0
if (left_is_zero) {
free(node);
return left;
}
// x && 1 = x (if x is boolean)
// 1 && x = x
if (left_is_one) {
free(node);
return right;
}
if (right_is_one) {
free(node);
return left;
}
break;
case OC_OR:
// x || 1 = 1
if (right_is_one) {
free(node);
return right;
}
// 1 || x = 1
if (left_is_one) {
free(node);
return left;
}
// x || 0 = x
if (right_is_zero) {
free(node);
return left;
}
// 0 || x = x
if (left_is_zero) {
free(node);
return right;
}
break;
default:
break;
}
return node;
}
static ir_node *build_address(usize base, usize offset) {
ir_node *addr = calloc(1, sizeof(ir_node));
addr->code = OC_ADDR;
ir_node *base_node = calloc(1, sizeof(ir_node));
if (base == -1) {
base_node->code = OC_FRAME_PTR;
base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe);
} else {
base_node->code = OC_CONST_INT;
base_node->data.const_int = base;
base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe);
}
ir_node *offset_node = calloc(1, sizeof(ir_node));
offset_node->code = OC_CONST_INT;
offset_node->data.const_int = offset;
offset_node->id = stbds_hash_bytes(offset_node, sizeof(ir_node), 0xcafebabe);
arrput(addr->out, base_node);
arrput(addr->out, offset_node);
addr->id = stbds_hash_bytes(addr, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *addr);
if (tmp) {
free(addr);
return tmp;
}
return addr;
}
static ir_node *build_assign_ptr(ast_node *binary)
{
ir_node *val_node = build_expression(binary->expr.binary.right);
char *var_name = binary->expr.binary.left->expr.string.start;
ir_node *existing_def = get_def(var_name)->node;
ir_node *store = calloc(1, sizeof(ir_node));
store->code = OC_STORE;
arrput(store->out, current_control);
arrput(store->out, current_memory);
arrput(store->out, existing_def);
arrput(store->out, val_node);
store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *store, store);
current_memory = store;
return val_node;
}
static ir_node *build_assign(ast_node *binary)
{
ir_node *val_node = build_expression(binary->expr.binary.right);
char *var_name = binary->expr.binary.left->expr.string.start;
struct symbol_def *def = get_def(var_name);
if (def && def->is_lvalue) {
ir_node *existing_def = def->node;
ir_node *store = calloc(1, sizeof(ir_node));
store->code = OC_STORE;
arrput(store->out, current_control);
arrput(store->out, current_memory);
arrput(store->out, existing_def);
arrput(store->out, val_node);
store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *store, store);
current_memory = store;
return val_node;
}
set_def(var_name, val_node, false);
return val_node;
}
static ir_node *build_binary(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
switch (node->expr.binary.operator) {
case OP_ASSIGN:
free(n);
return build_assign(node);
case OP_ASSIGN_PTR:
free(n);
return build_assign_ptr(node);
case OP_PLUS:
n->code = OC_ADD;
break;
case OP_MINUS:
n->code = OC_SUB;
break;
case OP_DIV:
n->code = OC_DIV;
break;
case OP_MUL:
n->code = OC_MUL;
break;
case OP_MOD:
n->code = OC_MOD;
break;
case OP_BOR:
n->code = OC_BOR;
break;
case OP_BAND:
n->code = OC_BAND;
break;
case OP_BXOR:
n->code = OC_BXOR;
break;
case OP_EQ:
n->code = OC_EQ;
break;
case OP_NEQ:
n->code = OC_NEQ;
break;
case OP_LT:
n->code = OC_LT;
break;
case OP_GT:
n->code = OC_GT;
break;
case OP_LE:
n->code = OC_LE;
break;
case OP_GE:
n->code = OC_GE;
break;
case OP_AND:
n->code = OC_AND;
break;
case OP_OR:
n->code = OC_OR;
break;
default:
break;
}
arrput(n->out, build_expression(node->expr.binary.left));
arrput(n->out, build_expression(node->expr.binary.right));
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
const_fold(n);
n = peephole(n);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_load(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
n->code = OC_LOAD;
arrput(n->out, current_memory);
arrput(n->out, build_expression(node));
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabebabecafe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_unary(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
switch (node->expr.unary.operator) {
case UOP_MINUS:
n->code = OC_NEG;
arrput(n->out, build_expression(node->expr.unary.right));
break;
case UOP_REF:
free(n);
if (node->expr.unary.right->type == NODE_IDENTIFIER) {
struct symbol_def *def = get_def(node->expr.unary.right->expr.string.start);
if (def) {
return def->node;
}
}
return build_expression(node->expr.unary.right);
case UOP_DEREF:
free(n);
return build_load(node->expr.unary.right);
default:
break;
}
// Constant folding for unary operations
if (n->out && n->out[0]->code == OC_CONST_INT) {
switch (n->code) {
case OC_NEG:
n->data.const_int = -(n->out[0]->data.const_int);
break;
default:
break;
}
n->code = OC_CONST_INT;
arrfree(n->out); n->out = NULL;
} else if (n->out && n->out[0]->code == OC_CONST_FLOAT) {
switch (n->code) {
case OC_NEG:
n->data.const_float = -(n->out[0]->data.const_float);
break;
default:
break;
}
n->code = OC_CONST_FLOAT;
arrfree(n->out); n->out = NULL;
}
// Peephole: double negation elimination --x => x
if (n->code == OC_NEG && n->out && n->out[0]->code == OC_NEG) {
ir_node *inner = n->out[0]->out[0];
free(n);
return inner;
}
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_while(ast_node *node)
{
// Save state before loop
ir_node *entry_control = current_control;
ir_node *entry_memory = current_memory;
// Create loop header region - initially with just entry control
// Back edge will be added after processing the body
ir_node *loop = calloc(1, sizeof(ir_node));
loop->code = OC_LOOP;
arrput(loop->out, entry_control);
// Placeholder for back edge - will be updated later
loop->id = stbds_hash_bytes(loop, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *loop, loop);
// Create memory phi for the loop
ir_node *mem_phi = calloc(1, sizeof(ir_node));
mem_phi->code = OC_PHI;
arrput(mem_phi->out, loop);
arrput(mem_phi->out, entry_memory);
// Placeholder for back edge memory - index 2 will be updated later
mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *mem_phi, mem_phi);
// Create phi nodes for all variables in scope
// We need to track which phi corresponds to which variable
struct { char *key; ir_node *value; } *var_phis = NULL;
for (int i = 0; i < arrlen(current_scope->data.symbol_tables); i++) {
symbol_table *table = current_scope->data.symbol_tables[i];
for (int j = 0; j < shlen(table); j++) {
char *name = table[j].key;
struct symbol_def *def = table[j].value;
if (!def->is_lvalue) {
// Create phi for this variable
ir_node *var_phi = calloc(1, sizeof(ir_node));
var_phi->code = OC_PHI;
arrput(var_phi->out, loop);
arrput(var_phi->out, def->node);
// Placeholder for back edge value
var_phi->id = stbds_hash_bytes(var_phi, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *var_phi, var_phi);
// Update the variable to use the phi
struct symbol_def *new_def = calloc(1, sizeof(struct symbol_def));
new_def->node = var_phi;
new_def->is_lvalue = false;
shput(current_scope->data.symbol_tables[i], name, new_def);
// Track the phi for later update
shput(var_phis, name, var_phi);
}
}
}
// Set current state to loop header
current_control = loop;
current_memory = mem_phi;
// Build the condition expression
ir_node *condition = build_expression(node->expr.whle.condition);
// Create if node for the loop condition
ir_node *if_node = calloc(1, sizeof(ir_node));
if_node->code = OC_IF;
arrput(if_node->out, condition);
arrput(if_node->out, current_control);
if_node->id = stbds_hash_bytes(if_node, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *if_node, if_node);
// Create projections for true (body) and false (exit)
ir_node *proj_body = calloc(1, sizeof(ir_node));
proj_body->code = OC_PROJ;
arrput(proj_body->out, if_node);
proj_body->id = stbds_hash_bytes(proj_body, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *proj_body, proj_body);
ir_node *proj_exit = calloc(1, sizeof(ir_node));
proj_exit->code = OC_PROJ;
arrput(proj_exit->out, if_node);
proj_exit->id = stbds_hash_bytes(proj_exit, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *proj_exit, proj_exit);
// Process the loop body
current_control = proj_body;
ast_node *current = node->expr.whle.body;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr && current_control) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
// After body - add back edge to loop header if control didn't terminate
if (current_control) {
// Add back edge control to loop region
arrput(loop->out, current_control);
loop->id = stbds_hash_bytes(loop, sizeof(ir_node), 0xcafebabe);
// Add back edge memory to memory phi
arrput(mem_phi->out, current_memory);
mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe);
// Update variable phis with back edge values
for (int i = 0; i < shlen(var_phis); i++) {
char *name = var_phis[i].key;
ir_node *phi = var_phis[i].value;
// Get current value of variable after loop body
struct symbol_def *current_def = get_def(name);
if (current_def && current_def->node) {
arrput(phi->out, current_def->node);
phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe);
}
}
}
// Restore phi values as current definitions for use after the loop
for (int i = 0; i < shlen(var_phis); i++) {
char *name = var_phis[i].key;
ir_node *phi = var_phis[i].value;
// Find which scope table contains this variable and update it
for (int j = 0; j < arrlen(current_scope->data.symbol_tables); j++) {
if (shget(current_scope->data.symbol_tables[j], name)) {
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->node = phi;
def->is_lvalue = false;
shput(current_scope->data.symbol_tables[j], name, def);
break;
}
}
}
// Clean up var_phis
shfree(var_phis);
// Exit the loop - continue with false projection
current_control = proj_exit;
// Memory after loop is the memory phi (represents all possible memory states)
current_memory = mem_phi;
return loop;
}
static ir_node *build_if(ast_node *node)
{
ir_node *condition = build_expression(node->expr.if_stmt.condition);
ir_node *if_node = calloc(1, sizeof(ir_node));
if_node->code = OC_IF;
arrput(if_node->out, condition);
arrput(if_node->out, current_control);
if_node->id = stbds_hash_bytes(if_node, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *if_node, if_node);
ir_node *proj_true = calloc(1, sizeof(ir_node));
proj_true->code = OC_PROJ;
arrput(proj_true->out, if_node);
proj_true->id = stbds_hash_bytes(proj_true, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *proj_true, proj_true);
ir_node *proj_false = calloc(1, sizeof(ir_node));
proj_false->code = OC_PROJ;
arrput(proj_false->out, if_node);
proj_false->id = stbds_hash_bytes(proj_false, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *proj_false, proj_false);
ir_node *base_scope = copy_scope(current_scope);
ir_node *base_mem = current_memory;
current_control = proj_true;
ast_node *current = node->expr.if_stmt.body;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr && current_control) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
ir_node *then_scope = current_scope;
ir_node *then_mem = current_memory;
ir_node *then_control = current_control;
current_scope = copy_scope(base_scope);
current_memory = base_mem;
current_control = proj_false;
current = node->expr.if_stmt.otherwise;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr && current_control) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
ir_node *else_scope = current_scope;
ir_node *else_mem = current_memory;
ir_node *else_control = current_control;
// Handle control flow merging based on which branches terminated
ir_node *region = NULL;
if (!then_control && !else_control) {
// Both branches returned - no merge point, code after if is unreachable
current_control = NULL;
current_scope = base_scope;
return NULL;
} else if (!then_control) {
// Only then branch returned - continue with else control
current_control = else_control;
current_memory = else_mem;
current_scope = else_scope;
return else_control;
} else if (!else_control) {
// Only else branch returned - continue with then control
current_control = then_control;
current_memory = then_mem;
current_scope = then_scope;
return then_control;
}
// Both branches fall through - create merge region
region = calloc(1, sizeof(ir_node));
region->code = OC_REGION;
arrput(region->out, then_control);
arrput(region->out, else_control);
region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *region, region);
if (then_mem->id != else_mem->id) {
ir_node *phi = calloc(1, sizeof(ir_node));
phi->code = OC_PHI;
arrput(phi->out, region);
arrput(phi->out, then_mem);
arrput(phi->out, else_mem);
phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *phi, phi);
current_memory = phi;
} else {
current_memory = then_mem;
}
current_scope = base_scope;
for (int i = 0; i < arrlen(current_scope->data.symbol_tables); i++) {
symbol_table *base_table = current_scope->data.symbol_tables[i];
for (int j = 0; j < shlen(base_table); j++) {
char *key = base_table[j].key;
ir_node *found_then = NULL;
symbol_table *t_table = then_scope->data.symbol_tables[i];
if (shget(t_table, key)->node) found_then = shget(t_table, key)->node;
else found_then = base_table[j].value->node;
ir_node *found_else = NULL;
symbol_table *e_table = else_scope->data.symbol_tables[i];
if (shget(e_table, key)->node) found_else = shget(e_table, key)->node;
else found_else = base_table[j].value->node;
if (found_then->id != found_else->id) {
ir_node *phi = calloc(1, sizeof(ir_node));
phi->code = OC_PHI;
arrput(phi->out, region);
arrput(phi->out, found_then);
arrput(phi->out, found_else);
phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe);
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->node = phi;
def->is_lvalue = false;
shput(current_scope->data.symbol_tables[i], key, def);
hmput(global_hash, *phi, phi);
} else {
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->node = found_then;
def->is_lvalue = false;
shput(current_scope->data.symbol_tables[i], key, def);
}
}
}
current_control = region;
return region;
}
static ir_node *build_call(ast_node *node)
{
ir_node *call = calloc(1, sizeof(ir_node));
call->code = OC_CALL;
call->data.call_name = node->expr.call.name;
// Call inputs: control, memory, then arguments
arrput(call->out, current_control);
arrput(call->out, current_memory);
// Build argument expressions
ast_node *param = node->expr.call.parameters;
while (param && param->type == NODE_UNIT) {
if (param->expr.unit_node.expr) {
ir_node *arg = build_expression(param->expr.unit_node.expr);
arrput(call->out, arg);
}
param = param->expr.unit_node.next;
}
call->id = stbds_hash_bytes(call, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *call, call);
// Create projection for new control
ir_node *call_ctrl = calloc(1, sizeof(ir_node));
call_ctrl->code = OC_PROJ;
arrput(call_ctrl->out, call);
call_ctrl->id = stbds_hash_bytes(call_ctrl, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *call_ctrl, call_ctrl);
current_control = call_ctrl;
// Create projection for new memory state
ir_node *call_mem = calloc(1, sizeof(ir_node));
call_mem->code = OC_PROJ;
arrput(call_mem->out, call);
call_mem->id = stbds_hash_bytes(call_mem, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *call_mem, call_mem);
current_memory = call_mem;
// Create projection for return value
ir_node *call_ret = calloc(1, sizeof(ir_node));
call_ret->code = OC_PROJ;
arrput(call_ret->out, call);
call_ret->id = stbds_hash_bytes(call_ret, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *call_ret, call_ret);
return call_ret;
}
static void build_return(ast_node *node)
{
ir_node *val = NULL;
if (node->expr.ret.value) {
val = build_expression(node->expr.ret.value);
} else {
val = calloc(1, sizeof(ir_node));
val->code = OC_VOID;
val->id = stbds_hash_bytes(val, sizeof(ir_node), 0xcafebabe);
}
arrput(current_func.return_controls, current_control);
arrput(current_func.return_memories, current_memory);
arrput(current_func.return_values, val);
current_control = NULL;
}
static void finalize_function(void)
{
int count = arrlen(current_func.return_controls);
if (count == 0) {
return;
}
ir_node *final_ctrl = NULL;
ir_node *final_mem = NULL;
ir_node *final_val = NULL;
if (count == 1) {
final_ctrl = current_func.return_controls[0];
final_mem = current_func.return_memories[0];
final_val = current_func.return_values[0];
}
else {
ir_node *region = calloc(1, sizeof(ir_node));
region->code = OC_REGION;
for (int i=0; i<count; i++) {
arrput(region->out, current_func.return_controls[i]);
}
hmput(global_hash, *region, region);
final_ctrl = region;
ir_node *mem_phi = calloc(1, sizeof(ir_node));
mem_phi->code = OC_PHI;
arrput(mem_phi->out, region);
for (int i=0; i<count; i++) {
arrput(mem_phi->out, current_func.return_memories[i]);
}
hmput(global_hash, *mem_phi, mem_phi);
mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe);
final_mem = mem_phi;
ir_node *val_phi = calloc(1, sizeof(ir_node));
val_phi->code = OC_PHI;
arrput(val_phi->out, region);
for (int i=0; i<count; i++) {
arrput(val_phi->out, current_func.return_values[i]);
}
val_phi->id = stbds_hash_bytes(val_phi, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *val_phi, val_phi);
final_val = val_phi;
region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe);
}
ir_node *ret = calloc(1, sizeof(ir_node));
ret->code = OC_RETURN;
arrput(ret->out, final_ctrl);
arrput(ret->out, final_mem);
arrput(ret->out, final_val);
ret->id = stbds_hash_bytes(ret, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *ret, ret);
}
static ir_node *build_function(ast_node *node)
{
memset(&current_func, 0x0, sizeof(current_func));
ast_node *current = node->expr.function.body;
ir_node *func = calloc(1, sizeof(ir_node));
func->code = OC_START;
func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe);
func->data.start_name = node->expr.function.name;
ir_node *start_ctrl = calloc(1, sizeof(ir_node));
start_ctrl->code = OC_PROJ;
start_ctrl->id = stbds_hash_bytes(&start_ctrl, sizeof(usize), 0xcafebabe);
arrput(start_ctrl->out, func);
hmput(global_hash, *start_ctrl, start_ctrl);
current_control = start_ctrl;
ir_node *start_mem = calloc(1, sizeof(ir_node));
start_mem->code = OC_PROJ;
start_mem->id = stbds_hash_bytes(&start_mem, sizeof(usize), 0xcafebabe);
arrput(start_mem->out, func);
hmput(global_hash, *start_mem, start_mem);
current_memory = start_mem;
current_scope = calloc(1, sizeof(ir_node));
current_scope->code = OC_SCOPE;
push_scope();
member *m = node->expr.function.parameters;
while (m) {
ir_node *proj_param = calloc(1, sizeof(ir_node));
proj_param->code = OC_PROJ;
arrput(proj_param->out, func);
proj_param->id = stbds_hash_bytes(proj_param, sizeof(ir_node), 0xcafebabe);
set_def(m->name, proj_param, false);
hmput(global_hash, *proj_param, proj_param);
m = m->next;
}
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr && current_control) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe);
finalize_function();
return func;
}
static ir_node *build_expression(ast_node *node)
{
ir_node *n = NULL;
ir_node *tmp = NULL;
switch (node->type) {
case NODE_UNARY:
n = build_unary(node);
break;
case NODE_BINARY:
n = build_binary(node);
break;
case NODE_INTEGER:
n = calloc(1, sizeof(ir_node));
n->code = OC_CONST_INT;
n->data.const_int = node->expr.integer;
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
break;
case NODE_VAR_DECL:
n = calloc(1, sizeof(ir_node));
if (node->address_taken) {
n->code = OC_STORE;
arrput(n->out, current_memory);
arrput(n->out, build_address(-1, current_stack));
arrput(n->out, build_expression(node->expr.var_decl.value));
current_memory = n;
current_stack += node->expr_type->size;
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *n, n);
n = n->out[1];
set_def(node->expr.var_decl.name, n, true);
} else {
n = build_expression(node->expr.var_decl.value);
set_def(node->expr.var_decl.name, n, false);
}
return n;
case NODE_IDENTIFIER:
struct symbol_def *def = get_def(node->expr.string.start);
if (!def) {
fprintf(stderr, "IR error: undefined identifier '%s'\n", node->expr.string.start);
return NULL;
}
n = def->node;
if (n && def->is_lvalue) {
ir_node *addr_node = n;
n = calloc(1, sizeof(ir_node));
n->code = OC_LOAD;
arrput(n->out, current_memory);
arrput(n->out, addr_node);
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
n = tmp;
} else {
hmput(global_hash, *n, n);
}
}
break;
case NODE_IF:
n = build_if(node);
break;
case NODE_WHILE:
n = build_while(node);
break;
case NODE_RETURN:
build_return(node);
break;
case NODE_CALL:
n = build_call(node);
break;
default:
break;
}
if (n) hmput(global_hash, *n, n);
return n;
}
void ir_build(ast_node *ast)
{
ast_node *current = ast;
graph = calloc(1, sizeof(ir_node));
graph->code = OC_START;
graph->id = stbds_hash_bytes(graph, sizeof(ir_node), 0xcafebabe);
graph->data.start_name = "program";
current_memory = calloc(1, sizeof(ir_node));
current_memory->code = OC_FRAME_PTR;
current_memory->id = stbds_hash_bytes(current_memory, sizeof(ir_node), 0xcafebabe);
current_scope = calloc(1, sizeof(ir_node));
current_scope->code = OC_SCOPE;
push_scope();
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) {
ir_node *expr = build_function(current->expr.unit_node.expr);
arrput(graph->out, expr);
hmput(global_hash, *expr, expr);
// Run GCM on this function
ir_function *scheduled = gcm_schedule(expr);
if (scheduled) {
gcm_print_scheduled(scheduled);
printf("\n");
}
}
current = current->expr.unit_node.next;
}
printf("digraph G {\n");
print_graph(graph);
printf("}\n");
}
static int block_id_counter = 0;
static basic_block *create_block(ir_node *control)
{
basic_block *bb = calloc(1, sizeof(basic_block));
bb->id = block_id_counter++;
bb->control = control;
bb->nodes = NULL;
bb->preds = NULL;
bb->succs = NULL;
bb->idom = NULL;
bb->dom_children = NULL;
bb->dom_depth = 0;
bb->loop_depth = 0;
bb->visited = false;
return bb;
}
static void add_edge(basic_block *from, basic_block *to)
{
arrput(from->succs, to);
arrput(to->preds, from);
}
// Check if a node is a control node (defines a basic block boundary)
static bool is_control_node(ir_node *node)
{
if (!node) return false;
switch (node->code) {
case OC_START:
case OC_REGION:
case OC_LOOP:
case OC_IF:
case OC_RETURN:
case OC_PROJ: // Projections from IF/CALL are control
return true;
default:
return false;
}
}
// Check if a node is pinned (must stay in a specific block)
static bool is_pinned(ir_node *node)
{
if (!node) return false;
switch (node->code) {
case OC_START:
case OC_REGION:
case OC_LOOP:
case OC_IF:
case OC_RETURN:
case OC_PHI:
case OC_PROJ:
case OC_STORE:
case OC_LOAD:
case OC_CALL:
return true;
default:
return false;
}
}
// Map from control nodes to basic blocks
static struct { ir_node *key; basic_block *value; } *control_to_block = NULL;
// Collect all nodes reachable from a function
static void collect_nodes(ir_node *node, ir_node ***all_nodes)
{
if (!node || node->scheduled) return;
node->scheduled = true; // Use as visited marker temporarily
arrput(*all_nodes, node);
// Follow inputs (out array in this IR)
for (int i = 0; i < arrlen(node->out); i++) {
collect_nodes(node->out[i], all_nodes);
}
}
// Build CFG from control nodes
static basic_block *build_cfg(ir_node *func_start, ir_function *func)
{
ir_node **all_nodes = NULL;
// Reset scheduled flags and collect all nodes
collect_nodes(func_start, &all_nodes);
// Reset scheduled flags for actual scheduling later
for (int i = 0; i < arrlen(all_nodes); i++) {
all_nodes[i]->scheduled = false;
all_nodes[i]->pinned = is_pinned(all_nodes[i]);
all_nodes[i]->early = NULL;
all_nodes[i]->late = NULL;
all_nodes[i]->block = NULL;
}
// Create blocks for control nodes
hmfree(control_to_block);
control_to_block = NULL;
basic_block *entry = NULL;
for (int i = 0; i < arrlen(all_nodes); i++) {
ir_node *node = all_nodes[i];
// Create block for START and control flow merge points
if (node->code == OC_START) {
basic_block *bb = create_block(node);
hmput(control_to_block, node, bb);
entry = bb;
node->block = bb;
arrput(func->blocks, bb);
}
// Projections from START are part of entry block
else if (node->code == OC_PROJ && node->out && arrlen(node->out) > 0) {
ir_node *parent = node->out[0];
if (parent->code == OC_START) {
basic_block *bb = hmget(control_to_block, parent);
node->block = bb;
// Also add to hashmap so predecessors can find it
hmput(control_to_block, node, bb);
}
else if (parent->code == OC_IF || parent->code == OC_CALL) {
// Create new block for IF/CALL projections
basic_block *bb = create_block(node);
hmput(control_to_block, node, bb);
node->block = bb;
arrput(func->blocks, bb);
}
else if (parent->code == OC_LOOP) {
// Loop projections get their own blocks too
basic_block *bb = create_block(node);
hmput(control_to_block, node, bb);
node->block = bb;
arrput(func->blocks, bb);
}
}
else if (node->code == OC_REGION) {
basic_block *bb = create_block(node);
hmput(control_to_block, node, bb);
node->block = bb;
arrput(func->blocks, bb);
}
else if (node->code == OC_LOOP) {
basic_block *bb = create_block(node);
hmput(control_to_block, node, bb);
node->block = bb;
arrput(func->blocks, bb);
}
else if (node->code == OC_RETURN) {
basic_block *bb = create_block(node);
hmput(control_to_block, node, bb);
node->block = bb;
arrput(func->blocks, bb);
}
}
// Build CFG edges
for (int i = 0; i < arrlen(all_nodes); i++) {
ir_node *node = all_nodes[i];
basic_block *bb = node->block;
if (!bb) continue;
// Connect based on control flow
if (node->code == OC_IF) {
// IF has projections as successors - find them
for (int j = 0; j < arrlen(all_nodes); j++) {
ir_node *other = all_nodes[j];
if (other->code == OC_PROJ && other->out && arrlen(other->out) > 0) {
if (other->out[0] == node) {
basic_block *succ = hmget(control_to_block, other);
if (succ && succ != bb) {
add_edge(bb, succ);
}
}
}
}
}
else if (node->code == OC_REGION || node->code == OC_LOOP) {
// Region/Loop has control inputs as predecessors
for (int j = 0; j < arrlen(node->out); j++) {
ir_node *pred_ctrl = node->out[j];
if (pred_ctrl) {
basic_block *pred = hmget(control_to_block, pred_ctrl);
if (pred && pred != bb) {
add_edge(pred, bb);
}
}
}
}
else if (node->code == OC_RETURN) {
// Return has control input
if (node->out && arrlen(node->out) > 0) {
ir_node *ctrl_in = node->out[0];
basic_block *pred = hmget(control_to_block, ctrl_in);
if (pred && pred != bb) {
add_edge(pred, bb);
}
}
}
}
// Pin PHI nodes to their region's block
for (int i = 0; i < arrlen(all_nodes); i++) {
ir_node *node = all_nodes[i];
if (node->code == OC_PHI && node->out && arrlen(node->out) > 0) {
ir_node *region = node->out[0];
node->block = hmget(control_to_block, region);
}
}
// Pin IF nodes to their control input's block
for (int i = 0; i < arrlen(all_nodes); i++) {
ir_node *node = all_nodes[i];
if (node->code == OC_IF && node->out && arrlen(node->out) > 1) {
ir_node *ctrl = node->out[1]; // Control is second input
node->block = hmget(control_to_block, ctrl);
}
}
arrfree(all_nodes);
func->block_count = arrlen(func->blocks);
return entry;
}
// Compute dominators using simple iterative algorithm
static void compute_dominators(ir_function *func)
{
if (!func->entry || func->block_count == 0) return;
// Initialize: entry dominates itself
func->entry->idom = func->entry;
func->entry->dom_depth = 0;
bool changed = true;
while (changed) {
changed = false;
for (int i = 0; i < func->block_count; i++) {
basic_block *bb = func->blocks[i];
if (bb == func->entry) continue;
basic_block *new_idom = NULL;
// Find first predecessor with computed idom
for (int j = 0; j < arrlen(bb->preds); j++) {
basic_block *pred = bb->preds[j];
if (pred->idom) {
if (!new_idom) {
new_idom = pred;
} else {
// Intersect dominators
basic_block *a = pred;
basic_block *b = new_idom;
while (a != b) {
while (a && a->dom_depth > b->dom_depth) a = a->idom;
while (b && b->dom_depth > a->dom_depth) b = b->idom;
if (a != b) {
if (a) a = a->idom;
if (b) b = b->idom;
}
}
new_idom = a;
}
}
}
if (new_idom && bb->idom != new_idom) {
bb->idom = new_idom;
bb->dom_depth = new_idom->dom_depth + 1;
changed = true;
}
}
}
// Build dominator tree children
for (int i = 0; i < func->block_count; i++) {
basic_block *bb = func->blocks[i];
if (bb->idom && bb->idom != bb) {
arrput(bb->idom->dom_children, bb);
}
}
}
// Compute loop depths
static void compute_loop_depths(ir_function *func)
{
// Simple approach: look for LOOP control nodes
for (int i = 0; i < func->block_count; i++) {
basic_block *bb = func->blocks[i];
if (bb->control && bb->control->code == OC_LOOP) {
// Mark this block and dominated blocks as in a loop
bb->loop_depth = 1;
for (int j = 0; j < arrlen(bb->dom_children); j++) {
bb->dom_children[j]->loop_depth = bb->loop_depth;
}
}
}
// Propagate loop depths through dominator tree
for (int i = 0; i < func->block_count; i++) {
basic_block *bb = func->blocks[i];
if (bb->idom && bb->idom->loop_depth > bb->loop_depth) {
bb->loop_depth = bb->idom->loop_depth;
}
}
}
// Schedule Early: place each node in earliest legal block
static void schedule_early(ir_node *node, basic_block *entry)
{
if (!node || node->early) return;
// Pinned nodes stay in their assigned block
if (node->pinned && node->block) {
node->early = node->block;
return;
}
// Start with entry block
node->early = entry;
// For each input, schedule it early and update our earliest block
for (int i = 0; i < arrlen(node->out); i++) {
ir_node *input = node->out[i];
if (!input) continue;
schedule_early(input, entry);
// Our earliest block must be dominated by input's earliest block
if (input->early && input->early->dom_depth > node->early->dom_depth) {
node->early = input->early;
}
}
}
// Find the Least Common Ancestor in dominator tree
static basic_block *dom_lca(basic_block *a, basic_block *b)
{
if (!a) return b;
if (!b) return a;
while (a != b) {
while (a && a->dom_depth > b->dom_depth) a = a->idom;
while (b && b->dom_depth > a->dom_depth) b = b->idom;
if (a != b) {
if (a) a = a->idom;
if (b) b = b->idom;
}
}
return a;
}
// Find uses of a node
static void find_uses(ir_node *node, ir_node **all_nodes, int count, ir_node ***uses)
{
for (int i = 0; i < count; i++) {
ir_node *other = all_nodes[i];
if (other == node) continue;
for (int j = 0; j < arrlen(other->out); j++) {
if (other->out[j] == node) {
arrput(*uses, other);
break;
}
}
}
}
// Schedule Late: find latest legal block for each node
static void schedule_late(ir_node *node, ir_node **all_nodes, int count)
{
if (!node || node->late || node->pinned) {
if (node && node->pinned && !node->late) {
node->late = node->block;
}
return;
}
// Mark as being processed to prevent infinite recursion
// Use early as a sentinel if we're in progress
node->late = node->early;
if (!node->late) {
// Fallback - shouldn't happen if schedule_early was run
return;
}
ir_node **uses = NULL;
find_uses(node, all_nodes, count, &uses);
basic_block *lca = NULL;
for (int i = 0; i < arrlen(uses); i++) {
ir_node *use = uses[i];
// Make sure use is scheduled (but avoid cycles)
if (!use->late) {
schedule_late(use, all_nodes, count);
}
basic_block *use_block = use->early;
if (use->block) use_block = use->block;
else if (use->late) use_block = use->late;
if (use_block) {
// For PHI nodes, use the predecessor block, not the PHI's block
if (use->code == OC_PHI) {
// Find which input we are
for (int j = 1; j < arrlen(use->out); j++) {
if (use->out[j] == node) {
// We're the j-th input, use j-th predecessor
ir_node *region = use->out[0];
if (region && j-1 < arrlen(region->out)) {
ir_node *pred_ctrl = region->out[j-1];
basic_block *pred = hmget(control_to_block, pred_ctrl);
if (pred) use_block = pred;
}
break;
}
}
}
lca = dom_lca(lca, use_block);
}
}
arrfree(uses);
if (lca) {
node->late = lca;
} else {
node->late = node->early;
}
}
// Select final block between early and late
static void select_block(ir_node *node)
{
if (!node || node->block) return; // Already placed
if (!node->early || !node->late) {
node->block = node->early ? node->early : node->late;
return;
}
// Pick block with shallowest loop depth between early and late
basic_block *best = node->late;
basic_block *current = node->late;
while (current && current->dom_depth >= node->early->dom_depth) {
if (current->loop_depth < best->loop_depth) {
best = current;
}
if (current == node->early) break;
current = current->idom;
}
node->block = best;
}
// Schedule nodes within a block (topological sort based on dependencies)
static void schedule_block(basic_block *bb, ir_node **all_nodes, int count)
{
ir_node **ready = NULL;
ir_node **pending = NULL;
// Collect nodes scheduled to this block
for (int i = 0; i < count; i++) {
ir_node *node = all_nodes[i];
if (node->block == bb && !node->scheduled) {
arrput(pending, node);
}
}
// Topological sort
while (arrlen(pending) > 0) {
// Find a node with all inputs satisfied
int ready_idx = -1;
for (int i = 0; i < arrlen(pending); i++) {
ir_node *node = pending[i];
bool inputs_ready = true;
for (int j = 0; j < arrlen(node->out); j++) {
ir_node *input = node->out[j];
if (!input) continue;
// Input is ready if it's in a different block or already scheduled
if (input->block == bb && !input->scheduled) {
inputs_ready = false;
break;
}
}
if (inputs_ready) {
ready_idx = i;
break;
}
}
if (ready_idx == -1) {
// Cycle detected or all remaining have unsatisfied deps - just pick first
ready_idx = 0;
}
ir_node *node = pending[ready_idx];
node->scheduled = true;
arrput(bb->nodes, node);
// Remove from pending
arrdel(pending, ready_idx);
}
arrfree(ready);
arrfree(pending);
}
// Main GCM entry point
ir_function *gcm_schedule(ir_node *func_start)
{
if (!func_start || func_start->code != OC_START) return NULL;
block_id_counter = 0;
ir_function *func = calloc(1, sizeof(ir_function));
func->name = func_start->data.start_name;
func->blocks = NULL;
// Build CFG
func->entry = build_cfg(func_start, func);
if (!func->entry) {
free(func);
return NULL;
}
// Compute dominators
compute_dominators(func);
// Compute loop depths
compute_loop_depths(func);
// Collect all nodes for scheduling
ir_node **all_nodes = NULL;
for (int i = 0; i < hmlen(global_hash); i++) {
ir_node *node = global_hash[i].value;
// Check if this node belongs to this function
// (simplified: include all nodes for now)
arrput(all_nodes, node);
node->scheduled = false;
}
int node_count = arrlen(all_nodes);
// Schedule Early
for (int i = 0; i < node_count; i++) {
schedule_early(all_nodes[i], func->entry);
}
// Schedule Late
for (int i = 0; i < node_count; i++) {
schedule_late(all_nodes[i], all_nodes, node_count);
}
// Select final blocks
for (int i = 0; i < node_count; i++) {
select_block(all_nodes[i]);
}
// Reset scheduled flags for block scheduling
for (int i = 0; i < node_count; i++) {
all_nodes[i]->scheduled = false;
}
// Schedule nodes within each block
for (int i = 0; i < func->block_count; i++) {
schedule_block(func->blocks[i], all_nodes, node_count);
}
arrfree(all_nodes);
return func;
}
// Print scheduled IR for debugging
void gcm_print_scheduled(ir_function *func)
{
if (!func) return;
printf("Function: %s\n", func->name ? func->name : "<unnamed>");
printf("Blocks: %d\n\n", func->block_count);
for (int i = 0; i < func->block_count; i++) {
basic_block *bb = func->blocks[i];
printf("BB%d (depth=%d, loop=%d):\n", bb->id, bb->dom_depth, bb->loop_depth);
// Print predecessors
printf(" preds: ");
for (int j = 0; j < arrlen(bb->preds); j++) {
printf("BB%d ", bb->preds[j]->id);
}
printf("\n");
// Print successors
printf(" succs: ");
for (int j = 0; j < arrlen(bb->succs); j++) {
printf("BB%d ", bb->succs[j]->id);
}
printf("\n");
// Print idom
if (bb->idom && bb->idom != bb) {
printf(" idom: BB%d\n", bb->idom->id);
}
// Print scheduled nodes
printf(" instructions:\n");
for (int j = 0; j < arrlen(bb->nodes); j++) {
ir_node *node = bb->nodes[j];
printf(" [%ld] ", node->id);
switch (node->code) {
case OC_START: printf("START %s", node->data.start_name); break;
case OC_ADD: printf("ADD"); break;
case OC_SUB: printf("SUB"); break;
case OC_MUL: printf("MUL"); break;
case OC_DIV: printf("DIV"); break;
case OC_MOD: printf("MOD"); break;
case OC_BAND: printf("AND"); break;
case OC_BOR: printf("OR"); break;
case OC_BXOR: printf("XOR"); break;
case OC_NEG: printf("NEG"); break;
case OC_EQ: printf("EQ"); break;
case OC_NEQ: printf("NEQ"); break;
case OC_LT: printf("LT"); break;
case OC_GT: printf("GT"); break;
case OC_LE: printf("LE"); break;
case OC_GE: printf("GE"); break;
case OC_AND: printf("LAND"); break;
case OC_OR: printf("LOR"); break;
case OC_CONST_INT: printf("CONST %ld", node->data.const_int); break;
case OC_CONST_FLOAT: printf("CONST %f", node->data.const_float); break;
case OC_VOID: printf("VOID"); break;
case OC_FRAME_PTR: printf("FRAME_PTR"); break;
case OC_ADDR: printf("ADDR"); break;
case OC_STORE: printf("STORE"); break;
case OC_LOAD: printf("LOAD"); break;
case OC_REGION: printf("REGION"); break;
case OC_PHI: printf("PHI"); break;
case OC_IF: printf("IF"); break;
case OC_PROJ: printf("PROJ"); break;
case OC_LOOP: printf("LOOP"); break;
case OC_CALL: printf("CALL %s", node->data.call_name); break;
case OC_RETURN: printf("RETURN"); break;
default: printf("OP_%d", node->code); break;
}
// Print inputs
if (arrlen(node->out) > 0) {
printf(" (");
for (int k = 0; k < arrlen(node->out); k++) {
if (k > 0) printf(", ");
if (node->out[k]) {
printf("%ld", node->out[k]->id);
} else {
printf("null");
}
}
printf(")");
}
printf("\n");
}
printf("\n");
}
}