implemented integer and float constant type inference

This commit is contained in:
Lorenzo Torres 2025-12-06 00:02:46 +01:00
parent 463ba71843
commit 989a32fa7b
3 changed files with 37 additions and 5 deletions

View file

@ -6,6 +6,5 @@ struct b {
u32 test() u32 test()
{ {
u32 a = (u32)3; f32 a = 5.0;
a = (u32)2;
} }

37
sema.c
View file

@ -22,6 +22,9 @@ static scope *global_scope = NULL;
static scope *current_scope = NULL; static scope *current_scope = NULL;
static type *current_return = NULL; static type *current_return = NULL;
static type *const_int = NULL;
static type *const_float = NULL;
static bool in_loop = false; static bool in_loop = false;
/* Print the error message and sync the parser. */ /* Print the error message and sync the parser. */
@ -422,6 +425,21 @@ static type *get_identifier_type(sema *s, ast_node *node)
static bool match(type *t1, type *t2); static bool match(type *t1, type *t2);
static bool can_cast(type *source, type *dest)
{
if (!dest || !source) return false;
switch (dest->tag) {
case TYPE_INTEGER:
case TYPE_UINTEGER:
return source->tag == TYPE_INTEGER_CONST;
case TYPE_FLOAT:
return source->tag == TYPE_FLOAT_CONST;
default:
return false;
}
}
static type *get_expression_type(sema *s, ast_node *node) static type *get_expression_type(sema *s, ast_node *node)
{ {
if (!node) { if (!node) {
@ -434,9 +452,9 @@ static type *get_expression_type(sema *s, ast_node *node)
case NODE_IDENTIFIER: case NODE_IDENTIFIER:
return get_identifier_type(s, node); return get_identifier_type(s, node);
case NODE_INTEGER: case NODE_INTEGER:
return shget(type_reg, "i32"); return const_int;
case NODE_FLOAT: case NODE_FLOAT:
return shget(type_reg, "f64"); return const_float;
case NODE_STRING: case NODE_STRING:
return get_string_type(s, node); return get_string_type(s, node);
case NODE_CHAR: case NODE_CHAR:
@ -514,6 +532,9 @@ static bool match(type *t1, type *t2)
case TYPE_GENERIC: case TYPE_GENERIC:
/* TODO */ /* TODO */
return false; return false;
case TYPE_INTEGER_CONST:
case TYPE_FLOAT_CONST:
return false;
} }
return false; return false;
@ -600,7 +621,7 @@ static void check_statement(sema *s, ast_node *node)
error(node, "redeclaration of variable."); error(node, "redeclaration of variable.");
break; break;
} }
if (!match(t, get_expression_type(s, node->expr.var_decl.value))) { if (!can_cast(get_expression_type(s, node->expr.var_decl.value), t) && !match(t, get_expression_type(s, node->expr.var_decl.value))) {
error(node, "type mismatch."); error(node, "type mismatch.");
} }
shput(current_scope->defs, name, t); shput(current_scope->defs, name, t);
@ -685,6 +706,16 @@ sema *sema_init(parser *p, arena *a)
register_type(s, "f32", create_float(s, "f32", 32)); register_type(s, "f32", create_float(s, "f32", 32));
register_type(s, "f64", create_float(s, "f64", 64)); register_type(s, "f64", create_float(s, "f64", 64));
const_int = arena_alloc(s->allocator, sizeof(type));
const_int->name = "const_int";
const_int->tag = TYPE_INTEGER_CONST;
const_int->data.integer = 0;
const_float = arena_alloc(s->allocator, sizeof(type));
const_float->name = "const_float";
const_float->tag = TYPE_FLOAT_CONST;
const_float->data.flt = 0;
analyze_unit(s, s->ast); analyze_unit(s, s->ast);
return s; return s;

2
sema.h
View file

@ -12,7 +12,9 @@ typedef enum {
TYPE_PTR, TYPE_PTR,
TYPE_SLICE, TYPE_SLICE,
TYPE_FLOAT, TYPE_FLOAT,
TYPE_FLOAT_CONST,
TYPE_INTEGER, TYPE_INTEGER,
TYPE_INTEGER_CONST,
TYPE_UINTEGER, TYPE_UINTEGER,
TYPE_STRUCT, TYPE_STRUCT,
TYPE_UNION, TYPE_UNION,