#include "ir.h" #include #include #include "stb_ds.h" #include "sema.h" struct { ir_node key; ir_node *value; } *global_hash = NULL; static ir_node *graph; static ir_node *current_memory; static ir_node *current_control; static usize current_stack = 0; static ir_node *current_scope = NULL; static ir_node *build_expression(ast_node *node); static struct { ir_node **return_controls; ir_node **return_memories; ir_node **return_values; } current_func = {0}; static void node_name(ir_node *node) { if (!node) { printf("null [label=\"NULL\", style=filled, fillcolor=red]\n"); return; } printf("%ld ", node->id); switch (node->code) { case OC_START: printf("[label=\"%s\", style=filled, color=orange]\n", node->data.start_name); break; case OC_RETURN: printf("[label=\"return\", style=filled, color=orange]\n"); break; case OC_ADD: printf("[label=\"+\"]\n"); break; case OC_NEG: case OC_SUB: printf("[label=\"-\"]\n"); break; case OC_DIV: printf("[label=\"/\"]\n"); break; case OC_MUL: printf("[label=\"*\"]\n"); break; case OC_MOD: printf("[label=\"%%\"]\n"); break; case OC_BAND: printf("[label=\"&\"]\n"); break; case OC_BOR: printf("[label=\"|\"]\n"); break; case OC_BXOR: printf("[label=\"^\"]\n"); break; case OC_EQ: printf("[label=\"==\"]\n"); break; case OC_CONST_INT: printf("[label=\"%ld\"]\n", node->data.const_int); break; case OC_CONST_FLOAT: printf("[label=\"%f\"]\n", node->data.const_float); break; case OC_FRAME_PTR: printf("[label=\"frame_ptr\"]\n"); break; case OC_STORE: printf("[label=\"store\", shape=box]\n"); break; case OC_LOAD: printf("[label=\"load\", shape=box]\n"); break; case OC_ADDR: printf("[label=\"addr\"]\n"); break; case OC_REGION: printf("[label=\"region\", shape=diamond, style=filled, color=green]\n"); break; case OC_PHI: printf("[label=\"phi\", shape=triangle]\n"); break; case OC_IF: printf("[label=\"if\", shape=diamond, style=filled, color=lightblue]\n"); break; case OC_PROJ: printf("[label=\"proj\", shape=diamond, style=filled, color=cyan]\n"); break; default: printf("[label=\"%d\"]\n", node->code); break; } } static void print_graph(ir_node *node) { for (int i = 0; i < hmlen(global_hash); i++) { ir_node *node = global_hash[i].value; node_name(node); for (int j = 0; j < arrlen(node->out); j++) { if (node->out[j]) { node_name(node->out[j]); printf("%ld->%ld\n", node->out[j]->id, node->id); } } } } static void push_scope(void) { arrput(current_scope->data.symbol_tables, NULL); } static struct symbol_def *get_def(char *name) { for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) { struct symbol_def *def = shget(current_scope->data.symbol_tables[i], name); if (def) return def; } return NULL; } static void set_def(char *name, ir_node *node, bool lvalue) { for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) { if (shget(current_scope->data.symbol_tables[i], name)) { struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); def->is_lvalue = lvalue; def->node = node; shput(current_scope->data.symbol_tables[i], name, def); return; } } int index = arrlen(current_scope->data.symbol_tables) - 1; struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); def->is_lvalue = lvalue; def->node = node; shput(current_scope->data.symbol_tables[index], name, def); } static ir_node *copy_scope(ir_node *src) { ir_node *dst = calloc(1, sizeof(ir_node)); dst->code = OC_SCOPE; for (int i=0; i < arrlen(src->data.symbol_tables); i++) { arrput(dst->data.symbol_tables, NULL); symbol_table *src_table = src->data.symbol_tables[i]; for (int j=0; j < shlen(src_table); j++) { shput(dst->data.symbol_tables[i], src_table[j].key, src_table[j].value); } } return dst; } static void const_fold(ir_node *binary) { ir_node *left = binary->out[0]; ir_node *right = binary->out[1]; if (left->code == OC_CONST_INT && right->code == OC_CONST_INT) { switch (binary->code) { case OC_ADD: binary->data.const_int = left->data.const_int + right->data.const_int; break; case OC_SUB: binary->data.const_int = left->data.const_int - right->data.const_int; break; case OC_MUL: binary->data.const_int = left->data.const_int * right->data.const_int; break; case OC_DIV: if (right->data.const_int != 0) binary->data.const_int = left->data.const_int / right->data.const_int; break; case OC_MOD: if (right->data.const_int != 0) binary->data.const_int = left->data.const_int % right->data.const_int; break; case OC_BOR: binary->data.const_int = left->data.const_int | right->data.const_int; break; case OC_BAND: binary->data.const_int = left->data.const_int & right->data.const_int; break; case OC_BXOR: binary->data.const_int = left->data.const_int ^ right->data.const_int; break; case OC_EQ: binary->data.const_int = left->data.const_int == right->data.const_int; break; default: return; } binary->code = OC_CONST_INT; arrfree(binary->out); binary->out = NULL; arrfree(binary->in); binary->in = NULL; binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe); } if (left->code == OC_CONST_FLOAT && right->code == OC_CONST_FLOAT) { switch (binary->code) { case OC_ADD: binary->data.const_float = left->data.const_float + right->data.const_float; break; case OC_SUB: binary->data.const_float = left->data.const_float - right->data.const_float; break; case OC_MUL: binary->data.const_float = left->data.const_float * right->data.const_float; break; case OC_DIV: if (right->data.const_float != 0.0f) binary->data.const_float = left->data.const_float / right->data.const_float; break; default: return; } binary->code = OC_CONST_FLOAT; arrfree(binary->out); binary->out = NULL; arrfree(binary->in); binary->in = NULL; binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe); } } static ir_node *build_address(usize base, usize offset) { ir_node *addr = calloc(1, sizeof(ir_node)); addr->code = OC_ADDR; ir_node *base_node = calloc(1, sizeof(ir_node)); if (base == -1) { base_node->code = OC_FRAME_PTR; base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe); } else { base_node->code = OC_CONST_INT; base_node->data.const_int = base; base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe); } ir_node *offset_node = calloc(1, sizeof(ir_node)); offset_node->code = OC_CONST_INT; offset_node->data.const_int = offset; offset_node->id = stbds_hash_bytes(offset_node, sizeof(ir_node), 0xcafebabe); arrput(addr->out, base_node); arrput(addr->out, offset_node); addr->id = stbds_hash_bytes(addr, sizeof(ir_node), 0xcafebabe); ir_node *tmp = hmget(global_hash, *addr); if (tmp) { free(addr); return tmp; } return addr; } static ir_node *build_assign_ptr(ast_node *binary) { ir_node *val_node = build_expression(binary->expr.binary.right); char *var_name = binary->expr.binary.left->expr.string.start; ir_node *existing_def = get_def(var_name)->node; ir_node *store = calloc(1, sizeof(ir_node)); store->code = OC_STORE; arrput(store->out, current_control); arrput(store->out, current_memory); arrput(store->out, existing_def); arrput(store->out, val_node); store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe); hmput(global_hash, *store, store); current_memory = store; return val_node; } static ir_node *build_assign(ast_node *binary) { ir_node *val_node = build_expression(binary->expr.binary.right); char *var_name = binary->expr.binary.left->expr.string.start; struct symbol_def *def = get_def(var_name); if (def && def->is_lvalue) { ir_node *existing_def = def->node; ir_node *store = calloc(1, sizeof(ir_node)); store->code = OC_STORE; arrput(store->out, current_control); arrput(store->out, current_memory); arrput(store->out, existing_def); arrput(store->out, val_node); store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe); hmput(global_hash, *store, store); current_memory = store; return val_node; } set_def(var_name, val_node, false); return val_node; } static ir_node *build_binary(ast_node *node) { ir_node *n = calloc(1, sizeof(ir_node)); switch (node->expr.binary.operator) { case OP_ASSIGN: free(n); return build_assign(node); case OP_ASSIGN_PTR: free(n); return build_assign_ptr(node); case OP_PLUS: n->code = OC_ADD; break; case OP_MINUS: n->code = OC_SUB; break; case OP_DIV: n->code = OC_DIV; break; case OP_MUL: n->code = OC_MUL; break; case OP_MOD: n->code = OC_MOD; break; case OP_BOR: n->code = OC_BOR; break; case OP_BAND: n->code = OC_BAND; break; case OP_BXOR: n->code = OC_BXOR; break; case OP_EQ: n->code = OC_EQ; break; default: break; } arrput(n->out, build_expression(node->expr.binary.left)); arrput(n->out, build_expression(node->expr.binary.right)); n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); const_fold(n); ir_node *tmp = hmget(global_hash, *n); if (tmp) { free(n); return tmp; } return n; } static ir_node *build_load(ast_node *node) { ir_node *n = calloc(1, sizeof(ir_node)); n->code = OC_LOAD; arrput(n->out, current_memory); arrput(n->out, build_expression(node)); n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabebabecafe); ir_node *tmp = hmget(global_hash, *n); if (tmp) { free(n); return tmp; } return n; } static ir_node *build_unary(ast_node *node) { ir_node *n = calloc(1, sizeof(ir_node)); switch (node->expr.unary.operator) { case UOP_MINUS: n->code = OC_NEG; arrput(n->out, build_expression(node->expr.unary.right)); break; case UOP_REF: free(n); if (node->expr.unary.right->type == NODE_IDENTIFIER) { struct symbol_def *def = get_def(node->expr.unary.right->expr.string.start); if (def) { return def->node; } } return build_expression(node->expr.unary.right); case UOP_DEREF: free(n); return build_load(node->expr.unary.right); default: break; } if (n->out && n->out[0]->code == OC_CONST_INT) { switch (n->code) { case OC_NEG: n->data.const_int = -(n->out[0]->data.const_int); break; default: break; } n->code = OC_CONST_INT; arrfree(n->out); n->out = NULL; } else if (n->out && n->out[0]->code == OC_CONST_FLOAT) { switch (n->code) { case OC_NEG: n->data.const_float = -(n->out[0]->data.const_float); break; default: break; } n->code = OC_CONST_FLOAT; arrfree(n->out); n->out = NULL; } n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); ir_node *tmp = hmget(global_hash, *n); if (tmp) { free(n); return tmp; } return n; } static ir_node *build_if(ast_node *node) { ir_node *condition = build_expression(node->expr.if_stmt.condition); ir_node *if_node = calloc(1, sizeof(ir_node)); if_node->code = OC_IF; arrput(if_node->out, condition); arrput(if_node->out, current_control); if_node->id = stbds_hash_bytes(if_node, sizeof(ir_node), 0xcafebabe); hmput(global_hash, *if_node, if_node); ir_node *proj_true = calloc(1, sizeof(ir_node)); proj_true->code = OC_PROJ; arrput(proj_true->out, if_node); proj_true->id = stbds_hash_bytes(proj_true, sizeof(ir_node), 0xcafebabe); hmput(global_hash, *proj_true, proj_true); ir_node *proj_false = calloc(1, sizeof(ir_node)); proj_false->code = OC_PROJ; arrput(proj_false->out, if_node); proj_false->id = stbds_hash_bytes(proj_false, sizeof(ir_node), 0xcafebabe); hmput(global_hash, *proj_false, proj_false); ir_node *base_scope = copy_scope(current_scope); ir_node *base_mem = current_memory; current_control = proj_true; ast_node *current = node->expr.if_stmt.body; while (current && current->type == NODE_UNIT) { if (current->expr.unit_node.expr) { build_expression(current->expr.unit_node.expr); } current = current->expr.unit_node.next; } ir_node *then_scope = current_scope; ir_node *then_mem = current_memory; ir_node *then_control = current_control; current_scope = copy_scope(base_scope); current_memory = base_mem; current_control = proj_false; current = node->expr.if_stmt.otherwise; while (current && current->type == NODE_UNIT) { if (current->expr.unit_node.expr) { build_expression(current->expr.unit_node.expr); } current = current->expr.unit_node.next; } ir_node *else_scope = current_scope; ir_node *else_mem = current_memory; ir_node *else_control = current_control; ir_node *region = calloc(1, sizeof(ir_node)); region->code = OC_REGION; arrput(region->out, then_control); arrput(region->out, else_control); region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe); hmput(global_hash, *region, region); if (then_mem->id != else_mem->id) { ir_node *phi = calloc(1, sizeof(ir_node)); phi->code = OC_PHI; arrput(phi->out, region); arrput(phi->out, then_mem); arrput(phi->out, else_mem); phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe); hmput(global_hash, *phi, phi); current_memory = phi; } else { current_memory = then_mem; } current_scope = base_scope; for (int i = 0; i < arrlen(current_scope->data.symbol_tables); i++) { symbol_table *base_table = current_scope->data.symbol_tables[i]; for (int j = 0; j < shlen(base_table); j++) { char *key = base_table[j].key; ir_node *found_then = NULL; symbol_table *t_table = then_scope->data.symbol_tables[i]; if (shget(t_table, key)->node) found_then = shget(t_table, key)->node; else found_then = base_table[j].value->node; ir_node *found_else = NULL; symbol_table *e_table = else_scope->data.symbol_tables[i]; if (shget(e_table, key)->node) found_else = shget(e_table, key)->node; else found_else = base_table[j].value->node; if (found_then->id != found_else->id) { ir_node *phi = calloc(1, sizeof(ir_node)); phi->code = OC_PHI; arrput(phi->out, region); arrput(phi->out, found_then); arrput(phi->out, found_else); phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe); struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); def->node = phi; def->is_lvalue = false; shput(current_scope->data.symbol_tables[i], key, def); hmput(global_hash, *phi, phi); } else { struct symbol_def *def = calloc(1, sizeof(struct symbol_def)); def->node = found_then; def->is_lvalue = false; shput(current_scope->data.symbol_tables[i], key, def); } } } current_control = region; return region; } static void build_return(ast_node *node) { ir_node *val = NULL; if (node->expr.ret.value) { val = build_expression(node->expr.ret.value); } else { val = calloc(1, sizeof(ir_node)); val->code = OC_VOID; val->id = stbds_hash_bytes(val, sizeof(ir_node), 0xcafebabe); } arrput(current_func.return_controls, current_control); arrput(current_func.return_memories, current_memory); arrput(current_func.return_values, val); current_control = NULL; } static void finalize_function(void) { int count = arrlen(current_func.return_controls); if (count == 0) { return; } ir_node *final_ctrl = NULL; ir_node *final_mem = NULL; ir_node *final_val = NULL; if (count == 1) { final_ctrl = current_func.return_controls[0]; final_mem = current_func.return_memories[0]; final_val = current_func.return_values[0]; } else { ir_node *region = calloc(1, sizeof(ir_node)); region->code = OC_REGION; for (int i=0; iout, current_func.return_controls[i]); } hmput(global_hash, *region, region); final_ctrl = region; ir_node *mem_phi = calloc(1, sizeof(ir_node)); mem_phi->code = OC_PHI; arrput(mem_phi->out, region); for (int i=0; iout, current_func.return_memories[i]); } hmput(global_hash, *mem_phi, mem_phi); mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe); final_mem = mem_phi; ir_node *val_phi = calloc(1, sizeof(ir_node)); val_phi->code = OC_PHI; //arrput(val_phi->out, region); for (int i=0; iout, current_func.return_values[i]); } val_phi->id = stbds_hash_bytes(val_phi, sizeof(ir_node), 0xcafebabe); hmput(global_hash, *val_phi, val_phi); final_val = val_phi; region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe); } ir_node *ret = calloc(1, sizeof(ir_node)); ret->code = OC_RETURN; arrput(ret->out, final_ctrl); arrput(ret->out, final_mem); arrput(ret->out, final_val); ret->id = stbds_hash_bytes(ret, sizeof(ir_node), 0xcafebabe); hmput(global_hash, *ret, ret); } static ir_node *build_function(ast_node *node) { memset(¤t_func, 0x0, sizeof(current_func)); ast_node *current = node->expr.function.body; ir_node *func = calloc(1, sizeof(ir_node)); func->code = OC_START; func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe); func->data.start_name = node->expr.function.name; ir_node *start_ctrl = calloc(1, sizeof(ir_node)); start_ctrl->code = OC_PROJ; start_ctrl->id = stbds_hash_bytes(&start_ctrl, sizeof(usize), 0xcafebabe); arrput(start_ctrl->out, func); hmput(global_hash, *start_ctrl, start_ctrl); current_control = start_ctrl; ir_node *start_mem = calloc(1, sizeof(ir_node)); start_mem->code = OC_PROJ; start_mem->id = stbds_hash_bytes(&start_mem, sizeof(usize), 0xcafebabe); arrput(start_mem->out, func); hmput(global_hash, *start_mem, start_mem); current_memory = start_mem; current_scope = calloc(1, sizeof(ir_node)); current_scope->code = OC_SCOPE; push_scope(); member *m = node->expr.function.parameters; while (m) { ir_node *proj_param = calloc(1, sizeof(ir_node)); proj_param->code = OC_PROJ; arrput(proj_param->out, func); proj_param->id = stbds_hash_bytes(proj_param, sizeof(ir_node), 0xcafebabe); set_def(m->name, proj_param, false); hmput(global_hash, *proj_param, proj_param); m = m->next; } while (current && current->type == NODE_UNIT) { if (current->expr.unit_node.expr) { build_expression(current->expr.unit_node.expr); } current = current->expr.unit_node.next; } func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe); finalize_function(); return func; } static ir_node *build_expression(ast_node *node) { ir_node *n = NULL; ir_node *tmp = NULL; switch (node->type) { case NODE_UNARY: n = build_unary(node); break; case NODE_BINARY: n = build_binary(node); break; case NODE_INTEGER: n = calloc(1, sizeof(ir_node)); n->code = OC_CONST_INT; n->data.const_int = node->expr.integer; n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); tmp = hmget(global_hash, *n); if (tmp) { free(n); return tmp; } break; case NODE_VAR_DECL: n = calloc(1, sizeof(ir_node)); if (node->address_taken) { n->code = OC_STORE; arrput(n->out, current_memory); arrput(n->out, build_address(-1, current_stack)); arrput(n->out, build_expression(node->expr.var_decl.value)); current_memory = n; current_stack += node->expr_type->size; n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); hmput(global_hash, *n, n); n = n->out[1]; set_def(node->expr.var_decl.name, n, true); } else { n = build_expression(node->expr.var_decl.value); set_def(node->expr.var_decl.name, n, false); } return n; case NODE_IDENTIFIER: struct symbol_def *def = get_def(node->expr.string.start); n = def->node; if (n && def->is_lvalue) { ir_node *addr_node = n; n = calloc(1, sizeof(ir_node)); n->code = OC_LOAD; arrput(n->out, current_memory); arrput(n->out, addr_node); n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe); ir_node *tmp = hmget(global_hash, *n); if (tmp) { free(n); n = tmp; } else { hmput(global_hash, *n, n); } } break; case NODE_IF: n = build_if(node); break; case NODE_RETURN: build_return(node); break; default: break; } if (n) hmput(global_hash, *n, n); return n; } void ir_build(ast_node *ast) { ast_node *current = ast; graph = calloc(1, sizeof(ir_node)); graph->code = OC_START; graph->id = stbds_hash_bytes(graph, sizeof(ir_node), 0xcafebabe); graph->data.start_name = "program"; current_memory = calloc(1, sizeof(ir_node)); current_memory->code = OC_FRAME_PTR; current_memory->id = stbds_hash_bytes(current_memory, sizeof(ir_node), 0xcafebabe); current_scope = calloc(1, sizeof(ir_node)); current_scope->code = OC_SCOPE; push_scope(); while (current && current->type == NODE_UNIT) { if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) { ir_node *expr = build_function(current->expr.unit_node.expr); arrput(graph->out, expr); hmput(global_hash, *expr, expr); } current = current->expr.unit_node.next; } printf("digraph G {\n"); print_graph(graph); printf("}\n"); }