preliminary work on sea of nodes based intermediate representation.

This commit is contained in:
Lorenzo Torres 2025-12-07 16:48:48 +01:00
parent 989a32fa7b
commit 849e0b6863
13 changed files with 918 additions and 58 deletions

View file

@ -3,8 +3,8 @@
include config.mk include config.mk
SRC = lc.c utils.c lexer.c parser.c sema.c SRC = lc.c utils.c lexer.c parser.c sema.c ir.c
HDR = config.def.h utils.h lexer.h parser.h sema.h HDR = config.def.h utils.h lexer.h parser.h sema.h ir.h
OBJ = ${SRC:.c=.o} OBJ = ${SRC:.c=.o}
all: options lc all: options lc
@ -51,5 +51,9 @@ install: all
uninstall: uninstall:
rm -f ${DESTDIR}${PREFIX}/bin/lc\ rm -f ${DESTDIR}${PREFIX}/bin/lc\
${DESTDIR}${MANPREFIX}/man1/lc.1 ${DESTDIR}${MANPREFIX}/man1/lc.1
graph: clean all
./lc > graph.dot
dot -Tpdf graph.dot > graph.pdf
zathura ./graph.pdf
.PHONY: all options clean dist install uninstall .PHONY: all options clean dist install uninstall

View file

@ -15,7 +15,7 @@ INCS = -I.
LIBS = LIBS =
# flags # flags
CPPFLAGS = -DVERSION=\"${VERSION}\" CPPFLAGS = -DVERSION=\"${VERSION}\"
CFLAGS := -std=c99 -pedantic -Wall -O0 ${INCS} ${CPPFLAGS} CFLAGS := -std=c23 -pedantic -Wall -O0 ${INCS} ${CPPFLAGS}
CFLAGS := ${CFLAGS} -g CFLAGS := ${CFLAGS} -g
LDFLAGS = ${LIBS} LDFLAGS = ${LIBS}

662
ir.c Normal file
View file

