almost finished implementing sema

This commit is contained in:
Lorenzo Torres 2025-12-05 23:27:22 +01:00
parent 8b4b81e90b
commit 463ba71843
8 changed files with 380 additions and 43 deletions

View file

@ -1,10 +1,11 @@
struct b { struct b {
u32 ciao, i32 a,
u32 test, u32 b,
u32 test1, u32 c,
} }
u32 test(b hello) u32 test()
{ {
u32 a = (u32)3;
a = (u32)2;
} }

6
lc.c
View file

@ -123,12 +123,6 @@ void print_ast(ast_node *node, int depth) {
current = current->expr.unit_node.next; current = current->expr.unit_node.next;
} }
break; break;
case NODE_COMPOUND:
printf("Block\n");
for (usize i = 0; i < node->expr.compound.stmt_len; ++i) {
print_ast(node->expr.compound.statements[i], depth + 1);
}
break;
case NODE_CALL: case NODE_CALL:
printf("Call: %.*s\n", (int)node->expr.call.name_len, node->expr.call.name); printf("Call: %.*s\n", (int)node->expr.call.name_len, node->expr.call.name);
current = node->expr.call.parameters; current = node->expr.call.parameters;

View file

@ -396,6 +396,8 @@ lexer *lexer_init(char *source, usize size, arena *arena)
lex->source = source; lex->source = source;
keywords = arena_alloc(arena, sizeof(trie_node)); keywords = arena_alloc(arena, sizeof(trie_node));
trie_insert(keywords, lex->allocator, "true", TOKEN_TRUE);
trie_insert(keywords, lex->allocator, "false", TOKEN_FALSE);
trie_insert(keywords, lex->allocator, "struct", TOKEN_STRUCT); trie_insert(keywords, lex->allocator, "struct", TOKEN_STRUCT);
trie_insert(keywords, lex->allocator, "enum", TOKEN_ENUM); trie_insert(keywords, lex->allocator, "enum", TOKEN_ENUM);
trie_insert(keywords, lex->allocator, "union", TOKEN_UNION); trie_insert(keywords, lex->allocator, "union", TOKEN_UNION);

View file

