From 3e87a79d944ca38fc4a1ab7483a56e0b822eab65 Mon Sep 17 00:00:00 2001 From: Lorenzo Torres Date: Tue, 16 Dec 2025 23:30:02 +0100 Subject: [PATCH] project cleanup --- src/Lexer.zig | 33 ++++++++++++++++- src/Node.zig | 69 +++++++++++++++++++++++++++++++++++ src/Parser.zig | 97 ++++++++++++++------------------------------------ src/Type.zig | 65 +++++++++++++++++++++++++++++++++ src/main.zig | 15 ++++---- 5 files changed, 202 insertions(+), 77 deletions(-) create mode 100644 src/Node.zig create mode 100644 src/Type.zig diff --git a/src/Lexer.zig b/src/Lexer.zig index 6e7e8aa..36cdd7b 100644 --- a/src/Lexer.zig +++ b/src/Lexer.zig @@ -1,6 +1,8 @@ const std = @import("std"); const Lexer = @This(); +var keywords: std.StringHashMap(TokenType) = undefined; + index: usize, source: []u8, start: usize, @@ -10,10 +12,14 @@ pub const TokenType = enum { minus, star, slash, + double_colon, + integer, float, identifier, + function, + eof, illegal, }; @@ -57,6 +63,7 @@ pub fn next(lexer: *Lexer) Token { // Identifiers if (std.ascii.isAlphabetic(c)) { + return lexer.identifier(); } // Single Character Tokens @@ -66,6 +73,14 @@ pub fn next(lexer: *Lexer) Token { '-' => return lexer.makeToken(.minus), '*' => return lexer.makeToken(.star), '/' => return lexer.makeToken(.slash), + ':' => { + if (lexer.source[lexer.index+1] == ':') { + lexer.index += 1; + return lexer.makeToken(.double_colon); + } else { + return lexer.makeToken(.illegal); + } + }, else => return lexer.makeToken(.illegal), } } @@ -103,6 +118,19 @@ fn skipWhitespaceAndComments(lexer: *Lexer) void { } } +fn identifier(lexer: *Lexer) Token { + while (lexer.index < lexer.source.len and (std.ascii.isAlphanumeric(lexer.source[lexer.index]) or lexer.source[lexer.index] == '_')) { + lexer.index += 1; + } + + var token = lexer.makeToken(.identifier); + if (keywords.get(token.lexeme)) |keyword| { + token.@"type" = keyword; + } + + return token; +} + fn number(lexer: *Lexer) Token { while (lexer.index < lexer.source.len and std.ascii.isDigit(lexer.source[lexer.index])) { lexer.index += 1; @@ -123,7 +151,10 @@ fn number(lexer: *Lexer) Token { /// If `source` was allocated on the heap, /// the caller must free it. -pub fn init(source: []u8) Lexer { +pub fn init(allocator: std.mem.Allocator, source: []u8) !Lexer { + keywords = std.StringHashMap(TokenType).init(allocator); + try keywords.put("fn", .function); + return .{ .index = 0, .source = source, diff --git a/src/Node.zig b/src/Node.zig new file mode 100644 index 0000000..9108707 --- /dev/null +++ b/src/Node.zig @@ -0,0 +1,69 @@ +const std = @import("std"); +const Parser = @import("Parser.zig"); +const Node = @This(); + +@"type": NodeType, +id: u64, +inputs: std.ArrayList(*Node), +outputs: std.ArrayList(*Node), +data: extern union { + integer: u64, + float: f64, +}, + +pub const NodeType = enum { + add, + sub, + mul, + div, + + integer, + float, + + start, + @"return", +}; + + +pub fn init(parser: *Parser, @"type": NodeType) !*Node { + var node = try parser.allocator.create(Node); + node.@"type" = @"type"; + node.inputs = .{}; + node.outputs = .{}; + node.data = undefined; + return node; +} + +pub fn globalNumbering(node: *Node, parser: *Parser) !*Node { + const node_hash = node.hash(); + node.id = node_hash; + if (parser.node_table.get(node_hash)) |n| { + parser.allocator.destroy(node); + return n; + } + + try parser.node_table.put(node_hash, node); + + return node; +} + +pub fn hash(node: *Node) u64 { + var hasher = std.hash.Wyhash.init(0); + std.hash.autoHash(&hasher, node.@"type"); + + switch (node.@"type") { + .integer => std.hash.autoHash(&hasher, node.data.integer), + .float => std.hash.autoHash(&hasher, @as(u64, @bitCast(node.data.float))), + else => {}, + } + + for (node.inputs.items) |n| { + std.hash.autoHash(&hasher, @intFromPtr(n)); + } + + return hasher.final(); +} + +pub fn deinit(node: *Node, parser: *Parser) void { + parser.allocator.destroy(node); +} diff --git a/src/Parser.zig b/src/Parser.zig index 75bd9bd..393803e 100644 --- a/src/Parser.zig +++ b/src/Parser.zig @@ -1,80 +1,16 @@ const std = @import("std"); const Lexer = @import("Lexer.zig"); +const Type = @import("Type.zig"); const Parser = @This(); +pub const Node = @import("Node.zig"); + lexer: *Lexer, allocator: std.mem.Allocator, node_table: std.AutoHashMap(u64, *Node), previous: Lexer.Token, current: Lexer.Token, -pub const NodeType = enum { - add, - sub, - mul, - div, - - integer, - float, - - start, - @"return", -}; - -pub const Node = struct { - @"type": NodeType, - id: u64, - inputs: std.ArrayList(*Node), - outputs: std.ArrayList(*Node), - data: extern union { - integer: u64, - float: f64, - }, - - pub fn init(parser: *Parser, @"type": NodeType) !*Node { - var node = try parser.allocator.create(Node); - node.@"type" = @"type"; - node.inputs = .{}; - node.outputs = .{}; - node.data = undefined; - return node; - } - - pub fn globalNumbering(node: *Node, parser: *Parser) !*Node { - const node_hash = node.hash(); - node.id = node_hash; - if (parser.node_table.get(node_hash)) |n| { - parser.allocator.destroy(node); - return n; - } - - try parser.node_table.put(node_hash, node); - - return node; - } - - pub fn hash(node: *Node) u64 { - var hasher = std.hash.Wyhash.init(0); - std.hash.autoHash(&hasher, node.@"type"); - - switch (node.@"type") { - .integer => std.hash.autoHash(&hasher, node.data.integer), - .float => std.hash.autoHash(&hasher, @as(u64, @bitCast(node.data.float))), - else => {}, - } - - for (node.inputs.items) |n| { - std.hash.autoHash(&hasher, @intFromPtr(n)); - } - - return hasher.final(); - } - - pub fn deinit(node: *Node, parser: *Parser) void { - parser.allocator.destroy(node); - } -}; - pub fn match(parser: *Parser, expected: Lexer.TokenType) bool { if (parser.current.@"type" == expected) { parser.advance(); @@ -83,6 +19,10 @@ pub fn match(parser: *Parser, expected: Lexer.TokenType) bool { return false; } +pub fn check(parser: *Parser, expected: Lexer.TokenType) bool { + return parser.current.@"type" == expected; +} + pub fn advance(parser: *Parser) void { parser.previous = parser.current; parser.current = parser.lexer.next(); @@ -113,14 +53,17 @@ pub fn buildTerm(parser: *Parser) !?*Node { var lhs = try parser.buildFactor(); while (parser.match(.star) or parser.match(.slash)) { - const node_type: NodeType = switch (parser.previous.@"type") { + const node_type: Node.NodeType = switch (parser.previous.@"type") { .star => .mul, .slash => .div, else => unreachable, }; var node = try Node.init(parser, node_type); - try node.inputs.append(parser.allocator, (try parser.buildFactor()).?); + const rhs = try parser.buildFactor(); + try node.inputs.append(parser.allocator, rhs.?); try node.inputs.append(parser.allocator, lhs.?); + try lhs.?.outputs.append(parser.allocator, node); + try rhs.?.outputs.append(parser.allocator, node); node = try node.globalNumbering(parser); lhs = node; } @@ -132,7 +75,7 @@ pub fn buildExpression(parser: *Parser) !?*Node { var lhs = try parser.buildTerm(); while (parser.match(.plus) or parser.match(.minus)) { - const node_type: NodeType = switch (parser.previous.@"type") { + const node_type: Node.NodeType = switch (parser.previous.@"type") { .plus => .add, .minus => .sub, else => unreachable, @@ -147,6 +90,20 @@ pub fn buildExpression(parser: *Parser) !?*Node { return lhs; } +pub fn buildStatement(parser: *Parser) !?*Node { + if (parser.match(.identifier)) { + const id = parser.prev_token; + _ = id; + if (parser.match(.double_colon)) { + // Type signature + } else if (parser.check(.identifier)) { + // Function definition + } + } + + return error.UnexpectedToken; +} + pub fn buildGraph(parser: *Parser) !?*Node { return try buildExpression(parser); } diff --git a/src/Type.zig b/src/Type.zig new file mode 100644 index 0000000..177cb22 --- /dev/null +++ b/src/Type.zig @@ -0,0 +1,65 @@ +const Type = @This(); + +kind: Kind, +data: extern union { + tensor: struct { + @"type": *Type, + shape: []const usize, + }, + + @"struct": struct { + name: []const u8, + fields: []Field, + }, + + function: struct { + params: []Parameter, + @"return": *Type, + }, + + generic: struct { + name: []const u8, + }, +}, + +pub const Kind = enum { + @"void", + @"bool", + @"struct", + @"u8", + @"u16", + @"u32", + @"u64", + @"i8", + @"i16", + @"i32", + @"i64", + @"f16", + @"f32", + @"f64", + + tensor, + function, + generic, +}; + +pub const Function = struct { + name: []const u8, + paraeters: []*Type, + @"return": *Type, +}; + +pub const Trait = struct { + name: []const u8, + methods: []Function, +}; + +pub const Field = struct { + name: []const u8, + @"type": *Type, +}; + +pub const Parameter = struct { + name: []const u8, + @"type": *Type, +}; diff --git a/src/main.zig b/src/main.zig index 8c5faad..d7c0f81 100644 --- a/src/main.zig +++ b/src/main.zig @@ -30,13 +30,16 @@ pub fn main() !void { } const allocator = gpa.allocator(); - var lexer = al.Lexer.init(@constCast("3*2+2.2")); + var lexer = try al.Lexer.init(allocator, @constCast("2+3+4")); var parser = al.Parser.init(allocator, &lexer); defer parser.deinit(); const graph = try parser.buildGraph(); - defer graph.?.deinit(&parser); - std.debug.print("digraph G {{\n", .{}); - nodeName(graph.?); - printGraph(graph.?); - std.debug.print("}}\n", .{}); + if (graph) |g| { + defer g.deinit(&parser); + std.debug.print("digraph G {{\n", .{}); + nodeName(g); + printGraph(g); + std.debug.print("}}\n", .{}); + } + }