@ -0,0 +1,662 @@
#include "ir.h"
#include <stdlib.h>
#include <stdio.h>
#include "stb_ds.h"
#include "sema.h"
struct { ir_node key; ir_node *value; } *global_hash = NULL;
static ir_node *graph;
static ir_node *current_memory;
static ir_node *current_control;
static usize current_stack = 0;
static ir_node *current_scope = NULL;
static ir_node *build_expression(ast_node *node);
static void node_name(ir_node *node)
{
if (!node) {
printf("null [label=\"NULL\", style=filled, fillcolor=red]\n");
return;
}
printf("%ld ", node->id);
switch (node->code) {
case OC_START:
printf("[label=\"start\", style=filled, color=orange]\n");
break;
case OC_ADD:
printf("[label=\"+\"]\n");
break;
case OC_NEG:
case OC_SUB:
printf("[label=\"-\"]\n");
break;
case OC_DIV:
printf("[label=\"/\"]\n");
break;
case OC_MUL:
printf("[label=\"*\"]\n");
break;
case OC_MOD:
printf("[label=\"%%\"]\n");
break;
case OC_BAND:
printf("[label=\"&\"]\n");
break;
case OC_BOR:
printf("[label=\"|\"]\n");
break;
case OC_BXOR:
printf("[label=\"^\"]\n");
break;
case OC_EQ:
printf("[label=\"==\"]\n");
break;
case OC_CONST_INT:
printf("[label=\"%ld\"]\n", node->data.const_int);
break;
case OC_CONST_FLOAT:
printf("[label=\"%f\"]\n", node->data.const_float);
break;
case OC_FRAME_PTR:
printf("[label=\"frame_ptr\"]\n");
break;
case OC_STORE:
printf("[label=\"store\", shape=box]\n");
break;
case OC_LOAD:
printf("[label=\"load\", shape=box]\n");
break;
case OC_ADDR:
printf("[label=\"addr\"]\n");
break;
case OC_REGION:
printf("[label=\"region\", shape=diamond, style=filled, color=green]\n");
break;
case OC_PHI:
printf("[label=\"phi\", shape=triangle]\n");
break;
case OC_IF:
printf("[label=\"if\", shape=diamond, style=filled, color=lightblue]\n");
break;
case OC_PROJ:
printf("[label=\"proj\", shape=diamond, style=filled, color=cyan]\n");
break;
default:
printf("[label=\"%d\"]\n", node->code);
break;
}
}
static void print_graph(ir_node *node)
{
for (int i = 0; i < hmlen(global_hash); i++) {
ir_node *node = global_hash[i].value;
node_name(node);
for (int j = 0; j < arrlen(node->out); j++) {
if (node->out[j]) {
node_name(node->out[j]);
printf("%ld->%ld\n", node->out[j]->id, node->id);
}
}
}
}
static void push_scope(void)
{
arrput(current_scope->data.symbol_tables, NULL);
}
static struct symbol_def *get_def(char *name)
{
for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) {
struct symbol_def *def = shget(current_scope->data.symbol_tables[i], name);
if (def) return def;
}
return NULL;
}
static void set_def(char *name, ir_node *node, bool lvalue)
{
for (int i = arrlen(current_scope->data.symbol_tables) - 1; i >= 0; i--) {
if (shget(current_scope->data.symbol_tables[i], name)) {
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->is_lvalue = lvalue;
def->node = node;
shput(current_scope->data.symbol_tables[i], name, def);
return;
}
}
int index = arrlen(current_scope->data.symbol_tables) - 1;
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->is_lvalue = lvalue;
def->node = node;
shput(current_scope->data.symbol_tables[index], name, def);
}
static ir_node *copy_scope(ir_node *src)
{
ir_node *dst = calloc(1, sizeof(ir_node));
dst->code = OC_SCOPE;
for (int i=0; i < arrlen(src->data.symbol_tables); i++) {
arrput(dst->data.symbol_tables, NULL);
symbol_table *src_table = src->data.symbol_tables[i];
for (int j=0; j < shlen(src_table); j++) {
shput(dst->data.symbol_tables[i], src_table[j].key, src_table[j].value);
}
}
return dst;
}
static void const_fold(ir_node *binary)
{
ir_node *left = binary->out[0];
ir_node *right = binary->out[1];
if (left->code == OC_CONST_INT && right->code == OC_CONST_INT) {
switch (binary->code) {
case OC_ADD:
binary->data.const_int = left->data.const_int + right->data.const_int;
break;
case OC_SUB:
binary->data.const_int = left->data.const_int - right->data.const_int;
break;
case OC_MUL:
binary->data.const_int = left->data.const_int * right->data.const_int;
break;
case OC_DIV:
if (right->data.const_int != 0)
binary->data.const_int = left->data.const_int / right->data.const_int;
break;
case OC_MOD:
if (right->data.const_int != 0)
binary->data.const_int = left->data.const_int % right->data.const_int;
break;
case OC_BOR:
binary->data.const_int = left->data.const_int | right->data.const_int;
break;
case OC_BAND:
binary->data.const_int = left->data.const_int & right->data.const_int;
break;
case OC_BXOR:
binary->data.const_int = left->data.const_int ^ right->data.const_int;
break;
case OC_EQ:
binary->data.const_int = left->data.const_int == right->data.const_int;
break;
default:
return;
}
binary->code = OC_CONST_INT;
arrfree(binary->out); binary->out = NULL;
arrfree(binary->in); binary->in = NULL;
binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe);
}
if (left->code == OC_CONST_FLOAT && right->code == OC_CONST_FLOAT) {
switch (binary->code) {
case OC_ADD:
binary->data.const_float = left->data.const_float + right->data.const_float;
break;
case OC_SUB:
binary->data.const_float = left->data.const_float - right->data.const_float;
break;
case OC_MUL:
binary->data.const_float = left->data.const_float * right->data.const_float;
break;
case OC_DIV:
if (right->data.const_float != 0.0f)
binary->data.const_float = left->data.const_float / right->data.const_float;
break;
default:
return;
}
binary->code = OC_CONST_FLOAT;
arrfree(binary->out); binary->out = NULL;
arrfree(binary->in); binary->in = NULL;
binary->id = stbds_hash_bytes(binary, sizeof(ir_node), 0xcafebabe);
}
}
static ir_node *build_address(usize base, usize offset) {
ir_node *addr = calloc(1, sizeof(ir_node));
addr->code = OC_ADDR;
ir_node *base_node = calloc(1, sizeof(ir_node));
if (base == -1) {
base_node->code = OC_FRAME_PTR;
base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe);
} else {
base_node->code = OC_CONST_INT;
base_node->data.const_int = base;
base_node->id = stbds_hash_bytes(base_node, sizeof(ir_node), 0xcafebabe);
}
ir_node *offset_node = calloc(1, sizeof(ir_node));
offset_node->code = OC_CONST_INT;
offset_node->data.const_int = offset;
offset_node->id = stbds_hash_bytes(offset_node, sizeof(ir_node), 0xcafebabe);
arrput(addr->out, base_node);
arrput(addr->out, offset_node);
addr->id = stbds_hash_bytes(addr, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *addr);
if (tmp) {
free(addr);
return tmp;
}
return addr;
}
static ir_node *build_assign_ptr(ast_node *binary)
{
ir_node *val_node = build_expression(binary->expr.binary.right);
char *var_name = binary->expr.binary.left->expr.string.start;
ir_node *existing_def = get_def(var_name)->node;
ir_node *store = calloc(1, sizeof(ir_node));
store->code = OC_STORE;
arrput(store->out, current_memory);
arrput(store->out, existing_def);
arrput(store->out, val_node);
store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *store, store);
current_memory = store;
return val_node;
}
static ir_node *build_assign(ast_node *binary)
{
ir_node *val_node = build_expression(binary->expr.binary.right);
char *var_name = binary->expr.binary.left->expr.string.start;
struct symbol_def *def = get_def(var_name);
if (def && def->is_lvalue) {
ir_node *existing_def = def->node;
ir_node *store = calloc(1, sizeof(ir_node));
store->code = OC_STORE;
arrput(store->out, current_memory);
arrput(store->out, existing_def);
arrput(store->out, val_node);
store->id = stbds_hash_bytes(store, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *store, store);
current_memory = store;
return val_node;
}
set_def(var_name, val_node, false);
return val_node;
}
static ir_node *build_binary(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
switch (node->expr.binary.operator) {
case OP_ASSIGN:
free(n);
return build_assign(node);
case OP_ASSIGN_PTR:
free(n);
return build_assign_ptr(node);
case OP_PLUS:
n->code = OC_ADD;
break;
case OP_MINUS:
n->code = OC_SUB;
break;
case OP_DIV:
n->code = OC_DIV;
break;
case OP_MUL:
n->code = OC_MUL;
break;
case OP_MOD:
n->code = OC_MOD;
break;
case OP_BOR:
n->code = OC_BOR;
break;
case OP_BAND:
n->code = OC_BAND;
break;
case OP_BXOR:
n->code = OC_BXOR;
break;
case OP_EQ:
n->code = OC_EQ;
break;
default:
break;
}
arrput(n->out, build_expression(node->expr.binary.left));
arrput(n->out, build_expression(node->expr.binary.right));
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
const_fold(n);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_load(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
n->code = OC_LOAD;
arrput(n->out, current_memory);
arrput(n->out, build_expression(node));
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabebabecafe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_unary(ast_node *node)
{
ir_node *n = calloc(1, sizeof(ir_node));
switch (node->expr.unary.operator) {
case UOP_MINUS:
n->code = OC_NEG;
arrput(n->out, build_expression(node->expr.unary.right));
break;
case UOP_REF:
free(n);
if (node->expr.unary.right->type == NODE_IDENTIFIER) {
struct symbol_def *def = get_def(node->expr.unary.right->expr.string.start);
if (def) {
return def->node;
}
}
return build_expression(node->expr.unary.right);
case UOP_DEREF:
free(n);
return build_load(node->expr.unary.right);
default:
break;
}
if (n->out && n->out[0]->code == OC_CONST_INT) {
switch (n->code) {
case OC_NEG:
n->data.const_int = -(n->out[0]->data.const_int);
break;
default:
break;
}
n->code = OC_CONST_INT;
arrfree(n->out); n->out = NULL;
} else if (n->out && n->out[0]->code == OC_CONST_FLOAT) {
switch (n->code) {
case OC_NEG:
n->data.const_float = -(n->out[0]->data.const_float);
break;
default:
break;
}
n->code = OC_CONST_FLOAT;
arrfree(n->out); n->out = NULL;
}
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
return n;
}
static ir_node *build_if(ast_node *node)
{
ir_node *condition = build_expression(node->expr.if_stmt.condition);
ir_node *if_node = calloc(1, sizeof(ir_node));
if_node->code = OC_IF;
arrput(if_node->out, condition);
arrput(if_node->out, current_control);
if_node->id = stbds_hash_bytes(if_node, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *if_node, if_node);
ir_node *proj_true = calloc(1, sizeof(ir_node));
proj_true->code = OC_PROJ;
arrput(proj_true->out, if_node);
proj_true->id = stbds_hash_bytes(proj_true, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *proj_true, proj_true);
ir_node *proj_false = calloc(1, sizeof(ir_node));
proj_false->code = OC_PROJ;
arrput(proj_false->out, if_node);
proj_false->id = stbds_hash_bytes(proj_false, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *proj_false, proj_false);
ir_node *base_scope = copy_scope(current_scope);
ir_node *base_mem = current_memory;
current_control = proj_true;
ast_node *current = node->expr.if_stmt.body;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
ir_node *expr = build_expression(current->expr.unit_node.expr);
arrput(graph->out, expr);
}
current = current->expr.unit_node.next;
}
ir_node *then_scope = current_scope;
ir_node *then_mem = current_memory;
ir_node *then_control = current_control;
current_scope = copy_scope(base_scope);
current_memory = base_mem;
current_control = proj_false;
current = node->expr.if_stmt.otherwise;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
ir_node *expr = build_expression(current->expr.unit_node.expr);
arrput(graph->out, expr);
}
current = current->expr.unit_node.next;
}
ir_node *else_scope = current_scope;
ir_node *else_mem = current_memory;
ir_node *else_control = current_control;
ir_node *region = calloc(1, sizeof(ir_node));
region->code = OC_REGION;
arrput(region->out, then_control);
arrput(region->out, else_control);
region->id = stbds_hash_bytes(region, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *region, region);
if (then_mem->id != else_mem->id) {
ir_node *phi = calloc(1, sizeof(ir_node));
phi->code = OC_PHI;
arrput(phi->out, region);
arrput(phi->out, then_mem);
arrput(phi->out, else_mem);
phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *phi, phi);
current_memory = phi;
} else {
current_memory = then_mem;
}
current_scope = base_scope;
for (int i = 0; i < arrlen(current_scope->data.symbol_tables); i++) {
symbol_table *base_table = current_scope->data.symbol_tables[i];
for (int j = 0; j < shlen(base_table); j++) {
char *key = base_table[j].key;
ir_node *found_then = NULL;
symbol_table *t_table = then_scope->data.symbol_tables[i];
if (shget(t_table, key)->node) found_then = shget(t_table, key)->node;
else found_then = base_table[j].value->node;
ir_node *found_else = NULL;
symbol_table *e_table = else_scope->data.symbol_tables[i];
if (shget(e_table, key)->node) found_else = shget(e_table, key)->node;
else found_else = base_table[j].value->node;
if (found_then->id != found_else->id) {
ir_node *phi = calloc(1, sizeof(ir_node));
phi->code = OC_PHI;
arrput(phi->out, region);
arrput(phi->out, found_then);
arrput(phi->out, found_else);
phi->id = stbds_hash_bytes(phi, sizeof(ir_node), 0xcafebabe);
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->node = phi;
def->is_lvalue = false;
shput(current_scope->data.symbol_tables[i], key, def);
hmput(global_hash, *phi, phi);
} else {
struct symbol_def *def = calloc(1, sizeof(struct symbol_def));
def->node = found_then;
def->is_lvalue = false;
shput(current_scope->data.symbol_tables[i], key, def);
}
}
}
current_control = region;
return region;
}
static ir_node *build_expression(ast_node *node)
{
ir_node *n = NULL;
ir_node *tmp = NULL;
switch (node->type) {
case NODE_UNARY:
n = build_unary(node);
break;
case NODE_BINARY:
n = build_binary(node);
break;
case NODE_INTEGER:
n = calloc(1, sizeof(ir_node));
n->code = OC_CONST_INT;
n->data.const_int = node->expr.integer;
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
return tmp;
}
break;
case NODE_VAR_DECL:
n = calloc(1, sizeof(ir_node));
if (node->address_taken) {
n->code = OC_STORE;
arrput(n->out, current_memory);
arrput(n->out, build_address(-1, current_stack));
arrput(n->out, build_expression(node->expr.var_decl.value));
current_memory = n;
current_stack += node->expr_type->size;
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
hmput(global_hash, *n, n);
n = n->out[1];
set_def(node->expr.var_decl.name, n, true);
} else {
n = build_expression(node->expr.var_decl.value);
set_def(node->expr.var_decl.name, n, false);
}
return n;
case NODE_IDENTIFIER:
struct symbol_def *def = get_def(node->expr.string.start);
n = def->node;
if (n && def->is_lvalue) {
ir_node *addr_node = n;
n = calloc(1, sizeof(ir_node));
n->code = OC_LOAD;
arrput(n->out, current_memory);
arrput(n->out, addr_node);
n->id = stbds_hash_bytes(n, sizeof(ir_node), 0xcafebabe);
ir_node *tmp = hmget(global_hash, *n);
if (tmp) {
free(n);
n = tmp;
} else {
hmput(global_hash, *n, n);
}
}
break;
case NODE_IF:
n = build_if(node);
break;
default:
break;
}
if (n) hmput(global_hash, *n, n);
return n;
}
void ir_build(ast_node *ast)
{
ast_node *current = ast;
graph = calloc(1, sizeof(ir_node));
graph->code = OC_START;
graph->id = stbds_hash_bytes(graph, sizeof(ir_node), 0xcafebabe);
current_control = graph;
current_memory = calloc(1, sizeof(ir_node));
current_memory->code = OC_FRAME_PTR;
current_memory->id = stbds_hash_bytes(current_memory, sizeof(ir_node), 0xcafebabe);
current_scope = calloc(1, sizeof(ir_node));
current_scope->code = OC_SCOPE;
push_scope();
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr) {
ir_node *expr = build_expression(current->expr.unit_node.expr);
arrput(graph->out, expr);
}
current = current->expr.unit_node.next;
}
printf("digraph G {\n");
print_graph(graph);
printf("}\n");
}

