diff --git a/examples/hello_world.l b/examples/hello_world.l index 8bdb256..68b4c08 100644 --- a/examples/hello_world.l +++ b/examples/hello_world.l @@ -6,6 +6,5 @@ struct b { u32 test() { - u32 a = (u32)3; - a = (u32)2; + f32 a = 5.0; } diff --git a/sema.c b/sema.c index a1499a4..15ea070 100644 --- a/sema.c +++ b/sema.c @@ -22,6 +22,9 @@ static scope *global_scope = NULL; static scope *current_scope = NULL; static type *current_return = NULL; +static type *const_int = NULL; +static type *const_float = NULL; + static bool in_loop = false; /* Print the error message and sync the parser. */ @@ -422,6 +425,21 @@ static type *get_identifier_type(sema *s, ast_node *node) static bool match(type *t1, type *t2); +static bool can_cast(type *source, type *dest) +{ + if (!dest || !source) return false; + + switch (dest->tag) { + case TYPE_INTEGER: + case TYPE_UINTEGER: + return source->tag == TYPE_INTEGER_CONST; + case TYPE_FLOAT: + return source->tag == TYPE_FLOAT_CONST; + default: + return false; + } +} + static type *get_expression_type(sema *s, ast_node *node) { if (!node) { @@ -434,9 +452,9 @@ static type *get_expression_type(sema *s, ast_node *node) case NODE_IDENTIFIER: return get_identifier_type(s, node); case NODE_INTEGER: - return shget(type_reg, "i32"); + return const_int; case NODE_FLOAT: - return shget(type_reg, "f64"); + return const_float; case NODE_STRING: return get_string_type(s, node); case NODE_CHAR: @@ -514,6 +532,9 @@ static bool match(type *t1, type *t2) case TYPE_GENERIC: /* TODO */ return false; + case TYPE_INTEGER_CONST: + case TYPE_FLOAT_CONST: + return false; } return false; @@ -600,7 +621,7 @@ static void check_statement(sema *s, ast_node *node) error(node, "redeclaration of variable."); break; } - if (!match(t, get_expression_type(s, node->expr.var_decl.value))) { + if (!can_cast(get_expression_type(s, node->expr.var_decl.value), t) && !match(t, get_expression_type(s, node->expr.var_decl.value))) { error(node, "type mismatch."); } shput(current_scope->defs, name, t); @@ -685,6 +706,16 @@ sema *sema_init(parser *p, arena *a) register_type(s, "f32", create_float(s, "f32", 32)); register_type(s, "f64", create_float(s, "f64", 64)); + const_int = arena_alloc(s->allocator, sizeof(type)); + const_int->name = "const_int"; + const_int->tag = TYPE_INTEGER_CONST; + const_int->data.integer = 0; + + const_float = arena_alloc(s->allocator, sizeof(type)); + const_float->name = "const_float"; + const_float->tag = TYPE_FLOAT_CONST; + const_float->data.flt = 0; + analyze_unit(s, s->ast); return s; diff --git a/sema.h b/sema.h index 47832b7..a1f4285 100644 --- a/sema.h +++ b/sema.h @@ -12,7 +12,9 @@ typedef enum { TYPE_PTR, TYPE_SLICE, TYPE_FLOAT, + TYPE_FLOAT_CONST, TYPE_INTEGER, + TYPE_INTEGER_CONST, TYPE_UINTEGER, TYPE_STRUCT, TYPE_UNION,