lc/ir.c

812 lines
21 KiB
C

#include "ir.h"
#include <stdlib.h>
#include <stdio.h>
#include "stb_ds.h"
#include "sema.h"
struct { ir_node key; ir_node *value; } *global_hash = NULL;
static ir_node *graph;
static ir_node *current_memory;
static ir_node *current_control;
static usize current_stack = 0;
static ir_node *current_scope = NULL;
static ir_node *build_expression(ast_node *node);
static struct {
ir_node **return_controls;
ir_node **return_memories;
ir_node **return_values;
} current_func = {0};
static void node_name(ir_node *node)
{
if (!node) {
printf("null [label=\"NULL\", style=filled, fillcolor=red]\n");
return;
}
printf("%ld ", node->id);
switch (node->code) {
case OC_START:
printf("[label=\"%s\", style=filled, color=orange]\n", node->data.start_name);
break;
case OC_RETURN:
printf("[label=\"return\", style=filled, color=orange]\n");
break;
case OC_ADD:
printf("[label=\"+\"]\n");
break;
case OC_NEG:
case OC_SUB:
printf("[label=\"-\"]\n");
break;
case OC_DIV:
printf("[label=\"/\"]\n");
break;
case OC_MUL:
printf("[label=\"*\"]\n");
break;
case OC_MOD:
printf("[label=\"%%\"]\n");
break;
case OC_BAND:
printf("[label=\"&\"]\n");
break;
case OC_BOR:
printf("[label=\"|\"]\n");
break;
case OC_BXOR:
printf("[label=\"^\"]\n");
break;
case OC_EQ:
printf("[label=\"==\"]\n");
break;
case OC_CONST_INT:
printf("[label=\"%ld\"]\n", node->data.const_int);
break;
case OC_CONST_FLOAT:
printf("[label=\"%f\"]\n", node->data.const_float);
break;
case OC_FRAME_PTR:
printf("[label=\"frame_ptr\"]\n");
break;
case OC_STORE:
printf("[label=\"store\", shape=box]\n");
break;
case OC_LOAD:
printf("[label=\"load\", shape=box]\n");
break;
case OC_ADDR:
printf("[label=\"addr\"]\n");
break;
case OC_REGION:
printf("[label=\"region\", shape=diamond, style=filled, color=green]\n");
break;
case OC_PHI:
printf("[label=\"phi\", shape=triangle]\n");
break;
case OC_IF:
printf("[label=\"if\", shape=diamond, style=filled, color=lightblue]\n");
break;
case OC_PROJ:
printf("[label=\"proj\", shape=diamond, style=filled, color=cyan]\n");
break;
default:
printf("[label=\"%d\"]\n", node->code);
break;
}
}
static void print_graph(ir_node *node)
{
for (int i = 0; i < hmlen(global_hash); i++) {
ir_node *node = global_hash[i].value;
node_name(node);
for (int j = 0; j < arrlen(node->out); j++) {
if (node->out[j]) {
node_name(node->out[j]);
printf("%ld->%ld\n", node->out[j]->id, node->id);
}
}
}
}
static void push_scope(void)
{
arrput(current_scope->data.symbol_tables, NULL);
}
static struct symbol_def *get_def(char *name)
{
for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) {
struct symbol_def *def = shget(current_scope->data.symbol_tables[i], name);
if (def) return def;
}
return NULL;
}
static void set_def(char *name, ir_node *node, bool lvalue)
{
for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) {
if (shget(current_scope->data.symbol_tables[i], name)) {
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->is_lvalue = lvalue;
def->node = node;
shput(current_scope->data.symbol_tables[i], name, def);
return;
}
}
int index = arrlen(current_scope->data.symbol_tables) - 1;
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->is_lvalue = lvalue;
def->node = node;
shput(current_scope->data.symbol_tables[index], name, def);
}
static ir_node *copy_scope(ir_node *src)
{
ir_node *dst = calloc(1, sizeof(ir_node));
dst->code = OC_SCOPE;
for (int i=0; i < arrlen(src->data.symbol_tables); i++) {
arrput(dst->data.symbol_tables, NULL);
symbol_table *src_table = src->data.symbol_tables[i];
for (int j=0; j < shlen(src_table); j++) {
shput(dst->data.symbol_tables[i], src_table[j].key, src_table[j].value);
}
}
return dst;
}
static void const_fold(ir_node *binary)
{
ir_node *left = binary->out[0];
ir_node *right = binary->out[1];
if (left->code == OC_CONST_INT && right->code == OC_CONST_INT) {
switch (binary->code) {
case OC_ADD:
binary->data.const_int = left->data.const_int + right->data.const_int;
break;
case OC_SUB:
binary->data.const_int = left->data.const_int - right->data.const_int;
break;
case OC_MUL:
binary->data.const_int = left->data.const_int * right->data.const_int;
break;
case OC_DIV:
if (right->data.const_int != 0)
binary->data.const_int = left->data.const_int / right->data.const_int;
break;
case OC_MOD:
if (right->data.const_int != 0)
binary->data.const_int = left->data.const_int % right->data.const_int;
break;
case OC_BOR:
binary->data.const_int = left->data.const_int | right->data.const_int;
break;
case OC_BAND:
binary->data.const_int = left->data.const_int & right->data.const_int;
break;
case OC_BXOR:
binary->data.const_int = left->data.const_int ^ right->data.const_int;
break;
case OC_EQ:
binary->data.const_int = left->data.const_int == right->data.const_int;
break;
default:
return;
}
binary->code = OC_CONST_INT;
arrfree(binary->out); binary->out = NULL;
arrfree(binary->in); binary->in = NULL;
binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe);
}
if (left->code == OC_CONST_FLOAT && right->code == OC_CONST_FLOAT) {
switch (binary->code) {
case OC_ADD:
binary->data.const_float = left->data.const_float + right->data.const_float;
break;
case OC_SUB:
binary->data.const_float = left->data.const_float - right->data.const_float;
break;
case OC_MUL:
binary->data.const_float = left->data.const_float * right->data.const_float;
break;
case OC_DIV:
if (right->data.const_float != 0.0f)
binary->data.const_float = left->data.const_float / right->data.const_float;
break;
default:
return;
}
binary->code = OC_CONST_FLOAT;
arrfree(binary->out); binary->out = NULL;
arrfree(binary->in); binary->in = NULL;
binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe);
}
}
static ir_node *build_address(usize base, usize offset) {
ir_node *addr = calloc(1, sizeof(ir_node));
addr->code = OC_ADDR;
ir_node *base_node = calloc(1, sizeof(ir_node));
if (base == -1) {
base_node->code = OC_FRAME_PTR;
base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe);
} else {
base_node->code = OC_CONST_INT;
base_node->data.const_int = base;
base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe);
}
ir_node *offset_node = calloc(1, sizeof(ir_node));
offset_node->code = OC_CONST_INT;
offset_node->data.const_int = offset;
offset_node->id = stbds_hash_bytes(offset_node, sizeof(ir_node), 0xcafebabe);
arrput(addr->out, base_node);
arrput(addr->out, offset_node);
addr->id = stbds_hash_bytes(addr, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *addr);
if (tmp) {
free(addr);
return tmp;
}
return addr;
}
static ir_node *build_assign_ptr(ast_node *binary)
{
ir_node *val_node = build_expression(binary->expr.binary.right);
char *var_name = binary->expr.binary.left->expr.string.start;
ir_node *existing_def = get_def(var_name)->node;
ir_node *store = calloc(1, sizeof(ir_node));
store->code = OC_STORE;
arrput(store->out, current_control);
arrput(store->out, current_memory);
arrput(store->out, existing_def);
arrput(store->out, val_node);
store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *store, store);
current_memory = store;
return val_node;
}
static ir_node *build_assign(ast_node *binary)
{
ir_node *val_node = build_expression(binary->expr.binary.right);
char *var_name = binary->expr.binary.left->expr.string.start;
struct symbol_def *def = get_def(var_name);
if (def && def->is_lvalue) {
ir_node *existing_def = def->node;
ir_node *store = calloc(1, sizeof(ir_node));
store->code = OC_STORE;
arrput(store->out, current_control);
arrput(store->out, current_memory);
arrput(store->out, existing_def);
arrput(store->out, val_node);
store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *store, store);
current_memory = store;
return val_node;
}
set_def(var_name, val_node, false);
return val_node;
}
static ir_node *build_binary(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
switch (node->expr.binary.operator) {
case OP_ASSIGN:
free(n);
return build_assign(node);
case OP_ASSIGN_PTR:
free(n);
return build_assign_ptr(node);
case OP_PLUS:
n->code = OC_ADD;
break;
case OP_MINUS:
n->code = OC_SUB;
break;
case OP_DIV:
n->code = OC_DIV;
break;
case OP_MUL:
n->code = OC_MUL;
break;
case OP_MOD:
n->code = OC_MOD;
break;
case OP_BOR:
n->code = OC_BOR;
break;
case OP_BAND:
n->code = OC_BAND;
break;
case OP_BXOR:
n->code = OC_BXOR;
break;
case OP_EQ:
n->code = OC_EQ;
break;
default:
break;
}
arrput(n->out, build_expression(node->expr.binary.left));
arrput(n->out, build_expression(node->expr.binary.right));
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
const_fold(n);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_load(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
n->code = OC_LOAD;
arrput(n->out, current_memory);
arrput(n->out, build_expression(node));
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabebabecafe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_unary(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
switch (node->expr.unary.operator) {
case UOP_MINUS:
n->code = OC_NEG;
arrput(n->out, build_expression(node->expr.unary.right));
break;
case UOP_REF:
free(n);
if (node->expr.unary.right->type == NODE_IDENTIFIER) {
struct symbol_def *def = get_def(node->expr.unary.right->expr.string.start);
if (def) {
return def->node;
}
}
return build_expression(node->expr.unary.right);
case UOP_DEREF:
free(n);
return build_load(node->expr.unary.right);
default:
break;
}
if (n->out && n->out[0]->code == OC_CONST_INT) {
switch (n->code) {
case OC_NEG:
n->data.const_int = -(n->out[0]->data.const_int);
break;
default:
break;
}
n->code = OC_CONST_INT;
arrfree(n->out); n->out = NULL;
} else if (n->out && n->out[0]->code == OC_CONST_FLOAT) {
switch (n->code) {
case OC_NEG:
n->data.const_float = -(n->out[0]->data.const_float);
break;
default:
break;
}
n->code = OC_CONST_FLOAT;
arrfree(n->out); n->out = NULL;
}
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_if(ast_node *node)
{
ir_node *condition = build_expression(node->expr.if_stmt.condition);
ir_node *if_node = calloc(1, sizeof(ir_node));
if_node->code = OC_IF;
arrput(if_node->out, condition);
arrput(if_node->out, current_control);
if_node->id = stbds_hash_bytes(if_node, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *if_node, if_node);
ir_node *proj_true = calloc(1, sizeof(ir_node));
proj_true->code = OC_PROJ;
arrput(proj_true->out, if_node);
proj_true->id = stbds_hash_bytes(proj_true, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *proj_true, proj_true);
ir_node *proj_false = calloc(1, sizeof(ir_node));
proj_false->code = OC_PROJ;
arrput(proj_false->out, if_node);
proj_false->id = stbds_hash_bytes(proj_false, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *proj_false, proj_false);
ir_node *base_scope = copy_scope(current_scope);
ir_node *base_mem = current_memory;
current_control = proj_true;
ast_node *current = node->expr.if_stmt.body;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
ir_node *then_scope = current_scope;
ir_node *then_mem = current_memory;
ir_node *then_control = current_control;
current_scope = copy_scope(base_scope);
current_memory = base_mem;
current_control = proj_false;
current = node->expr.if_stmt.otherwise;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
ir_node *else_scope = current_scope;
ir_node *else_mem = current_memory;
ir_node *else_control = current_control;
ir_node *region = calloc(1, sizeof(ir_node));
region->code = OC_REGION;
arrput(region->out, then_control);
arrput(region->out, else_control);
region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *region, region);
if (then_mem->id != else_mem->id) {
ir_node *phi = calloc(1, sizeof(ir_node));
phi->code = OC_PHI;
arrput(phi->out, region);
arrput(phi->out, then_mem);
arrput(phi->out, else_mem);
phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *phi, phi);
current_memory = phi;
} else {
current_memory = then_mem;
}
current_scope = base_scope;
for (int i = 0; i < arrlen(current_scope->data.symbol_tables); i++) {
symbol_table *base_table = current_scope->data.symbol_tables[i];
for (int j = 0; j < shlen(base_table); j++) {
char *key = base_table[j].key;
ir_node *found_then = NULL;
symbol_table *t_table = then_scope->data.symbol_tables[i];
if (shget(t_table, key)->node) found_then = shget(t_table, key)->node;
else found_then = base_table[j].value->node;
ir_node *found_else = NULL;
symbol_table *e_table = else_scope->data.symbol_tables[i];
if (shget(e_table, key)->node) found_else = shget(e_table, key)->node;
else found_else = base_table[j].value->node;
if (found_then->id != found_else->id) {
ir_node *phi = calloc(1, sizeof(ir_node));
phi->code = OC_PHI;
arrput(phi->out, region);
arrput(phi->out, found_then);
arrput(phi->out, found_else);
phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe);
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->node = phi;
def->is_lvalue = false;
shput(current_scope->data.symbol_tables[i], key, def);
hmput(global_hash, *phi, phi);
} else {
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->node = found_then;
def->is_lvalue = false;
shput(current_scope->data.symbol_tables[i], key, def);
}
}
}
current_control = region;
return region;
}
static void build_return(ast_node *node)
{
ir_node *val = NULL;
if (node->expr.ret.value) {
val = build_expression(node->expr.ret.value);
} else {
val = calloc(1, sizeof(ir_node));
val->code = OC_VOID;
val->id = stbds_hash_bytes(val, sizeof(ir_node), 0xcafebabe);
}
arrput(current_func.return_controls, current_control);
arrput(current_func.return_memories, current_memory);
arrput(current_func.return_values, val);
current_control = NULL;
}
static void finalize_function(void)
{
int count = arrlen(current_func.return_controls);
if (count == 0) {
return;
}
ir_node *final_ctrl = NULL;
ir_node *final_mem = NULL;
ir_node *final_val = NULL;
if (count == 1) {
final_ctrl = current_func.return_controls[0];
final_mem = current_func.return_memories[0];
final_val = current_func.return_values[0];
}
else {
ir_node *region = calloc(1, sizeof(ir_node));
region->code = OC_REGION;
for (int i=0; i<count; i++) {
arrput(region->out, current_func.return_controls[i]);
}
hmput(global_hash, *region, region);
final_ctrl = region;
ir_node *mem_phi = calloc(1, sizeof(ir_node));
mem_phi->code = OC_PHI;
arrput(mem_phi->out, region);
for (int i=0; i<count; i++) {
arrput(mem_phi->out, current_func.return_memories[i]);
}
hmput(global_hash, *mem_phi, mem_phi);
mem_phi->id = stbds_hash_bytes(mem_phi, sizeof(ir_node), 0xcafebabe);
final_mem = mem_phi;
ir_node *val_phi = calloc(1, sizeof(ir_node));
val_phi->code = OC_PHI;
//arrput(val_phi->out, region);
for (int i=0; i<count; i++) {
arrput(val_phi->out, current_func.return_values[i]);
}
val_phi->id = stbds_hash_bytes(val_phi, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *val_phi, val_phi);
final_val = val_phi;
region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe);
}
ir_node *ret = calloc(1, sizeof(ir_node));
ret->code = OC_RETURN;
arrput(ret->out, final_ctrl);
arrput(ret->out, final_mem);
arrput(ret->out, final_val);
ret->id = stbds_hash_bytes(ret, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *ret, ret);
}
static ir_node *build_function(ast_node *node)
{
memset(&current_func, 0x0, sizeof(current_func));
ast_node *current = node->expr.function.body;
ir_node *func = calloc(1, sizeof(ir_node));
func->code = OC_START;
func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe);
func->data.start_name = node->expr.function.name;
ir_node *start_ctrl = calloc(1, sizeof(ir_node));
start_ctrl->code = OC_PROJ;
start_ctrl->id = stbds_hash_bytes(&start_ctrl, sizeof(usize), 0xcafebabe);
arrput(start_ctrl->out, func);
hmput(global_hash, *start_ctrl, start_ctrl);
current_control = start_ctrl;
ir_node *start_mem = calloc(1, sizeof(ir_node));
start_mem->code = OC_PROJ;
start_mem->id = stbds_hash_bytes(&start_mem, sizeof(usize), 0xcafebabe);
arrput(start_mem->out, func);
hmput(global_hash, *start_mem, start_mem);
current_memory = start_mem;
current_scope = calloc(1, sizeof(ir_node));
current_scope->code = OC_SCOPE;
push_scope();
member *m = node->expr.function.parameters;
while (m) {
ir_node *proj_param = calloc(1, sizeof(ir_node));
proj_param->code = OC_PROJ;
arrput(proj_param->out, func);
proj_param->id = stbds_hash_bytes(proj_param, sizeof(ir_node), 0xcafebabe);
set_def(m->name, proj_param, false);
hmput(global_hash, *proj_param, proj_param);
m = m->next;
}
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
build_expression(current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
func->id = stbds_hash_bytes(func, sizeof(ir_node), 0xcafebabe);
finalize_function();
return func;
}
static ir_node *build_expression(ast_node *node)
{
ir_node *n = NULL;
ir_node *tmp = NULL;
switch (node->type) {
case NODE_UNARY:
n = build_unary(node);
break;
case NODE_BINARY:
n = build_binary(node);
break;
case NODE_INTEGER:
n = calloc(1, sizeof(ir_node));
n->code = OC_CONST_INT;
n->data.const_int = node->expr.integer;
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
break;
case NODE_VAR_DECL:
n = calloc(1, sizeof(ir_node));
if (node->address_taken) {
n->code = OC_STORE;
arrput(n->out, current_memory);
arrput(n->out, build_address(-1, current_stack));
arrput(n->out, build_expression(node->expr.var_decl.value));
current_memory = n;
current_stack += node->expr_type->size;
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *n, n);
n = n->out[1];
set_def(node->expr.var_decl.name, n, true);
} else {
n = build_expression(node->expr.var_decl.value);
set_def(node->expr.var_decl.name, n, false);
}
return n;
case NODE_IDENTIFIER:
struct symbol_def *def = get_def(node->expr.string.start);
n = def->node;
if (n && def->is_lvalue) {
ir_node *addr_node = n;
n = calloc(1, sizeof(ir_node));
n->code = OC_LOAD;
arrput(n->out, current_memory);
arrput(n->out, addr_node);
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
n = tmp;
} else {
hmput(global_hash, *n, n);
}
}
break;
case NODE_IF:
n = build_if(node);
break;
case NODE_RETURN:
build_return(node);
break;
default:
break;
}
if (n) hmput(global_hash, *n, n);
return n;
}
void ir_build(ast_node *ast)
{
ast_node *current = ast;
graph = calloc(1, sizeof(ir_node));
graph->code = OC_START;
graph->id = stbds_hash_bytes(graph, sizeof(ir_node), 0xcafebabe);
graph->data.start_name = "program";
current_memory = calloc(1, sizeof(ir_node));
current_memory->code = OC_FRAME_PTR;
current_memory->id = stbds_hash_bytes(current_memory, sizeof(ir_node), 0xcafebabe);
current_scope = calloc(1, sizeof(ir_node));
current_scope->code = OC_SCOPE;
push_scope();
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) {
ir_node *expr = build_function(current->expr.unit_node.expr);
arrput(graph->out, expr);
hmput(global_hash, *expr, expr);
}
current = current->expr.unit_node.next;
}
printf("digraph G {\n");
print_graph(graph);
printf("}\n");
}