From ed25810927a2d37f6830e5f2d8ceee627dc0fb7f Mon Sep 17 00:00:00 2001 From: Lorenzo Torres Date: Thu, 16 Apr 2026 13:04:22 +0200 Subject: [PATCH] Fixed various correctness related issues --- src/Fdt.zig | 83 +++++++++++++++++++++++++++++++++++- src/drivers/Console.zig | 25 ++++++++--- src/main.zig | 24 +++++++---- src/mem/BuddyAllocator.zig | 17 ++++---- src/riscv/PageTable.zig | 86 +++++++++++++++++++++++++++----------- 5 files changed, 185 insertions(+), 50 deletions(-) diff --git a/src/Fdt.zig b/src/Fdt.zig index 669d4ee..efc42a2 100644 --- a/src/Fdt.zig +++ b/src/Fdt.zig @@ -468,6 +468,19 @@ pub fn parse(ptr: *const anyopaque) Error!Fdt { const size_dt_strings = std.mem.bigToNative(u32, header.size_dt_strings); const off_mem_rsvmap = std.mem.bigToNative(u32, header.off_mem_rsvmap); + if (totalsize < @sizeOf(RawHeader)) { + return Error.InvalidStructure; + } + if (!sliceWithinBounds(off_dt_struct, size_dt_struct, totalsize)) { + return Error.InvalidStructure; + } + if (!sliceWithinBounds(off_dt_strings, size_dt_strings, totalsize)) { + return Error.InvalidStructure; + } + if (off_mem_rsvmap > totalsize) { + return Error.InvalidStructure; + } + return Fdt{ .base = base, .totalsize = totalsize, @@ -616,11 +629,74 @@ pub fn cpus(self: *const Fdt) ?Node { return self.findNode("/cpus"); } +pub fn parent(self: *const Fdt, target: Node) ?Node { + var pos: usize = 0; + var stack: [64]Node = undefined; + var depth: usize = 0; + + while (pos < self.struct_block.len) { + const token = self.readToken(pos) orelse return null; + pos += 4; + + switch (token) { + .begin_node => { + const name_start = pos; + while (pos < self.struct_block.len and self.struct_block[pos] != 0) { + pos += 1; + } + if (pos >= self.struct_block.len) return null; + + const node = Node{ + .name = self.struct_block[name_start..pos], + .fdt = self, + .struct_offset = alignUp(pos + 1, 4), + }; + pos = node.struct_offset; + + if (node.struct_offset == target.struct_offset and std.mem.eql(u8, node.name, target.name)) { + if (depth == 0) return null; + return stack[depth - 1]; + } + + if (depth >= stack.len) return null; + stack[depth] = node; + depth += 1; + }, + .end_node => { + if (depth == 0) return null; + depth -= 1; + }, + .prop => { + const len = self.readU32(pos) orelse return null; + pos += 8; + pos += alignUp(len, 4); + }, + .nop => {}, + .end => return null, + } + } + + return null; +} + +pub fn regIterator(self: *const Fdt, node: Node) ?RegIterator { + const reg = node.reg() orelse return null; + const parent_node = self.parent(node) orelse self.root() orelse return null; + return parseReg(reg, parent_node.addressCells(), parent_node.sizeCells()); +} + // Internal helpers fn readToken(self: *const Fdt, offset: usize) ?Token { if (offset + 4 > self.struct_block.len) return null; const val = std.mem.bigToNative(u32, @as(*const u32, @alignCast(@ptrCast(self.struct_block.ptr + offset))).*); - return @as(Token, @enumFromInt(val)); + return switch (val) { + @intFromEnum(Token.begin_node) => .begin_node, + @intFromEnum(Token.end_node) => .end_node, + @intFromEnum(Token.prop) => .prop, + @intFromEnum(Token.nop) => .nop, + @intFromEnum(Token.end) => .end, + else => null, + }; } fn readU32(self: *const Fdt, offset: usize) ?u32 { @@ -673,6 +749,11 @@ fn alignUp(value: anytype, alignment: @TypeOf(value)) @TypeOf(value) { return (value + alignment - 1) & ~(alignment - 1); } +fn sliceWithinBounds(offset: u32, size: u32, total: u32) bool { + const end = @as(u64, offset) + @as(u64, size); + return end <= total; +} + pub fn parseReg(data: []const u8, address_cells: u32, size_cells: u32) RegIterator { return .{ .data = data, diff --git a/src/drivers/Console.zig b/src/drivers/Console.zig index 8b478b4..f7fecbf 100644 --- a/src/drivers/Console.zig +++ b/src/drivers/Console.zig @@ -7,11 +7,25 @@ pub const Writer = struct { interface: std.Io.Writer, pub fn drain(io_w: *std.Io.Writer, data: []const []const u8, splat: usize) !usize { - _ = splat; const self: *Writer = @fieldParentPtr("interface", io_w); - self.base.* = data[0][0]; + var written: usize = 0; - return data.len; + for (data[0 .. data.len - 1]) |chunk| { + for (chunk) |byte| { + self.base.* = byte; + } + written += chunk.len; + } + + const pattern = data[data.len - 1]; + for (0..splat) |_| { + for (pattern) |byte| { + self.base.* = byte; + } + written += pattern.len; + } + + return written; } }; @@ -19,9 +33,8 @@ mmio: *volatile u8, pub fn init(fdt: Fdt) ?Console { if (fdt.findFirstCompatible("ns16550a")) |console| { - var iter = console.getProperty("reg").?.asU32Array(); - _ = iter.next(); - const start = iter.next().?; + var iter = fdt.regIterator(console) orelse return null; + const start = iter.next().?.address; return .{ .mmio = @ptrFromInt(@as(usize, @intCast(start))), }; diff --git a/src/main.zig b/src/main.zig index 7bd4a12..9004d62 100644 --- a/src/main.zig +++ b/src/main.zig @@ -13,7 +13,7 @@ export fn _start() linksection(".text.init") callconv(.naked) noreturn { asm volatile ( \\li fp, 0 \\li ra, 0 - \\li sp, 0x88000000 + \\la sp, __stack_top \\tail kmain ); } @@ -44,14 +44,22 @@ export fn kmain(hartid: u64, fdt_ptr: *const anyopaque) callconv(.c) noreturn { const allocator = buddy.allocator(); const mmu_type = fdt.mmuType(); - var table = isa.PageTable.init(allocator) catch { @panic("Unable to create page table.\n"); }; - table.identityMap(allocator, memory_end, mmu_type) catch {}; - table.map(allocator, @intFromPtr(console.mmio), @intFromPtr(console.mmio), .{.read = 1, .write = 1, .execute = 1}, mmu_type) catch { + if (mmu_type == .bare) { + debug.print("mmu disabled by firmware description.\n", .{}); + } else { + var table = isa.PageTable.init(allocator) catch { + @panic("Unable to create page table.\n"); + }; + table.identityMap(allocator, reg.address, memory_end, mmu_type) catch { + @panic("Unable to identity map kernel memory.\n"); + }; + table.map(allocator, @intFromPtr(console.mmio), @intFromPtr(console.mmio), .{ .read = 1, .write = 1, .execute = 1 }, mmu_type) catch { + @panic("Unable to map console MMIO.\n"); + }; - }; - - table.load(mmu_type); - debug.print("loaded kernel page table.\n", .{}); + table.load(mmu_type); + debug.print("loaded kernel page table.\n", .{}); + } while (true) { asm volatile ("wfi"); diff --git a/src/mem/BuddyAllocator.zig b/src/mem/BuddyAllocator.zig index 6ae181b..414701c 100644 --- a/src/mem/BuddyAllocator.zig +++ b/src/mem/BuddyAllocator.zig @@ -22,6 +22,11 @@ fn getLevelSize(level: u8) usize { return @as(usize, MIN_BLOCK_SIZE) << @intCast(level); } +fn levelForRequest(len: usize, alignment: std.mem.Alignment) usize { + const required_size = @max(@max(len, MIN_BLOCK_SIZE), alignment.toByteUnits()); + return math.log2_int_ceil(usize, (required_size + MIN_BLOCK_SIZE - 1) / MIN_BLOCK_SIZE); +} + pub fn freeRange(self: *BuddyAllocator, start: [*]u8, size: usize) void { var current_addr = @intFromPtr(start); const end_addr = current_addr + size; @@ -119,23 +124,15 @@ pub fn allocator(self: *BuddyAllocator) Allocator { } fn alloc(ctx: *anyopaque, len: usize, ptr_align: std.mem.Alignment, ret_addr: usize) ?[*]u8 { - _ = ptr_align; _ = ret_addr; const self: *BuddyAllocator = @ptrCast(@alignCast(ctx)); - const actual_size = @max(len, MIN_BLOCK_SIZE); - const level = math.log2_int_ceil(usize, (actual_size + MIN_BLOCK_SIZE - 1) / MIN_BLOCK_SIZE); - - return self.allocBlock(@intCast(level)); + return self.allocBlock(@intCast(levelForRequest(len, ptr_align))); } fn free(ctx: *anyopaque, buf: []u8, buf_align: std.mem.Alignment, ret_addr: usize) void { - _ = buf_align; _ = ret_addr; const self: *BuddyAllocator = @ptrCast(@alignCast(ctx)); - const actual_size = @max(buf.len, MIN_BLOCK_SIZE); - const level = math.log2_int_ceil(usize, (actual_size + MIN_BLOCK_SIZE - 1) / MIN_BLOCK_SIZE); - - self.freeBlock(buf.ptr, @intCast(level)); + self.freeBlock(buf.ptr, @intCast(levelForRequest(buf.len, buf_align))); } diff --git a/src/riscv/PageTable.zig b/src/riscv/PageTable.zig index 1edf701..4f6d28d 100644 --- a/src/riscv/PageTable.zig +++ b/src/riscv/PageTable.zig @@ -1,11 +1,8 @@ const std = @import("std"); const Allocator = std.mem.Allocator; const PageTable = @This(); -const debug = @import("../debug.zig"); const isa = @import("isa.zig"); -const MEMORY_START = @extern([*]u8, .{.name = "__memory_start"}); - pub const EntryFlags = packed struct { valid: u1 = 1, read: u1, @@ -25,17 +22,21 @@ pub const Entry = packed struct { n: u1 = 0, }; -entries: [512]Entry, +const PTE_VALID = @as(u64, 1) << 0; +const PTE_READ = @as(u64, 1) << 1; +const PTE_WRITE = @as(u64, 1) << 2; +const PTE_EXECUTE = @as(u64, 1) << 3; +const PTE_USER = @as(u64, 1) << 4; + +entries: [512]u64, pub fn init(allocator: Allocator) !*PageTable { const table = try allocator.create(PageTable); - for (&table.entries) |*entry| { - entry.* = @bitCast(@as(u64, 0x0)); - } + @memset(&table.entries, 0); return table; } -pub fn identityMap(self: *PageTable, allocator: Allocator, memory_end: u64, mode: isa.Satp.Mode) !void { +pub fn identityMap(self: *PageTable, allocator: Allocator, start: u64, end: u64, mode: isa.Satp.Mode) !void { const flags = EntryFlags{ .valid = 1, .read = 1, @@ -44,9 +45,16 @@ pub fn identityMap(self: *PageTable, allocator: Allocator, memory_end: u64, mode .user = 0, }; - var addr: u64 = 0x0; - while (addr < memory_end) : (addr += 0x1000) { - try self.map(allocator, addr, addr, flags, mode); + var addr = start; + while (addr < end) { + const remaining = end - addr; + if (addr % (2 * 1024 * 1024) == 0 and remaining >= 2 * 1024 * 1024) { + try self.mapLarge2MiB(allocator, addr, addr, flags, mode); + addr += 2 * 1024 * 1024; + } else { + try self.map(allocator, addr, addr, flags, mode); + addr += 0x1000; + } } } @@ -58,9 +66,21 @@ pub fn load(self: *const PageTable, mmu_type: isa.Satp.Mode) void { } pub fn map(self: *PageTable, allocator: Allocator, virtual: u64, physical: u64, flags: EntryFlags, mode: isa.Satp.Mode) !void { - const PAGE_SIZE = 4096; + return self.mapAtLevel(allocator, virtual, physical, flags, mode, 0); +} - if (virtual % PAGE_SIZE != 0 or physical % PAGE_SIZE != 0) { +pub fn mapLarge2MiB(self: *PageTable, allocator: Allocator, virtual: u64, physical: u64, flags: EntryFlags, mode: isa.Satp.Mode) !void { + return self.mapAtLevel(allocator, virtual, physical, flags, mode, 1); +} + +fn mapAtLevel(self: *PageTable, allocator: Allocator, virtual: u64, physical: u64, flags: EntryFlags, mode: isa.Satp.Mode, leaf_level: usize) !void { + const page_size: u64 = switch (leaf_level) { + 0 => 4096, + 1 => 2 * 1024 * 1024, + else => unreachable, + }; + + if (virtual % page_size != 0 or physical % page_size != 0) { return error.AddressNotAligned; } @@ -81,26 +101,42 @@ pub fn map(self: *PageTable, allocator: Allocator, virtual: u64, physical: u64, var current_table = self; - while (level > 0) : (level -= 1) { + while (level > leaf_level) : (level -= 1) { const index = vpn[level]; const entry = ¤t_table.entries[index]; - if (entry.flags.valid == 0) { + if (!isValid(entry.*)) { const new_table = try PageTable.init(allocator); - + const table_phys = @intFromPtr(new_table); - entry.* = .{ - .flags = .{ .valid = 1, .read = 0, .write = 0, .execute = 0, .user = 0 }, - .ppn = @truncate(table_phys >> 12), - }; + entry.* = encodeBranchEntry(table_phys); } - const next_table_phys = @as(u64, entry.ppn) << 12; + const next_table_phys = decodePhysical(entry.*); current_table = @ptrFromInt(next_table_phys); } - current_table.entries[vpn[0]] = .{ - .flags = flags, - .ppn = @truncate(physical >> 12), - }; + current_table.entries[vpn[leaf_level]] = encodeLeafEntry(physical, flags); +} + +fn isValid(entry: u64) bool { + return (entry & PTE_VALID) != 0; +} + +fn decodePhysical(entry: u64) u64 { + return ((entry >> 10) & ((@as(u64, 1) << 44) - 1)) << 12; +} + +fn encodeBranchEntry(physical: u64) u64 { + return PTE_VALID | ((physical >> 12) << 10); +} + +fn encodeLeafEntry(physical: u64, flags: EntryFlags) u64 { + var entry = (physical >> 12) << 10; + entry |= PTE_VALID; + if (flags.read != 0) entry |= PTE_READ; + if (flags.write != 0) entry |= PTE_WRITE; + if (flags.execute != 0) entry |= PTE_EXECUTE; + if (flags.user != 0) entry |= PTE_USER; + return entry; }