63
ir.h Normal file
View file

@ -0,0 +1,63 @@
#ifndef IR_H
#define IR_H
#include "utils.h"
#include "parser.h"
struct _ir_node;
struct symbol_def {
struct _ir_node *node;
bool is_lvalue;
};
typedef struct { char *key; struct symbol_def *value; } symbol_table;
typedef enum {
OC_START,
OC_ADD,
OC_SUB,
OC_MUL,
OC_DIV,
OC_MOD,
OC_BAND,
OC_BOR,
OC_BXOR,
OC_NEG,
OC_EQ,
OC_CONST_INT,
OC_CONST_FLOAT,
OC_FRAME_PTR,
OC_ADDR,
OC_STORE,
OC_LOAD,
OC_REGION,
OC_PHI,
OC_IF,
OC_PROJ,
OC_STOP,
OC_RETURN,
OC_SCOPE,
} opcode;
typedef struct _ir_node {
opcode code;
usize id;
struct _ir_node **in;
struct _ir_node **out;
union {
i64 const_int;
f64 const_float;
symbol_table **symbol_tables;
} data;
} ir_node;
void ir_build(ast_node *ast);
#endif

9
lc.c
View file

@ -4,6 +4,7 @@
#include "lexer.h" #include "lexer.h"
#include "parser.h" #include "parser.h"
#include "sema.h" #include "sema.h"
#include "ir.h"
void print_indent(int depth) { void print_indent(int depth) {
for (int i = 0; i < depth; i++) printf(" "); for (int i = 0; i < depth; i++) printf(" ");
@ -17,6 +18,7 @@ const char* get_op_str(binary_op op) {
case OP_MUL: return "*"; case OP_MUL: return "*";
case OP_EQ: return "=="; case OP_EQ: return "==";
case OP_ASSIGN: return "="; case OP_ASSIGN: return "=";
case OP_ASSIGN_PTR: return "<-";
case OP_AND: return "&&"; case OP_AND: return "&&";
case OP_OR: return "||"; case OP_OR: return "||";
case OP_NEQ: return "!="; case OP_NEQ: return "!=";
@ -215,7 +217,7 @@ void print_ast(ast_node *node, int depth) {
int main(void) int main(void)
{ {
FILE *fp = fopen("examples/hello_world.l", "r"); FILE *fp = fopen("test.l", "r");
usize size = 0; usize size = 0;
fseek(fp, 0, SEEK_END); fseek(fp, 0, SEEK_END);
size = ftell(fp); size = ftell(fp);
@ -228,8 +230,11 @@ int main(void)
arena a = arena_init(0x1000 * 0x1000 * 64); arena a = arena_init(0x1000 * 0x1000 * 64);
lexer *l = lexer_init(src, size, &a); lexer *l = lexer_init(src, size, &a);
parser *p = parser_init(l, &a); parser *p = parser_init(l, &a);
print_ast(p->ast, 0); //print_ast(p->ast, 0);
sema *s = sema_init(p, &a); sema *s = sema_init(p, &a);
//(void) s;
ir_build(p->ast);
arena_deinit(a); arena_deinit(a);

View file

@ -133,9 +133,6 @@ static bool parse_special(lexer *l)
} else if (l->source[l->index+1] == '-') { } else if (l->source[l->index+1] == '-') {
add_token(l, TOKEN_MINUS_MINUS, 2); add_token(l, TOKEN_MINUS_MINUS, 2);
l->index += 2; l->index += 2;
} else if (l->source[l->index+1] == '>') {
add_token(l, TOKEN_ARROW, 2);
l->index += 2;
} else { } else {
add_token(l, TOKEN_MINUS, 1); add_token(l, TOKEN_MINUS, 1);
l->index += 1; l->index += 1;
@ -231,6 +228,9 @@ static bool parse_special(lexer *l)
if (l->source[l->index+1] == '=') { if (l->source[l->index+1] == '=') {
add_token(l, TOKEN_LESS_EQ, 2); add_token(l, TOKEN_LESS_EQ, 2);
l->index += 2; l->index += 2;
} else if (l->source[l->index+1] == '-') {
add_token(l, TOKEN_ARROW, 2);
l->index += 2;
} else if (l->source[l->index+1] == '<') { } else if (l->source[l->index+1] == '<') {
if (l->source[l->index+2] == '=') { if (l->source[l->index+2] == '=') {
add_token(l, TOKEN_LSHIFT_EQ, 3); add_token(l, TOKEN_LSHIFT_EQ, 3);

View file

@ -16,10 +16,10 @@ typedef enum {
TOKEN_AND, // & TOKEN_AND, // &
TOKEN_HAT, // ^ TOKEN_HAT, // ^
TOKEN_PIPE, // | TOKEN_PIPE, // |
TOKEN_ARROW, // ->
TOKEN_LSHIFT, // << TOKEN_LSHIFT, // <<
TOKEN_RSHIFT, // >> TOKEN_RSHIFT, // >>
TOKEN_DOUBLE_EQ, // == TOKEN_DOUBLE_EQ, // ==
TOKEN_ARROW, // <-
TOKEN_EQ, // = TOKEN_EQ, // =
TOKEN_LESS_THAN, // < TOKEN_LESS_THAN, // <
TOKEN_GREATER_THAN, // > TOKEN_GREATER_THAN, // >

View file

@ -367,9 +367,21 @@ ast_node *parse_term(parser *p)
{ {
ast_node *left = parse_unary(p); ast_node *left = parse_unary(p);
while (match_peek(p, TOKEN_STAR) || match_peek(p, TOKEN_SLASH)) while (match_peek(p, TOKEN_STAR) || match_peek(p, TOKEN_SLASH) || match_peek(p, TOKEN_PERC)) {
{ binary_op op;
binary_op op = peek(p)->type == TOKEN_STAR ? OP_MUL : OP_DIV; switch (peek(p)->type) {
case TOKEN_STAR:
op = OP_MUL;
break;
case TOKEN_SLASH:
op = OP_DIV;
break;
case TOKEN_PERC:
op = OP_MOD;
break;
default:
continue;
}
advance(p); advance(p);
ast_node *right = parse_factor(p); ast_node *right = parse_factor(p);
ast_node *node = arena_alloc(p->allocator, sizeof(ast_node)); ast_node *node = arena_alloc(p->allocator, sizeof(ast_node));
@ -561,6 +573,9 @@ ast_node *parse_expression(parser *p)
binary_op op; binary_op op;
switch (p->tokens->type) switch (p->tokens->type)
{ {
case TOKEN_ARROW:
op = OP_ASSIGN_PTR;
break;
case TOKEN_EQ: case TOKEN_EQ:
op = OP_ASSIGN; op = OP_ASSIGN;
break; break;
@ -854,8 +869,12 @@ static ast_node *parse_if(parser *p)
ast_node *node = arena_alloc(p->allocator, sizeof(ast_node)); ast_node *node = arena_alloc(p->allocator, sizeof(ast_node));
node->type = NODE_IF; node->type = NODE_IF;
node->position = p->previous->position; node->position = p->previous->position;
node->expr.whle.body = body; node->expr.if_stmt.body = body;
node->expr.whle.condition = condition; node->expr.if_stmt.condition = condition;
if (match(p, TOKEN_ELSE)) {
body = parse_compound(p);
node->expr.if_stmt.otherwise = body;
}
return node; return node;
} }
@ -1302,11 +1321,6 @@ static void parse(parser *p)
ast_node *tail = p->ast; ast_node *tail = p->ast;
ast_node *expr = parse_statement(p); ast_node *expr = parse_statement(p);
while (expr) { while (expr) {
if (expr->type != NODE_FUNCTION && expr->type != NODE_VAR_DECL && expr->type != NODE_IMPORT &&
expr->type != NODE_STRUCT && expr->type != NODE_UNION && expr->type != NODE_ENUM) {
error(p, "expected function, struct, enum, union, global variable or import statement.");
return;
}
tail->expr.unit_node.next = arena_alloc(p->allocator, sizeof(ast_node)); tail->expr.unit_node.next = arena_alloc(p->allocator, sizeof(ast_node));
tail->expr.unit_node.next->expr.unit_node.expr = expr; tail->expr.unit_node.next->expr.unit_node.expr = expr;
tail = tail->expr.unit_node.next; tail = tail->expr.unit_node.next;

View file

@ -19,6 +19,7 @@ typedef enum {
OP_BXOR, // ^ OP_BXOR, // ^
OP_ASSIGN, // = OP_ASSIGN, // =
OP_ASSIGN_PTR, // <-
OP_RSHIFT_EQ, // >>= OP_RSHIFT_EQ, // >>=
OP_LSHIFT_EQ, // <<= OP_LSHIFT_EQ, // <<=
OP_PLUS_EQ, // += OP_PLUS_EQ, // +=
@ -126,6 +127,7 @@ typedef struct _ast_node {
node_type type; node_type type;
source_pos position; source_pos position;
struct _type *expr_type; struct _type *expr_type;
bool address_taken; // used in IR generation.
union { union {
struct { struct {
struct _ast_node *type; struct _ast_node *type;
@ -200,6 +202,12 @@ typedef struct _ast_node {
struct _ast_node *body; struct _ast_node *body;
u8 flags; u8 flags;
} whle; // while } whle; // while
struct {
struct _ast_node *condition;
struct _ast_node *body;
struct _ast_node *otherwise;
u8 flags;
} if_stmt; // while
struct { struct {
struct _ast_node **statements; struct _ast_node **statements;
usize stmt_len; usize stmt_len;

161
sema.c
View file

@ -27,7 +27,6 @@ static type *const_float = NULL;
static bool in_loop = false; static bool in_loop = false;
/* Print the error message and sync the parser. */
static void error(ast_node *n, char *msg) static void error(ast_node *n, char *msg)
{ {
if (n) { if (n) {
@ -78,7 +77,6 @@ static type *create_float(sema *s, char *name, u8 bits)
return t; return t;
} }
/* https://en.wikipedia.org/wiki/Topological_sorting */
static void order_type(sema *s, ast_node *node) static void order_type(sema *s, ast_node *node)
{ {
if (node->type == NODE_STRUCT || node->type == NODE_UNION) { if (node->type == NODE_STRUCT || node->type == NODE_UNION) {
@ -350,12 +348,12 @@ static void pop_scope(sema *s)
current_scope = current_scope->parent; current_scope = current_scope->parent;
} }
static type *get_def(sema *s, char *name) static ast_node *get_def(sema *s, char *name)
{ {
scope *current = current_scope; scope *current = current_scope;
while (current) { while (current) {
type *t = shget(current->defs, name); ast_node *def = shget(current->defs, name);
if (t) return t; if (def) return def;
current = current->parent; current = current->parent;
} }
@ -416,11 +414,13 @@ static type *get_identifier_type(sema *s, ast_node *node)
{ {
char *name_start = node->expr.string.start; char *name_start = node->expr.string.start;
usize name_len = node->expr.string.len; usize name_len = node->expr.string.len;
type *t = get_def(s, intern_string(s, name_start, name_len)); char *name = intern_string(s, name_start, name_len);
if (!t) { node->expr.string.start = name;
ast_node *def = get_def(s, name);
if (!def) {
error(node, "unknown identifier."); error(node, "unknown identifier.");
} }
return t; return def->expr_type;
} }
static bool match(type *t1, type *t2); static bool match(type *t1, type *t2);
@ -450,60 +450,129 @@ static type *get_expression_type(sema *s, ast_node *node)
prototype *prot = NULL; prototype *prot = NULL;
switch (node->type) { switch (node->type) {
case NODE_IDENTIFIER: case NODE_IDENTIFIER:
return get_identifier_type(s, node); t = get_identifier_type(s, node);
node->expr_type = t;
return t;
case NODE_INTEGER: case NODE_INTEGER:
node->expr_type = const_int;
return const_int; return const_int;
case NODE_FLOAT: case NODE_FLOAT:
node->expr_type = const_float;
return const_float; return const_float;
case NODE_STRING: case NODE_STRING:
return get_string_type(s, node); t = get_string_type(s, node);
node->expr_type = t;
return t;
case NODE_CHAR: case NODE_CHAR:
return shget(type_reg, "u8"); t = shget(type_reg, "u8");
node->expr_type = t;
return t;
case NODE_BOOL: case NODE_BOOL:
return shget(type_reg, "bool"); t = shget(type_reg, "bool");
node->expr_type = t;
return t;
case NODE_CAST: case NODE_CAST:
return get_type(s, node->expr.cast.type); t = get_type(s, node->expr.cast.type);
node->expr_type = t;
return t;
case NODE_POSTFIX: case NODE_POSTFIX:
case NODE_UNARY: case NODE_UNARY:
return get_expression_type(s, node->expr.unary.right); t = get_expression_type(s, node->expr.unary.right);
if (node->expr.unary.operator == UOP_REF) {
ast_node *target = node->expr.unary.right;
while (target->type == NODE_ACCESS) {
target = target->expr.access.expr;
}
if (target->type != NODE_IDENTIFIER) {
error(node, "expected identifier.");
return NULL;
}
char *name = target->expr.string.start;
ast_node *def = get_def(s, name);
if (def) {
def->address_taken = true;
target->address_taken = true;
}
type *tmp = t;
t = arena_alloc(s->allocator, sizeof(type));
t->tag = TYPE_PTR;
t->size = sizeof(usize);
t->alignment = sizeof(usize);
t->name = "ptr";
t->data.ptr.is_const = false;
t->data.ptr.is_volatile = false;
t->data.ptr.child = tmp;
} else if (node->expr.unary.operator == UOP_DEREF) {
if (t->tag != TYPE_PTR) {
error(node, "only pointers can be dereferenced.");
return NULL;
}
t = t->data.ptr.child;
}
node->expr_type = t;
return t;
case NODE_BINARY: case NODE_BINARY:
t = get_expression_type(s, node->expr.binary.left); t = get_expression_type(s, node->expr.binary.left);
if (!t) return NULL; if (!t) return NULL;
if (!match(t, get_expression_type(s, node->expr.binary.right))) { if (node->expr.binary.operator == OP_ASSIGN_PTR) {
if (t->tag != TYPE_PTR) {
error(node, "expected pointer.");
return NULL;
}
t = t->data.ptr.child;
}
if (!can_cast(get_expression_type(s, node->expr.binary.right), t) && !match(t, get_expression_type(s, node->expr.binary.right))) {
error(node, "type mismatch."); error(node, "type mismatch.");
node->expr_type = NULL;
return NULL; return NULL;
} }
if (node->expr.binary.operator >= OP_EQ) { if (node->expr.binary.operator >= OP_EQ) {
return shget(type_reg, "bool"); t = shget(type_reg, "bool");
} else if (node->expr.binary.operator >= OP_ASSIGN && node->expr.binary.operator <= OP_MOD_EQ) { } else if (node->expr.binary.operator >= OP_ASSIGN && node->expr.binary.operator <= OP_MOD_EQ) {
return shget(type_reg, "void"); t = shget(type_reg, "void");
} else {
return t;
} }
node->expr_type = t;
return t;
case NODE_RANGE: case NODE_RANGE:
return get_range_type(s, node); t = get_range_type(s, node);
node->expr_type = t;
return t;
case NODE_ARRAY_SUBSCRIPT: case NODE_ARRAY_SUBSCRIPT:
t = get_expression_type(s, node->expr.subscript.expr); t = get_expression_type(s, node->expr.subscript.expr);
switch (t->tag) { switch (t->tag) {
case TYPE_SLICE: case TYPE_SLICE:
return t->data.slice.child; t = t->data.slice.child;
break;
case TYPE_PTR: case TYPE_PTR:
return t->data.ptr.child; t = t->data.ptr.child;
break;
default: default:
error(node, "only pointers and slices can be indexed."); error(node, "only pointers and slices can be indexed.");
return NULL; return NULL;
} }
node->expr_type = t;
return t;
case NODE_CALL: case NODE_CALL:
prot = shget(prototypes, intern_string(s, node->expr.call.name, node->expr.call.name_len)); prot = shget(prototypes, intern_string(s, node->expr.call.name, node->expr.call.name_len));
if (!prot) { if (!prot) {
error(node, "unknown function."); error(node, "unknown function.");
return NULL; return NULL;
} }
return prot->type; t = prot->type;
node->expr_type = t;
return t;
case NODE_ACCESS: case NODE_ACCESS:
return get_access_type(s, node); t = get_access_type(s, node);
node->expr_type = t;
return t;
default: default:
return shget(type_reg, "void"); t = shget(type_reg, "void");
node->expr_type = t;
return t;
} }
} }
@ -567,7 +636,14 @@ static void check_for(sema *s, ast_node *node)
while (current_capture) { while (current_capture) {
type *c_type = get_expression_type(s, current_slice->expr.unit_node.expr); type *c_type = get_expression_type(s, current_slice->expr.unit_node.expr);
char *c_name = intern_string(s, current_capture->expr.unit_node.expr->expr.string.start, current_capture->expr.unit_node.expr->expr.string.len); char *c_name = intern_string(s, current_capture->expr.unit_node.expr->expr.string.start, current_capture->expr.unit_node.expr->expr.string.len);
shput(current_scope->defs, c_name, c_type);
ast_node *cap_node = arena_alloc(s->allocator, sizeof(ast_node));
cap_node->type = NODE_VAR_DECL;
cap_node->expr_type = c_type;
cap_node->address_taken = false;
cap_node->expr.var_decl.name = c_name;
shput(current_scope->defs, c_name, cap_node);
current_capture = current_capture->expr.unit_node.next; current_capture = current_capture->expr.unit_node.next;
current_slice = current_slice->expr.unit_node.next; current_slice = current_slice->expr.unit_node.next;
} }
@ -611,12 +687,23 @@ static void check_statement(sema *s, ast_node *node)
check_body(s, node->expr.whle.body); check_body(s, node->expr.whle.body);
in_loop = false; in_loop = false;
break; break;
case NODE_IF:
if (!match(get_expression_type(s, node->expr.if_stmt.condition), shget(type_reg, "bool"))) {
error(node, "expected boolean value.");
return;
}
check_body(s, node->expr.if_stmt.body);
if (node->expr.if_stmt.otherwise) check_body(s, node->expr.if_stmt.otherwise);
break;
case NODE_FOR: case NODE_FOR:
check_for(s, node); check_for(s, node);
break; break;
case NODE_VAR_DECL: case NODE_VAR_DECL:
t = get_type(s, node->expr.var_decl.type); t = get_type(s, node->expr.var_decl.type);
node->expr_type = t;
name = intern_string(s, node->expr.var_decl.name, node->expr.var_decl.name_len); name = intern_string(s, node->expr.var_decl.name, node->expr.var_decl.name_len);
node->expr.var_decl.name = name;
if (get_def(s, name)) { if (get_def(s, name)) {
error(node, "redeclaration of variable."); error(node, "redeclaration of variable.");
break; break;
@ -624,7 +711,7 @@ static void check_statement(sema *s, ast_node *node)
if (!can_cast(get_expression_type(s, node->expr.var_decl.value), t) && !match(t, get_expression_type(s, node->expr.var_decl.value))) { if (!can_cast(get_expression_type(s, node->expr.var_decl.value), t) && !match(t, get_expression_type(s, node->expr.var_decl.value))) {
error(node, "type mismatch."); error(node, "type mismatch.");
} }
shput(current_scope->defs, name, t); shput(current_scope->defs, name, node);
break; break;
default: default:
get_expression_type(s, node); get_expression_type(s, node);
@ -641,7 +728,14 @@ static void check_function(sema *s, ast_node *f)
while (param) { while (param) {
type *p_type = get_type(s, param->type); type *p_type = get_type(s, param->type);
char *t_name = intern_string(s, param->name, param->name_len); char *t_name = intern_string(s, param->name, param->name_len);
shput(current_scope->defs, t_name, p_type);
ast_node *param_node = arena_alloc(s->allocator, sizeof(ast_node));
param_node->type = NODE_VAR_DECL;
param_node->expr_type = p_type;
param_node->address_taken = false;
param_node->expr.var_decl.name = t_name;
shput(current_scope->defs, t_name, param_node);
param = param->next; param = param->next;
} }
@ -658,7 +752,8 @@ static void analyze_unit(sema *s, ast_node *node)
{ {
ast_node *current = node; ast_node *current = node;
while (current && current->type == NODE_UNIT) { while (current && current->type == NODE_UNIT) {
order_type(s, current->expr.unit_node.expr); if (current->expr.unit_node.expr)
order_type(s, current->expr.unit_node.expr);
current = current->expr.unit_node.next; current = current->expr.unit_node.next;
} }
@ -666,7 +761,7 @@ static void analyze_unit(sema *s, ast_node *node)
current = node; current = node;
while (current && current->type == NODE_UNIT) { while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr->type == NODE_FUNCTION) { if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) {
create_prototype(s, current->expr.unit_node.expr); create_prototype(s, current->expr.unit_node.expr);
} }
current = current->expr.unit_node.next; current = current->expr.unit_node.next;
@ -674,8 +769,10 @@ static void analyze_unit(sema *s, ast_node *node)
current = node; current = node;
while (current && current->type == NODE_UNIT) { while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr->type == NODE_FUNCTION) { if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) {
check_function(s, current->expr.unit_node.expr); check_function(s, current->expr.unit_node.expr);
} else {
check_statement(s, current->expr.unit_node.expr);
} }
current = current->expr.unit_node.next; current = current->expr.unit_node.next;
} }
@ -720,5 +817,3 @@ sema *sema_init(parser *p, arena *a)
return s; return s;
} }

2
sema.h
View file

@ -63,7 +63,7 @@ typedef struct {
typedef struct _scope { typedef struct _scope {
struct _scope *parent; struct _scope *parent;
struct { char *key; type *value; } *defs; struct { char *key; ast_node *value; } *defs;
} scope; } scope;
typedef struct { typedef struct {

15
test.l
View file

@ -1,5 +1,12 @@
u32 a() u32 a = 2;
{
[u32] v = {1, 2, 3}; if (a == 3) {
return z[0]; a = 5;
if (a == 4) {
a = 3;
}
} else {
a = 1;
} }
u32 d = a;

View file

@ -112,10 +112,12 @@ static usize align_forward(usize ptr, usize align) {
arena arena_init(usize size) arena arena_init(usize size)
{ {
void *memory = malloc(size);
memset(memory, 0x0, size);
return (arena){ return (arena){
.capacity = size, .capacity = size,
.position = 0, .position = 0,
.memory = malloc(size), .memory = memory,
}; };
} }