diff --git a/examples/hello_world.l b/examples/hello_world.l index ec4071b..8bdb256 100644 --- a/examples/hello_world.l +++ b/examples/hello_world.l @@ -1,10 +1,11 @@ struct b { - u32 ciao, - u32 test, - u32 test1, + i32 a, + u32 b, + u32 c, } -u32 test(b hello) +u32 test() { - + u32 a = (u32)3; + a = (u32)2; } diff --git a/lc.c b/lc.c index 1587486..e4a0229 100644 --- a/lc.c +++ b/lc.c @@ -123,12 +123,6 @@ void print_ast(ast_node *node, int depth) { current = current->expr.unit_node.next; } break; - case NODE_COMPOUND: - printf("Block\n"); - for (usize i = 0; i < node->expr.compound.stmt_len; ++i) { - print_ast(node->expr.compound.statements[i], depth + 1); - } - break; case NODE_CALL: printf("Call: %.*s\n", (int)node->expr.call.name_len, node->expr.call.name); current = node->expr.call.parameters; diff --git a/lexer.c b/lexer.c index 26c1eb1..07d1a9d 100644 --- a/lexer.c +++ b/lexer.c @@ -396,6 +396,8 @@ lexer *lexer_init(char *source, usize size, arena *arena) lex->source = source; keywords = arena_alloc(arena, sizeof(trie_node)); + trie_insert(keywords, lex->allocator, "true", TOKEN_TRUE); + trie_insert(keywords, lex->allocator, "false", TOKEN_FALSE); trie_insert(keywords, lex->allocator, "struct", TOKEN_STRUCT); trie_insert(keywords, lex->allocator, "enum", TOKEN_ENUM); trie_insert(keywords, lex->allocator, "union", TOKEN_UNION); diff --git a/lexer.h b/lexer.h index 2aa86ec..a3859ba 100644 --- a/lexer.h +++ b/lexer.h @@ -55,6 +55,8 @@ typedef enum { TOKEN_IDENTIFIER, TOKEN_STRING, TOKEN_CHAR, + TOKEN_TRUE, + TOKEN_FALSE, TOKEN_GOTO, TOKEN_LOOP, TOKEN_WHILE, diff --git a/parser.c b/parser.c index 3ba7c2a..96b5c87 100644 --- a/parser.c +++ b/parser.c @@ -428,15 +428,13 @@ ast_node *parse_expression(parser *p) left = node; } - return left; } /* * If after parsing an expression a `.` character * is found, it should be a member access expression. */ - if (match_peek(p, TOKEN_DOT)) - { + if (match_peek(p, TOKEN_DOT) && p->tokens->next && p->tokens->next->type != TOKEN_LCURLY) { while (match(p, TOKEN_DOT)) { if (!match_peek(p, TOKEN_IDENTIFIER)) { error(p, "expected identifier after member access."); @@ -450,7 +448,6 @@ ast_node *parse_expression(parser *p) left = node; } - return left; } /* @@ -749,7 +746,7 @@ parse_captures: arena_start = arena_snapshot(p->allocator); node->expr.fr.captures = arena_alloc(p->allocator, sizeof(ast_node)); node->expr.fr.captures->type = NODE_UNIT; - node->expr.fr.captures->expr.unit_node.expr = parse_expression(p); + node->expr.fr.captures->expr.unit_node.expr = parse_factor(p); if (node->expr.fr.captures->expr.unit_node.expr && node->expr.fr.captures->expr.unit_node.expr->type != NODE_IDENTIFIER) { error(p, "captures must be identifiers."); arena_reset_to_snapshot(p->allocator, arena_start); @@ -776,7 +773,7 @@ parse_captures: tail->expr.unit_node.next->expr.unit_node.expr = expr; tail = tail->expr.unit_node.next; tail->type = NODE_UNIT; - expr = parse_expression(p); + expr = parse_factor(p); if (!expr) { error(p, "expected `|`."); arena_reset_to_snapshot(p->allocator, arena_start); diff --git a/parser.h b/parser.h index 40f9a9c..5428101 100644 --- a/parser.h +++ b/parser.h @@ -5,6 +5,7 @@ #include "utils.h" #include +struct _type; struct _ast_node; typedef enum { @@ -12,21 +13,14 @@ typedef enum { OP_MINUS, // - OP_DIV, // / OP_MUL, // * - OP_EQ, // == - OP_ASSIGN, // = - OP_AND, // && - OP_OR, // || - OP_NEQ, // != - OP_GT, // > - OP_LT, // < - OP_GE, // >= - OP_LE, // <= - OP_RSHIFT_EQ, // >>= - OP_LSHIFT_EQ, // <<= + OP_MOD, // % OP_BOR, // | OP_BAND, // & OP_BXOR, // ^ - OP_MOD, // % + + OP_ASSIGN, // = + OP_RSHIFT_EQ, // >>= + OP_LSHIFT_EQ, // <<= OP_PLUS_EQ, // += OP_MINUS_EQ, // -= OP_DIV_EQ, // /= @@ -35,6 +29,15 @@ typedef enum { OP_BAND_EQ, // &= OP_BXOR_EQ, // ^= OP_MOD_EQ, // %= + + OP_EQ, // == + OP_AND, // && + OP_OR, // || + OP_NEQ, // != + OP_GT, // > + OP_LT, // < + OP_GE, // >= + OP_LE, // <= } binary_op; typedef enum { @@ -79,32 +82,34 @@ typedef enum { NODE_FLOAT, NODE_STRING, NODE_CHAR, + NODE_BOOL, NODE_CAST, NODE_UNARY, NODE_BINARY, NODE_RANGE, NODE_ARRAY_SUBSCRIPT, - NODE_ACCESS, - NODE_CALL, NODE_POSTFIX, + NODE_CALL, + NODE_ACCESS, + NODE_STRUCT_INIT, + NODE_TERNARY, /* TODO */ + NODE_BREAK, NODE_RETURN, - NODE_LABEL, - NODE_GOTO, NODE_IMPORT, NODE_FOR, NODE_WHILE, NODE_IF, - NODE_COMPOUND, + NODE_VAR_DECL, + NODE_LABEL, + NODE_GOTO, + NODE_ENUM, NODE_STRUCT, NODE_UNION, - NODE_VAR_DECL, NODE_FUNCTION, NODE_PTR_TYPE, - NODE_TERNARY, /* TODO */ NODE_SWITCH, /* TODO */ - NODE_STRUCT_INIT, NODE_UNIT, } node_type; @@ -120,6 +125,7 @@ typedef enum { typedef struct _ast_node { node_type type; source_pos position; + struct _type *expr_type; union { struct { struct _ast_node *type; @@ -138,6 +144,7 @@ typedef struct _ast_node { struct _ast_node *right; unary_op operator; } unary; + u8 boolean; i64 integer; f64 flt; // float struct { @@ -183,9 +190,9 @@ typedef struct _ast_node { struct { /* These should be lists of unit_node */ struct _ast_node *slices; + usize slice_len; struct _ast_node *captures; - int capture_len; - int slice_len; + usize capture_len; struct _ast_node* body; } fr; // for struct { diff --git a/sema.c b/sema.c index ea4cf10..a1499a4 100644 --- a/sema.c +++ b/sema.c @@ -18,6 +18,12 @@ static struct { char *key; type *value; } *type_reg; static struct { char *key; prototype *value; } *prototypes; +static scope *global_scope = NULL; +static scope *current_scope = NULL; +static type *current_return = NULL; + +static bool in_loop = false; + /* Print the error message and sync the parser. */ static void error(ast_node *n, char *msg) { @@ -170,6 +176,9 @@ static void register_struct(sema *s, char *name, type *t) return; } + char *n = intern_string(s, m->name, m->name_len); + shput(t->data.structure.member_types, n, m_type); + if (m_type->size == 0) { error(m->type, "a struct member can't be of type `void`."); return; @@ -195,8 +204,6 @@ static void register_struct(sema *s, char *name, type *t) } t->size = offset; - - printf("%ld\n", t->size); } static void register_union(sema *s, char *name, type *t) @@ -206,6 +213,15 @@ static void register_union(sema *s, char *name, type *t) member *m = t->data.structure.members; while (m) { type *m_type = get_type(s, m->type); + + if (!m_type) { + error(m->type, "unknown type."); + return; + } + + char *n = intern_string(s, m->name, m->name_len); + shput(t->data.structure.member_types, n, m_type); + if (alignment < m_type->alignment) { alignment = m_type->alignment; } @@ -299,6 +315,9 @@ static void create_prototype(sema *s, ast_node *node) { prototype *p = arena_alloc(s->allocator, sizeof(prototype)); p->name = intern_string(s, node->expr.function.name, node->expr.function.name_len); + if (shget(prototypes, p->name)) { + error(node, "function already defined."); + } member *m = node->expr.function.parameters; while (m) { @@ -316,6 +335,304 @@ static void create_prototype(sema *s, ast_node *node) shput(prototypes, p->name, p); } +static void push_scope(sema *s) +{ + scope *scp = arena_alloc(s->allocator, sizeof(scope)); + scp->parent = current_scope; + current_scope = scp; +} + +static void pop_scope(sema *s) +{ + current_scope = current_scope->parent; +} + +static type *get_def(sema *s, char *name) +{ + scope *current = current_scope; + while (current) { + type *t = shget(current->defs, name); + if (t) return t; + + current = current->parent; + } + + return NULL; +} + +static type *get_string_type(sema *s, ast_node *node) +{ + type *string_type = arena_alloc(s->allocator, sizeof(type)); + string_type->tag = TYPE_PTR; + string_type->size = sizeof(usize); + string_type->alignment = sizeof(usize); + string_type->name = "slice"; + string_type->data.slice.child = shget(type_reg, "u8"); + string_type->data.slice.is_const = true; + string_type->data.slice.is_volatile = false; + string_type->data.slice.len = node->expr.string.len; + return string_type; +} + +static type *get_range_type(sema *s, ast_node *node) +{ + type *range_type = arena_alloc(s->allocator, sizeof(type)); + range_type->tag = TYPE_PTR; + range_type->size = sizeof(usize); + range_type->alignment = sizeof(usize); + range_type->name = "slice"; + range_type->data.slice.child = shget(type_reg, "usize"); + range_type->data.slice.is_const = true; + range_type->data.slice.is_volatile = false; + range_type->data.slice.len = node->expr.binary.right->expr.integer - node->expr.binary.left->expr.integer; + return range_type; +} + +static type *get_expression_type(sema *s, ast_node *node); +static type *get_access_type(sema *s, ast_node *node) +{ + type *t = get_expression_type(s, node->expr.access.expr); + ast_node *member = node->expr.access.member; + char *name_start = member->expr.string.start; + usize name_len = member->expr.string.len; + if (!t || (t->tag != TYPE_STRUCT && t->tag != TYPE_UNION)) { + error(node, "invalid expression."); + return NULL; + } + char *name = intern_string(s, name_start, name_len); + type *res = shget(t->data.structure.member_types, name); + if (!res) { + error(node, "struct doesn't have that member"); + return NULL; + } + + return res; +} + +static type *get_identifier_type(sema *s, ast_node *node) +{ + char *name_start = node->expr.string.start; + usize name_len = node->expr.string.len; + type *t = get_def(s, intern_string(s, name_start, name_len)); + if (!t) { + error(node, "unknown identifier."); + } + return t; +} + +static bool match(type *t1, type *t2); + +static type *get_expression_type(sema *s, ast_node *node) +{ + if (!node) { + return shget(type_reg, "void"); + } + + type *t = NULL; + prototype *prot = NULL; + switch (node->type) { + case NODE_IDENTIFIER: + return get_identifier_type(s, node); + case NODE_INTEGER: + return shget(type_reg, "i32"); + case NODE_FLOAT: + return shget(type_reg, "f64"); + case NODE_STRING: + return get_string_type(s, node); + case NODE_CHAR: + return shget(type_reg, "u8"); + case NODE_BOOL: + return shget(type_reg, "bool"); + case NODE_CAST: + return get_type(s, node->expr.cast.type); + case NODE_POSTFIX: + case NODE_UNARY: + return get_expression_type(s, node->expr.unary.right); + case NODE_BINARY: + t = get_expression_type(s, node->expr.binary.left); + if (!t) return NULL; + if (!match(t, get_expression_type(s, node->expr.binary.right))) { + error(node, "type mismatch."); + return NULL; + } + if (node->expr.binary.operator >= OP_EQ) { + return shget(type_reg, "bool"); + } else if (node->expr.binary.operator >= OP_ASSIGN && node->expr.binary.operator <= OP_MOD_EQ) { + return shget(type_reg, "void"); + } else { + return t; + } + case NODE_RANGE: + return get_range_type(s, node); + case NODE_ARRAY_SUBSCRIPT: + t = get_expression_type(s, node->expr.subscript.expr); + switch (t->tag) { + case TYPE_SLICE: + return t->data.slice.child; + case TYPE_PTR: + return t->data.ptr.child; + default: + error(node, "only pointers and slices can be indexed."); + return NULL; + } + case NODE_CALL: + prot = shget(prototypes, intern_string(s, node->expr.call.name, node->expr.call.name_len)); + if (!prot) { + error(node, "unknown function."); + return NULL; + } + return prot->type; + case NODE_ACCESS: + return get_access_type(s, node); + default: + return shget(type_reg, "void"); + } +} + +static bool match(type *t1, type *t2) +{ + if (!t1 || !t2) return false; + if (t1->tag != t2->tag) return false; + + switch(t1->tag) { + case TYPE_VOID: + case TYPE_BOOL: + return true; + case TYPE_PTR: + return (t1->data.ptr.is_const == t2->data.ptr.is_const) && (t1->data.ptr.is_volatile == t2->data.ptr.is_volatile) && match(t1->data.ptr.child, t2->data.ptr.child); + case TYPE_SLICE: + return (t1->data.slice.is_const == t2->data.slice.is_const) && (t1->data.slice.is_volatile == t2->data.slice.is_volatile) && match(t1->data.slice.child, t2->data.slice.child) && t1->data.slice.len == t2->data.slice.len; + case TYPE_STRUCT: + case TYPE_UNION: + return t1 == t2; + case TYPE_INTEGER: + case TYPE_UINTEGER: + return t1->data.integer == t2->data.integer; + case TYPE_FLOAT: + return t1->data.flt == t2->data.flt; + case TYPE_ENUM: + case TYPE_GENERIC: + /* TODO */ + return false; + } + + return false; +} + +static void check_statement(sema *s, ast_node *node); +static void check_body(sema *s, ast_node *node) +{ + push_scope(s); + + ast_node *current = node; + while (current && current->type == NODE_UNIT) { + check_statement(s, current->expr.unit_node.expr); + current = current->expr.unit_node.next; + } + + pop_scope(s); +} + +static void check_for(sema *s, ast_node *node) +{ + ast_node *slices = node->expr.fr.slices; + ast_node *captures = node->expr.fr.captures; + + push_scope(s); + + ast_node *current_capture = captures; + ast_node *current_slice = slices; + + while (current_capture) { + type *c_type = get_expression_type(s, current_slice->expr.unit_node.expr); + char *c_name = intern_string(s, current_capture->expr.unit_node.expr->expr.string.start, current_capture->expr.unit_node.expr->expr.string.len); + shput(current_scope->defs, c_name, c_type); + current_capture = current_capture->expr.unit_node.next; + current_slice = current_slice->expr.unit_node.next; + } + + ast_node *current = node->expr.fr.body; + + in_loop = true; + while (current && current->type == NODE_UNIT) { + check_statement(s, current->expr.unit_node.expr); + current = current->expr.unit_node.next; + } + in_loop = false; + + pop_scope(s); +} + +static void check_statement(sema *s, ast_node *node) +{ + if (!node) return; + + type *t = NULL; + char *name = NULL; + switch(node->type) { + case NODE_RETURN: + if (!match(get_expression_type(s, node->expr.ret.value), current_return)) { + error(node, "return type doesn't match function's one."); + } + break; + case NODE_BREAK: + if (!in_loop) { + error(node, "`break` isn't in a loop."); + } + break; + case NODE_WHILE: + if (!match(get_expression_type(s, node->expr.whle.condition), shget(type_reg, "bool"))) { + error(node, "expected boolean value."); + return; + } + + in_loop = true; + check_body(s, node->expr.whle.body); + in_loop = false; + break; + case NODE_FOR: + check_for(s, node); + break; + case NODE_VAR_DECL: + t = get_type(s, node->expr.var_decl.type); + name = intern_string(s, node->expr.var_decl.name, node->expr.var_decl.name_len); + if (get_def(s, name)) { + error(node, "redeclaration of variable."); + break; + } + if (!match(t, get_expression_type(s, node->expr.var_decl.value))) { + error(node, "type mismatch."); + } + shput(current_scope->defs, name, t); + break; + default: + get_expression_type(s, node); + break; + } +} + +static void check_function(sema *s, ast_node *f) +{ + push_scope(s); + current_return = get_type(s, f->expr.function.type); + + member *param = f->expr.function.parameters; + while (param) { + type *p_type = get_type(s, param->type); + char *t_name = intern_string(s, param->name, param->name_len); + shput(current_scope->defs, t_name, p_type); + param = param->next; + } + + ast_node *current = f->expr.function.body; + while (current && current->type == NODE_UNIT) { + check_statement(s, current->expr.unit_node.expr); + current = current->expr.unit_node.next; + } + + pop_scope(s); +} + static void analyze_unit(sema *s, ast_node *node) { ast_node *current = node; @@ -334,8 +651,12 @@ static void analyze_unit(sema *s, ast_node *node) current = current->expr.unit_node.next; } - for (int i=0; i < shlen(prototypes); i++) { - printf("f: %s\n", prototypes[i].key); + current = node; + while (current && current->type == NODE_UNIT) { + if (current->expr.unit_node.expr->type == NODE_FUNCTION) { + check_function(s, current->expr.unit_node.expr); + } + current = current->expr.unit_node.next; } } @@ -346,7 +667,13 @@ sema *sema_init(parser *p, arena *a) types = NULL; s->ast = p->ast; + global_scope = arena_alloc(a, sizeof(scope)); + global_scope->parent = NULL; + global_scope->defs = NULL; + current_scope = global_scope; + register_type(s, "void", create_integer(s, "void", 0, false)); + register_type(s, "bool", create_integer(s, "bool", 8, false)); register_type(s, "u8", create_integer(s, "u8", 8, false)); register_type(s, "u16", create_integer(s, "u16", 16, false)); register_type(s, "u32", create_integer(s, "u32", 32, false)); diff --git a/sema.h b/sema.h index 9c27332..47832b7 100644 --- a/sema.h +++ b/sema.h @@ -8,6 +8,7 @@ typedef enum { TYPE_VOID, + TYPE_BOOL, TYPE_PTR, TYPE_SLICE, TYPE_FLOAT, @@ -42,6 +43,7 @@ typedef struct _type { char *name; usize name_len; member *members; + struct { char *key; struct _type *value; } *member_types; } structure; struct { char *name; @@ -57,6 +59,11 @@ typedef struct { type **parameters; } prototype; +typedef struct _scope { + struct _scope *parent; + struct { char *key; type *value; } *defs; +} scope; + typedef struct { arena *allocator; ast_node *ast;