preliminary work on sea of nodes based intermediate representation.

This commit is contained in:
Lorenzo Torres 2025-12-07 16:48:48 +01:00
parent 989a32fa7b
commit 849e0b6863
13 changed files with 918 additions and 58 deletions

161
sema.c
View file

@ -27,7 +27,6 @@ static type *const_float = NULL;
static bool in_loop = false;
/* Print the error message and sync the parser. */
static void error(ast_node *n, char *msg)
{
if (n) {
@ -78,7 +77,6 @@ static type *create_float(sema *s, char *name, u8 bits)
return t;
}
/* https://en.wikipedia.org/wiki/Topological_sorting */
static void order_type(sema *s, ast_node *node)
{
if (node->type == NODE_STRUCT || node->type == NODE_UNION) {
@ -350,12 +348,12 @@ static void pop_scope(sema *s)
current_scope = current_scope->parent;
}
static type *get_def(sema *s, char *name)
static ast_node *get_def(sema *s, char *name)
{
scope *current = current_scope;
while (current) {
type *t = shget(current->defs, name);
if (t) return t;
ast_node *def = shget(current->defs, name);
if (def) return def;
current = current->parent;
}
@ -416,11 +414,13 @@ static type *get_identifier_type(sema *s, ast_node *node)
{
char *name_start = node->expr.string.start;
usize name_len = node->expr.string.len;
type *t = get_def(s, intern_string(s, name_start, name_len));
if (!t) {
char *name = intern_string(s, name_start, name_len);
node->expr.string.start = name;
ast_node *def = get_def(s, name);
if (!def) {
error(node, "unknown identifier.");
}
return t;
return def->expr_type;
}
static bool match(type *t1, type *t2);
@ -450,60 +450,129 @@ static type *get_expression_type(sema *s, ast_node *node)
prototype *prot = NULL;
switch (node->type) {
case NODE_IDENTIFIER:
return get_identifier_type(s, node);
t = get_identifier_type(s, node);
node->expr_type = t;
return t;
case NODE_INTEGER:
node->expr_type = const_int;
return const_int;
case NODE_FLOAT:
node->expr_type = const_float;
return const_float;
case NODE_STRING:
return get_string_type(s, node);
t = get_string_type(s, node);
node->expr_type = t;
return t;
case NODE_CHAR:
return shget(type_reg, "u8");
t = shget(type_reg, "u8");
node->expr_type = t;
return t;
case NODE_BOOL:
return shget(type_reg, "bool");
t = shget(type_reg, "bool");
node->expr_type = t;
return t;
case NODE_CAST:
return get_type(s, node->expr.cast.type);
t = get_type(s, node->expr.cast.type);
node->expr_type = t;
return t;
case NODE_POSTFIX:
case NODE_UNARY:
return get_expression_type(s, node->expr.unary.right);
t = get_expression_type(s, node->expr.unary.right);
if (node->expr.unary.operator == UOP_REF) {
ast_node *target = node->expr.unary.right;
while (target->type == NODE_ACCESS) {
target = target->expr.access.expr;
}
if (target->type != NODE_IDENTIFIER) {
error(node, "expected identifier.");
return NULL;
}
char *name = target->expr.string.start;
ast_node *def = get_def(s, name);
if (def) {
def->address_taken = true;
target->address_taken = true;
}
type *tmp = t;
t = arena_alloc(s->allocator, sizeof(type));
t->tag = TYPE_PTR;
t->size = sizeof(usize);
t->alignment = sizeof(usize);
t->name = "ptr";
t->data.ptr.is_const = false;
t->data.ptr.is_volatile = false;
t->data.ptr.child = tmp;
} else if (node->expr.unary.operator == UOP_DEREF) {
if (t->tag != TYPE_PTR) {
error(node, "only pointers can be dereferenced.");
return NULL;
}
t = t->data.ptr.child;
}
node->expr_type = t;
return t;
case NODE_BINARY:
t = get_expression_type(s, node->expr.binary.left);
if (!t) return NULL;
if (!match(t, get_expression_type(s, node->expr.binary.right))) {
if (node->expr.binary.operator == OP_ASSIGN_PTR) {
if (t->tag != TYPE_PTR) {
error(node, "expected pointer.");
return NULL;
}
t = t->data.ptr.child;
}
if (!can_cast(get_expression_type(s, node->expr.binary.right), t) && !match(t, get_expression_type(s, node->expr.binary.right))) {
error(node, "type mismatch.");
node->expr_type = NULL;
return NULL;
}
if (node->expr.binary.operator >= OP_EQ) {
return shget(type_reg, "bool");
t = shget(type_reg, "bool");
} else if (node->expr.binary.operator >= OP_ASSIGN && node->expr.binary.operator <= OP_MOD_EQ) {
return shget(type_reg, "void");
} else {
return t;
t = shget(type_reg, "void");
}
node->expr_type = t;
return t;
case NODE_RANGE:
return get_range_type(s, node);
t = get_range_type(s, node);
node->expr_type = t;
return t;
case NODE_ARRAY_SUBSCRIPT:
t = get_expression_type(s, node->expr.subscript.expr);
switch (t->tag) {
case TYPE_SLICE:
return t->data.slice.child;
t = t->data.slice.child;
break;
case TYPE_PTR:
return t->data.ptr.child;
t = t->data.ptr.child;
break;
default:
error(node, "only pointers and slices can be indexed.");
return NULL;
}
node->expr_type = t;
return t;
case NODE_CALL:
prot = shget(prototypes, intern_string(s, node->expr.call.name, node->expr.call.name_len));
if (!prot) {
error(node, "unknown function.");
return NULL;
}
return prot->type;
t = prot->type;
node->expr_type = t;
return t;
case NODE_ACCESS:
return get_access_type(s, node);
t = get_access_type(s, node);
node->expr_type = t;
return t;
default:
return shget(type_reg, "void");
t = shget(type_reg, "void");
node->expr_type = t;
return t;
}
}
@ -567,7 +636,14 @@ static void check_for(sema *s, ast_node *node)
while (current_capture) {
type *c_type = get_expression_type(s, current_slice->expr.unit_node.expr);
char *c_name = intern_string(s, current_capture->expr.unit_node.expr->expr.string.start, current_capture->expr.unit_node.expr->expr.string.len);
shput(current_scope->defs, c_name, c_type);
ast_node *cap_node = arena_alloc(s->allocator, sizeof(ast_node));
cap_node->type = NODE_VAR_DECL;
cap_node->expr_type = c_type;
cap_node->address_taken = false;
cap_node->expr.var_decl.name = c_name;
shput(current_scope->defs, c_name, cap_node);
current_capture = current_capture->expr.unit_node.next;
current_slice = current_slice->expr.unit_node.next;
}
@ -611,12 +687,23 @@ static void check_statement(sema *s, ast_node *node)
check_body(s, node->expr.whle.body);
in_loop = false;
break;
case NODE_IF:
if (!match(get_expression_type(s, node->expr.if_stmt.condition), shget(type_reg, "bool"))) {
error(node, "expected boolean value.");
return;
}
check_body(s, node->expr.if_stmt.body);
if (node->expr.if_stmt.otherwise) check_body(s, node->expr.if_stmt.otherwise);
break;
case NODE_FOR:
check_for(s, node);
break;
case NODE_VAR_DECL:
t = get_type(s, node->expr.var_decl.type);
node->expr_type = t;
name = intern_string(s, node->expr.var_decl.name, node->expr.var_decl.name_len);
node->expr.var_decl.name = name;
if (get_def(s, name)) {
error(node, "redeclaration of variable.");
break;
@ -624,7 +711,7 @@ static void check_statement(sema *s, ast_node *node)
if (!can_cast(get_expression_type(s, node->expr.var_decl.value), t) && !match(t, get_expression_type(s, node->expr.var_decl.value))) {
error(node, "type mismatch.");
}
shput(current_scope->defs, name, t);
shput(current_scope->defs, name, node);
break;
default:
get_expression_type(s, node);
@ -641,7 +728,14 @@ static void check_function(sema *s, ast_node *f)
while (param) {
type *p_type = get_type(s, param->type);
char *t_name = intern_string(s, param->name, param->name_len);
shput(current_scope->defs, t_name, p_type);
ast_node *param_node = arena_alloc(s->allocator, sizeof(ast_node));
param_node->type = NODE_VAR_DECL;
param_node->expr_type = p_type;
param_node->address_taken = false;
param_node->expr.var_decl.name = t_name;
shput(current_scope->defs, t_name, param_node);
param = param->next;
}
@ -658,7 +752,8 @@ static void analyze_unit(sema *s, ast_node *node)
{
ast_node *current = node;
while (current && current->type == NODE_UNIT) {
order_type(s, current->expr.unit_node.expr);
if (current->expr.unit_node.expr)
order_type(s, current->expr.unit_node.expr);
current = current->expr.unit_node.next;
}
@ -666,7 +761,7 @@ static void analyze_unit(sema *s, ast_node *node)
current = node;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr->type == NODE_FUNCTION) {
if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) {
create_prototype(s, current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
@ -674,8 +769,10 @@ static void analyze_unit(sema *s, ast_node *node)
current = node;
while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr->type == NODE_FUNCTION) {
if (current->expr.unit_node.expr && current->expr.unit_node.expr->type == NODE_FUNCTION) {
check_function(s, current->expr.unit_node.expr);
} else {
check_statement(s, current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
}
@ -720,5 +817,3 @@ sema *sema_init(parser *p, arena *a)
return s;
}