@ -55,6 +55,8 @@ typedef enum {
TOKEN_IDENTIFIER, TOKEN_IDENTIFIER,
TOKEN_STRING, TOKEN_STRING,
TOKEN_CHAR, TOKEN_CHAR,
TOKEN_TRUE,
TOKEN_FALSE,
TOKEN_GOTO, TOKEN_GOTO,
TOKEN_LOOP, TOKEN_LOOP,
TOKEN_WHILE, TOKEN_WHILE,

View file

@ -428,15 +428,13 @@ ast_node *parse_expression(parser *p)
left = node; left = node;
} }
return left;
} }
/* /*
* If after parsing an expression a `.` character * If after parsing an expression a `.` character
* is found, it should be a member access expression. * is found, it should be a member access expression.
*/ */
if (match_peek(p, TOKEN_DOT)) if (match_peek(p, TOKEN_DOT) && p->tokens->next && p->tokens->next->type != TOKEN_LCURLY) {
{
while (match(p, TOKEN_DOT)) { while (match(p, TOKEN_DOT)) {
if (!match_peek(p, TOKEN_IDENTIFIER)) { if (!match_peek(p, TOKEN_IDENTIFIER)) {
error(p, "expected identifier after member access."); error(p, "expected identifier after member access.");
@ -450,7 +448,6 @@ ast_node *parse_expression(parser *p)
left = node; left = node;
} }
return left;
} }
/* /*
@ -749,7 +746,7 @@ parse_captures:
arena_start = arena_snapshot(p->allocator); arena_start = arena_snapshot(p->allocator);
node->expr.fr.captures = arena_alloc(p->allocator, sizeof(ast_node)); node->expr.fr.captures = arena_alloc(p->allocator, sizeof(ast_node));
node->expr.fr.captures->type = NODE_UNIT; node->expr.fr.captures->type = NODE_UNIT;
node->expr.fr.captures->expr.unit_node.expr = parse_expression(p); node->expr.fr.captures->expr.unit_node.expr = parse_factor(p);
if (node->expr.fr.captures->expr.unit_node.expr && node->expr.fr.captures->expr.unit_node.expr->type != NODE_IDENTIFIER) { if (node->expr.fr.captures->expr.unit_node.expr && node->expr.fr.captures->expr.unit_node.expr->type != NODE_IDENTIFIER) {
error(p, "captures must be identifiers."); error(p, "captures must be identifiers.");
arena_reset_to_snapshot(p->allocator, arena_start); arena_reset_to_snapshot(p->allocator, arena_start);
@ -776,7 +773,7 @@ parse_captures:
tail->expr.unit_node.next->expr.unit_node.expr = expr; tail->expr.unit_node.next->expr.unit_node.expr = expr;
tail = tail->expr.unit_node.next; tail = tail->expr.unit_node.next;
tail->type = NODE_UNIT; tail->type = NODE_UNIT;
expr = parse_expression(p); expr = parse_factor(p);
if (!expr) { if (!expr) {
error(p, "expected `|`."); error(p, "expected `|`.");
arena_reset_to_snapshot(p->allocator, arena_start); arena_reset_to_snapshot(p->allocator, arena_start);

View file

@ -5,6 +5,7 @@
#include "utils.h" #include "utils.h"
#include <stdbool.h> #include <stdbool.h>
struct _type;
struct _ast_node; struct _ast_node;
typedef enum { typedef enum {
@ -12,21 +13,14 @@ typedef enum {
OP_MINUS, // - OP_MINUS, // -
OP_DIV, // / OP_DIV, // /
OP_MUL, // * OP_MUL, // *
OP_EQ, // == OP_MOD, // %
OP_ASSIGN, // =
OP_AND, // &&
OP_OR, // ||
OP_NEQ, // !=
OP_GT, // >
OP_LT, // <
OP_GE, // >=
OP_LE, // <=
OP_RSHIFT_EQ, // >>=
OP_LSHIFT_EQ, // <<=
OP_BOR, // | OP_BOR, // |
OP_BAND, // & OP_BAND, // &
OP_BXOR, // ^ OP_BXOR, // ^
OP_MOD, // %
OP_ASSIGN, // =
OP_RSHIFT_EQ, // >>=
OP_LSHIFT_EQ, // <<=
OP_PLUS_EQ, // += OP_PLUS_EQ, // +=
OP_MINUS_EQ, // -= OP_MINUS_EQ, // -=
OP_DIV_EQ, // /= OP_DIV_EQ, // /=
@ -35,6 +29,15 @@ typedef enum {
OP_BAND_EQ, // &= OP_BAND_EQ, // &=
OP_BXOR_EQ, // ^= OP_BXOR_EQ, // ^=
OP_MOD_EQ, // %= OP_MOD_EQ, // %=
OP_EQ, // ==
OP_AND, // &&
OP_OR, // ||
OP_NEQ, // !=
OP_GT, // >
OP_LT, // <
OP_GE, // >=
OP_LE, // <=
} binary_op; } binary_op;
typedef enum { typedef enum {
@ -79,32 +82,34 @@ typedef enum {
NODE_FLOAT, NODE_FLOAT,
NODE_STRING, NODE_STRING,
NODE_CHAR, NODE_CHAR,
NODE_BOOL,
NODE_CAST, NODE_CAST,
NODE_UNARY, NODE_UNARY,
NODE_BINARY, NODE_BINARY,
NODE_RANGE, NODE_RANGE,
NODE_ARRAY_SUBSCRIPT, NODE_ARRAY_SUBSCRIPT,
NODE_ACCESS,
NODE_CALL,
NODE_POSTFIX, NODE_POSTFIX,
NODE_CALL,
NODE_ACCESS,
NODE_STRUCT_INIT,
NODE_TERNARY, /* TODO */
NODE_BREAK, NODE_BREAK,
NODE_RETURN, NODE_RETURN,
NODE_LABEL,
NODE_GOTO,
NODE_IMPORT, NODE_IMPORT,
NODE_FOR, NODE_FOR,
NODE_WHILE, NODE_WHILE,
NODE_IF, NODE_IF,
NODE_COMPOUND, NODE_VAR_DECL,
NODE_LABEL,
NODE_GOTO,
NODE_ENUM, NODE_ENUM,
NODE_STRUCT, NODE_STRUCT,
NODE_UNION, NODE_UNION,
NODE_VAR_DECL,
NODE_FUNCTION, NODE_FUNCTION,
NODE_PTR_TYPE, NODE_PTR_TYPE,
NODE_TERNARY, /* TODO */
NODE_SWITCH, /* TODO */ NODE_SWITCH, /* TODO */
NODE_STRUCT_INIT,
NODE_UNIT, NODE_UNIT,
} node_type; } node_type;
@ -120,6 +125,7 @@ typedef enum {
typedef struct _ast_node { typedef struct _ast_node {
node_type type; node_type type;
source_pos position; source_pos position;
struct _type *expr_type;
union { union {
struct { struct {
struct _ast_node *type; struct _ast_node *type;
@ -138,6 +144,7 @@ typedef struct _ast_node {
struct _ast_node *right; struct _ast_node *right;
unary_op operator; unary_op operator;
} unary; } unary;
u8 boolean;
i64 integer; i64 integer;
f64 flt; // float f64 flt; // float
struct { struct {
@ -183,9 +190,9 @@ typedef struct _ast_node {
struct { struct {
/* These should be lists of unit_node */ /* These should be lists of unit_node */
struct _ast_node *slices; struct _ast_node *slices;
usize slice_len;
struct _ast_node *captures; struct _ast_node *captures;
int capture_len; usize capture_len;
int slice_len;
struct _ast_node* body; struct _ast_node* body;
} fr; // for } fr; // for
struct { struct {

335
sema.c
View file

@ -18,6 +18,12 @@ static struct { char *key; type *value; } *type_reg;
static struct { char *key; prototype *value; } *prototypes; static struct { char *key; prototype *value; } *prototypes;
static scope *global_scope = NULL;
static scope *current_scope = NULL;
static type *current_return = NULL;
static bool in_loop = false;
/* Print the error message and sync the parser. */ /* Print the error message and sync the parser. */
static void error(ast_node *n, char *msg) static void error(ast_node *n, char *msg)
{ {
@ -170,6 +176,9 @@ static void register_struct(sema *s, char *name, type *t)
return; return;
} }
char *n = intern_string(s, m->name, m->name_len);
shput(t->data.structure.member_types, n, m_type);
if (m_type->size == 0) { if (m_type->size == 0) {
error(m->type, "a struct member can't be of type `void`."); error(m->type, "a struct member can't be of type `void`.");
return; return;
@ -195,8 +204,6 @@ static void register_struct(sema *s, char *name, type *t)
} }
t->size = offset; t->size = offset;
printf("%ld\n", t->size);
} }
static void register_union(sema *s, char *name, type *t) static void register_union(sema *s, char *name, type *t)
@ -206,6 +213,15 @@ static void register_union(sema *s, char *name, type *t)
member *m = t->data.structure.members; member *m = t->data.structure.members;
while (m) { while (m) {
type *m_type = get_type(s, m->type); type *m_type = get_type(s, m->type);
if (!m_type) {
error(m->type, "unknown type.");
return;
}
char *n = intern_string(s, m->name, m->name_len);
shput(t->data.structure.member_types, n, m_type);
if (alignment < m_type->alignment) { if (alignment < m_type->alignment) {
alignment = m_type->alignment; alignment = m_type->alignment;
} }
@ -299,6 +315,9 @@ static void create_prototype(sema *s, ast_node *node)
{ {
prototype *p = arena_alloc(s->allocator, sizeof(prototype)); prototype *p = arena_alloc(s->allocator, sizeof(prototype));
p->name = intern_string(s, node->expr.function.name, node->expr.function.name_len); p->name = intern_string(s, node->expr.function.name, node->expr.function.name_len);
if (shget(prototypes, p->name)) {
error(node, "function already defined.");
}
member *m = node->expr.function.parameters; member *m = node->expr.function.parameters;
while (m) { while (m) {
@ -316,6 +335,304 @@ static void create_prototype(sema *s, ast_node *node)
shput(prototypes, p->name, p); shput(prototypes, p->name, p);
} }
static void push_scope(sema *s)
{
scope *scp = arena_alloc(s->allocator, sizeof(scope));
scp->parent = current_scope;
current_scope = scp;
}
static void pop_scope(sema *s)
{
current_scope = current_scope->parent;
}
static type *get_def(sema *s, char *name)
{
scope *current = current_scope;
while (current) {
type *t = shget(current->defs, name);
if (t) return t;
current = current->parent;
}
return NULL;
}
static type *get_string_type(sema *s, ast_node *node)
{
type *string_type = arena_alloc(s->allocator, sizeof(type));
string_type->tag = TYPE_PTR;
string_type->size = sizeof(usize);
string_type->alignment = sizeof(usize);
string_type->name = "slice";
string_type->data.slice.child = shget(type_reg, "u8");
string_type->data.slice.is_const = true;
string_type->data.slice.is_volatile = false;
string_type->data.slice.len = node->expr.string.len;
return string_type;
}
static type *get_range_type(sema *s, ast_node *node)
{
type *range_type = arena_alloc(s->allocator, sizeof(type));
range_type->tag = TYPE_PTR;
range_type->size = sizeof(usize);
range_type->alignment = sizeof(usize);
range_type->name = "slice";
range_type->data.slice.child = shget(type_reg, "usize");
range_type->data.slice.is_const = true;
range_type->data.slice.is_volatile = false;
range_type->data.slice.len = node->expr.binary.right->expr.integer - node->expr.binary.left->expr.integer;
return range_type;
}
static type *get_expression_type(sema *s, ast_node *node);
static type *get_access_type(sema *s, ast_node *node)
{
type *t = get_expression_type(s, node->expr.access.expr);
ast_node *member = node->expr.access.member;
char *name_start = member->expr.string.start;
usize name_len = member->expr.string.len;
if (!t || (t->tag != TYPE_STRUCT && t->tag != TYPE_UNION)) {
error(node, "invalid expression.");
return NULL;
}
char *name = intern_string(s, name_start, name_len);
type *res = shget(t->data.structure.member_types, name);
if (!res) {
error(node, "struct doesn't have that member");
return NULL;
}
return res;
}
static type *get_identifier_type(sema *s, ast_node *node)
{
char *name_start = node->expr.string.start;
usize name_len = node->expr.string.len;
type *t = get_def(s, intern_string(s, name_start, name_len));
if (!t) {
error(node, "unknown identifier.");
}
return t;
}
static bool match(type *t1, type *t2);
static type *get_expression_type(sema *s, ast_node *node)
{
if (!node) {
return shget(type_reg, "void");
}
type *t = NULL;
prototype *prot = NULL;
switch (node->type) {
case NODE_IDENTIFIER:
return get_identifier_type(s, node);
case NODE_INTEGER:
return shget(type_reg, "i32");
case NODE_FLOAT:
return shget(type_reg, "f64");
case NODE_STRING:
return get_string_type(s, node);
case NODE_CHAR:
return shget(type_reg, "u8");
case NODE_BOOL:
return shget(type_reg, "bool");
case NODE_CAST:
return get_type(s, node->expr.cast.type);
case NODE_POSTFIX:
case NODE_UNARY:
return get_expression_type(s, node->expr.unary.right);
case NODE_BINARY:
t = get_expression_type(s, node->expr.binary.left);
if (!t) return NULL;
if (!match(t, get_expression_type(s, node->expr.binary.right))) {
error(node, "type mismatch.");
return NULL;
}
if (node->expr.binary.operator >= OP_EQ) {
return shget(type_reg, "bool");
} else if (node->expr.binary.operator >= OP_ASSIGN && node->expr.binary.operator <= OP_MOD_EQ) {
return shget(type_reg, "void");
} else {
return t;
}
case NODE_RANGE:
return get_range_type(s, node);
case NODE_ARRAY_SUBSCRIPT:
t = get_expression_type(s, node->expr.subscript.expr);
switch (t->tag) {
case TYPE_SLICE:
return t->data.slice.child;
case TYPE_PTR:
return t->data.ptr.child;
default:
error(node, "only pointers and slices can be indexed.");
return NULL;
}
case NODE_CALL:
prot = shget(prototypes, intern_string(s, node->expr.call.name, node->expr.call.name_len));
if (!prot) {
error(node, "unknown function.");
return NULL;
}
return prot->type;
case NODE_ACCESS:
return get_access_type(s, node);
default:
return shget(type_reg, "void");
}
}
static bool match(type *t1, type *t2)
{
if (!t1 || !t2) return false;
if (t1->tag != t2->tag) return false;
switch(t1->tag) {
case TYPE_VOID:
case TYPE_BOOL:
return true;
case TYPE_PTR:
return (t1->data.ptr.is_const == t2->data.ptr.is_const) && (t1->data.ptr.is_volatile == t2->data.ptr.is_volatile) && match(t1->data.ptr.child, t2->data.ptr.child);
case TYPE_SLICE:
return (t1->data.slice.is_const == t2->data.slice.is_const) && (t1->data.slice.is_volatile == t2->data.slice.is_volatile) && match(t1->data.slice.child, t2->data.slice.child) && t1->data.slice.len == t2->data.slice.len;
case TYPE_STRUCT:
case TYPE_UNION:
return t1 == t2;
case TYPE_INTEGER:
case TYPE_UINTEGER:
return t1->data.integer == t2->data.integer;
case TYPE_FLOAT:
return t1->data.flt == t2->data.flt;
case TYPE_ENUM:
case TYPE_GENERIC:
/* TODO */
return false;
}
return false;
}
static void check_statement(sema *s, ast_node *node);
static void check_body(sema *s, ast_node *node)
{
push_scope(s);
ast_node *current = node;
while (current && current->type == NODE_UNIT) {
check_statement(s, current->expr.unit_node.expr);
current = current->expr.unit_node.next;
}
pop_scope(s);
}
static void check_for(sema *s, ast_node *node)
{
ast_node *slices = node->expr.fr.slices;
ast_node *captures = node->expr.fr.captures;
push_scope(s);
ast_node *current_capture = captures;
ast_node *current_slice = slices;
while (current_capture) {
type *c_type = get_expression_type(s, current_slice->expr.unit_node.expr);
char *c_name = intern_string(s, current_capture->expr.unit_node.expr->expr.string.start, current_capture->expr.unit_node.expr->expr.string.len);
shput(current_scope->defs, c_name, c_type);
current_capture = current_capture->expr.unit_node.next;
current_slice = current_slice->expr.unit_node.next;
}
ast_node *current = node->expr.fr.body;
in_loop = true;
while (current && current->type == NODE_UNIT) {
check_statement(s, current->expr.unit_node.expr);
current = current->expr.unit_node.next;
}
in_loop = false;
pop_scope(s);
}
static void check_statement(sema *s, ast_node *node)
{
if (!node) return;
type *t = NULL;
char *name = NULL;
switch(node->type) {
case NODE_RETURN:
if (!match(get_expression_type(s, node->expr.ret.value), current_return)) {
error(node, "return type doesn't match function's one.");
}
break;
case NODE_BREAK:
if (!in_loop) {
error(node, "`break` isn't in a loop.");
}
break;
case NODE_WHILE:
if (!match(get_expression_type(s, node->expr.whle.condition), shget(type_reg, "bool"))) {
error(node, "expected boolean value.");
return;
}
in_loop = true;
check_body(s, node->expr.whle.body);
in_loop = false;
break;
case NODE_FOR:
check_for(s, node);
break;
case NODE_VAR_DECL:
t = get_type(s, node->expr.var_decl.type);
name = intern_string(s, node->expr.var_decl.name, node->expr.var_decl.name_len);
if (get_def(s, name)) {
error(node, "redeclaration of variable.");
break;
}
if (!match(t, get_expression_type(s, node->expr.var_decl.value))) {
error(node, "type mismatch.");
}
shput(current_scope->defs, name, t);
break;
default:
get_expression_type(s, node);
break;
}
}
static void check_function(sema *s, ast_node *f)
{
push_scope(s);
current_return = get_type(s, f->expr.function.type);
member *param = f->expr.function.parameters;
while (param) {
type *p_type = get_type(s, param->type);
char *t_name = intern_string(s, param->name, param->name_len);
shput(current_scope->defs, t_name, p_type);
param = param->next;
}
ast_node *current = f->expr.function.body;
while (current && current->type == NODE_UNIT) {
check_statement(s, current->expr.unit_node.expr);
current = current->expr.unit_node.next;
}
pop_scope(s);
}
static void analyze_unit(sema *s, ast_node *node) static void analyze_unit(sema *s, ast_node *node)
{ {
ast_node *current = node; ast_node *current = node;
@ -334,8 +651,12 @@ static void analyze_unit(sema *s, ast_node *node)
current = current->expr.unit_node.next; current = current->expr.unit_node.next;
} }
for (int i=0; i < shlen(prototypes); i++) { current = node;
printf("f: %s\n", prototypes[i].key); while (current && current->type == NODE_UNIT) {
if (current->expr.unit_node.expr->type == NODE_FUNCTION) {
check_function(s, current->expr.unit_node.expr);
}
current = current->expr.unit_node.next;
} }
} }
@ -346,7 +667,13 @@ sema *sema_init(parser *p, arena *a)
types = NULL; types = NULL;
s->ast = p->ast; s->ast = p->ast;
global_scope = arena_alloc(a, sizeof(scope));
global_scope->parent = NULL;
global_scope->defs = NULL;
current_scope = global_scope;
register_type(s, "void", create_integer(s, "void", 0, false)); register_type(s, "void", create_integer(s, "void", 0, false));
register_type(s, "bool", create_integer(s, "bool", 8, false));
register_type(s, "u8", create_integer(s, "u8", 8, false)); register_type(s, "u8", create_integer(s, "u8", 8, false));
register_type(s, "u16", create_integer(s, "u16", 16, false)); register_type(s, "u16", create_integer(s, "u16", 16, false));
register_type(s, "u32", create_integer(s, "u32", 32, false)); register_type(s, "u32", create_integer(s, "u32", 32, false));

7
sema.h
View file

@ -8,6 +8,7 @@
typedef enum { typedef enum {
TYPE_VOID, TYPE_VOID,
TYPE_BOOL,
TYPE_PTR, TYPE_PTR,
TYPE_SLICE, TYPE_SLICE,
TYPE_FLOAT, TYPE_FLOAT,
@ -42,6 +43,7 @@ typedef struct _type {
char *name; char *name;
usize name_len; usize name_len;
member *members; member *members;
struct { char *key; struct _type *value; } *member_types;
} structure; } structure;
struct { struct {
char *name; char *name;
@ -57,6 +59,11 @@ typedef struct {
type **parameters; type **parameters;
} prototype; } prototype;
typedef struct _scope {
struct _scope *parent;
struct { char *key; type *value; } *defs;
} scope;
typedef struct { typedef struct {
arena *allocator; arena *allocator;
ast_node *ast; ast_node *ast;