diff --git a/build.zig b/build.zig index 5f680cd..413d336 100644 --- a/build.zig +++ b/build.zig @@ -12,13 +12,15 @@ pub fn build(b: *std.Build) void { const kernel = b.addExecutable(.{ .name = "kernel", .root_module = b.createModule(.{ - .root_source_file = b.path("src/main.zig"), + .root_source_file = b.path("src/kernel.zig"), .target = target, .optimize = optimize, .code_model = .medium, }), }); + kernel.root_module.addAssemblyFile(b.path("src/boot.S")); + kernel.setLinkerScript(b.path("linker.ld")); b.installArtifact(kernel); diff --git a/linker.ld b/linker.ld index d87d503..ee89ccc 100644 --- a/linker.ld +++ b/linker.ld @@ -1,35 +1,86 @@ OUTPUT_ARCH(riscv) ENTRY(_start) -MEMORY -{ - RAM (rwx) : ORIGIN = 0x80200000, LENGTH = 126M -} +KERNEL_PHYS_BASE = 0x80200000; +KERNEL_VIRT_OFFSET = 0xffffffc000000000; +BOOT_STACK_SIZE = 0x10000; +KERNEL_STACK_SIZE = 0x10000; SECTIONS { - .text : { - *(.text.init) - *(.text .text.*) - } > RAM + . = KERNEL_PHYS_BASE; - .rodata : { - *(.rodata .rodata.*) - } > RAM + __boot_phys_start = .; - .data : { - *(.data .data.*) - } > RAM + .boot.text : ALIGN(0x1000) { + *(.boot.text) + *(.boot.text.*) + } - .bss : { - __bss_start = .; - *(.bss .bss.*) - *(COMMON) - __bss_end = .; - } > RAM + .boot.rodata : ALIGN(0x1000) { + *(.boot.rodata) + *(.boot.rodata.*) + } + + .boot.data : ALIGN(0x1000) { + *(.boot.data) + *(.boot.data.*) + } + + .boot.bss (NOLOAD) : ALIGN(0x1000) { + *(.boot.bss) + *(.boot.bss.*) + } . = ALIGN(16); - __stack_top = . + 0x10000; + . += BOOT_STACK_SIZE; + __boot_stack_top = .; . = ALIGN(0x1000); - PROVIDE(__memory_start = .); + __boot_phys_end = .; + + __kernel_phys_start = .; + . = KERNEL_VIRT_OFFSET + __kernel_phys_start; + __kernel_start = .; + __kernel_virt_start = .; + + .text : AT(ADDR(.text) - KERNEL_VIRT_OFFSET) ALIGN(0x1000) { + *(.text) + *(.text.*) + } + + .rodata : AT(ADDR(.rodata) - KERNEL_VIRT_OFFSET) ALIGN(0x1000) { + *(.rodata) + *(.rodata.*) + } + + .data : AT(ADDR(.data) - KERNEL_VIRT_OFFSET) ALIGN(0x1000) { + PROVIDE(__global_pointer$ = . + 0x800); + *(.data) + *(.data.*) + *(.sdata) + *(.sdata.*) + } + + .bss : AT(ADDR(.bss) - KERNEL_VIRT_OFFSET) ALIGN(0x1000) { + __bss_start = .; + *(.bss) + *(.bss.*) + *(.sbss) + *(.sbss.*) + *(COMMON) + . = ALIGN(16); + . += KERNEL_STACK_SIZE; + __stack_top = .; + __bss_end = .; + } + + __kernel_end = ALIGN(., 0x1000); + __kernel_virt_end = __kernel_end; + __kernel_phys_end = __kernel_end - KERNEL_VIRT_OFFSET; + __memory_start = __kernel_phys_end; + + /DISCARD/ : { + *(.eh_frame) + *(.eh_frame_hdr) + } } diff --git a/src/Fdt.zig b/src/Fdt.zig index 669d4ee..efc42a2 100644 --- a/src/Fdt.zig +++ b/src/Fdt.zig @@ -468,6 +468,19 @@ pub fn parse(ptr: *const anyopaque) Error!Fdt { const size_dt_strings = std.mem.bigToNative(u32, header.size_dt_strings); const off_mem_rsvmap = std.mem.bigToNative(u32, header.off_mem_rsvmap); + if (totalsize < @sizeOf(RawHeader)) { + return Error.InvalidStructure; + } + if (!sliceWithinBounds(off_dt_struct, size_dt_struct, totalsize)) { + return Error.InvalidStructure; + } + if (!sliceWithinBounds(off_dt_strings, size_dt_strings, totalsize)) { + return Error.InvalidStructure; + } + if (off_mem_rsvmap > totalsize) { + return Error.InvalidStructure; + } + return Fdt{ .base = base, .totalsize = totalsize, @@ -616,11 +629,74 @@ pub fn cpus(self: *const Fdt) ?Node { return self.findNode("/cpus"); } +pub fn parent(self: *const Fdt, target: Node) ?Node { + var pos: usize = 0; + var stack: [64]Node = undefined; + var depth: usize = 0; + + while (pos < self.struct_block.len) { + const token = self.readToken(pos) orelse return null; + pos += 4; + + switch (token) { + .begin_node => { + const name_start = pos; + while (pos < self.struct_block.len and self.struct_block[pos] != 0) { + pos += 1; + } + if (pos >= self.struct_block.len) return null; + + const node = Node{ + .name = self.struct_block[name_start..pos], + .fdt = self, + .struct_offset = alignUp(pos + 1, 4), + }; + pos = node.struct_offset; + + if (node.struct_offset == target.struct_offset and std.mem.eql(u8, node.name, target.name)) { + if (depth == 0) return null; + return stack[depth - 1]; + } + + if (depth >= stack.len) return null; + stack[depth] = node; + depth += 1; + }, + .end_node => { + if (depth == 0) return null; + depth -= 1; + }, + .prop => { + const len = self.readU32(pos) orelse return null; + pos += 8; + pos += alignUp(len, 4); + }, + .nop => {}, + .end => return null, + } + } + + return null; +} + +pub fn regIterator(self: *const Fdt, node: Node) ?RegIterator { + const reg = node.reg() orelse return null; + const parent_node = self.parent(node) orelse self.root() orelse return null; + return parseReg(reg, parent_node.addressCells(), parent_node.sizeCells()); +} + // Internal helpers fn readToken(self: *const Fdt, offset: usize) ?Token { if (offset + 4 > self.struct_block.len) return null; const val = std.mem.bigToNative(u32, @as(*const u32, @alignCast(@ptrCast(self.struct_block.ptr + offset))).*); - return @as(Token, @enumFromInt(val)); + return switch (val) { + @intFromEnum(Token.begin_node) => .begin_node, + @intFromEnum(Token.end_node) => .end_node, + @intFromEnum(Token.prop) => .prop, + @intFromEnum(Token.nop) => .nop, + @intFromEnum(Token.end) => .end, + else => null, + }; } fn readU32(self: *const Fdt, offset: usize) ?u32 { @@ -673,6 +749,11 @@ fn alignUp(value: anytype, alignment: @TypeOf(value)) @TypeOf(value) { return (value + alignment - 1) & ~(alignment - 1); } +fn sliceWithinBounds(offset: u32, size: u32, total: u32) bool { + const end = @as(u64, offset) + @as(u64, size); + return end <= total; +} + pub fn parseReg(data: []const u8, address_cells: u32, size_cells: u32) RegIterator { return .{ .data = data, diff --git a/src/boot.S b/src/boot.S new file mode 100644 index 0000000..f4e421b --- /dev/null +++ b/src/boot.S @@ -0,0 +1,90 @@ +.section .boot.text, "ax", @progbits +.globl _start + +.equ PTE_VALID, 0x001 +.equ PTE_READ, 0x002 +.equ PTE_WRITE, 0x004 +.equ PTE_EXECUTE, 0x008 +.equ PTE_FLAGS, PTE_VALID | PTE_READ | PTE_WRITE | PTE_EXECUTE + +.equ SATP_MODE_SV39, (8 << 60) + +.equ LOW_RAM_ROOT_INDEX, (2 * 8) +.equ KERNEL_ROOT_INDEX, (0x102 * 8) +.equ PHYSMAP_ROOT_INDEX, (0x142 * 8) +.equ MMIO_ROOT_INDEX, (0x180 * 8) + +.equ LOW_RAM_GIGAPAGE_PTE, 0x2000000f +.equ ZERO_GIGAPAGE_PTE, 0x0000000f +.equ UART_PHYS, 0x10000000 + +_start: + li fp, 0 + li ra, 0 + la sp, __boot_stack_top + + la t0, boot_root_page + + li t1, LOW_RAM_GIGAPAGE_PTE + sd t1, LOW_RAM_ROOT_INDEX(t0) + li t2, PHYSMAP_ROOT_INDEX + add t2, t2, t0 + sd t1, 0(t2) + + li t1, LOW_RAM_GIGAPAGE_PTE + li t2, KERNEL_ROOT_INDEX + add t2, t2, t0 + sd t1, 0(t2) + li t1, ZERO_GIGAPAGE_PTE + li t2, MMIO_ROOT_INDEX + add t2, t2, t0 + sd t1, 0(t2) + + srli t1, t0, 12 + li t2, SATP_MODE_SV39 + or t1, t1, t2 + csrw satp, t1 + sfence.vma + + li a2, UART_PHYS + la t3, memory_start_ptr + ld a3, 0(t3) + li a4, 0 + la t3, bss_start_ptr + ld t0, 0(t3) + la t3, bss_end_ptr + ld t1, 0(t3) + bgeu t0, t1, 2f +1: + sd zero, 0(t0) + addi t0, t0, 8 + bltu t0, t1, 1b +2: + la t3, stack_top_ptr + ld sp, 0(t3) + la t3, global_pointer_ptr + ld gp, 0(t3) + li fp, 0 + la t3, kmain_ptr + ld t0, 0(t3) + jr t0 + +.section .boot.rodata, "a", @progbits +.balign 8 +memory_start_ptr: + .dword __memory_start +bss_start_ptr: + .dword __bss_start +bss_end_ptr: + .dword __bss_end +stack_top_ptr: + .dword __stack_top +global_pointer_ptr: + .dword __global_pointer$ +kmain_ptr: + .dword kmain + +.section .boot.bss, "aw", @nobits +.balign 4096 +boot_root_page: + .skip 4096 diff --git a/src/drivers/Console.zig b/src/drivers/Console.zig index 8b478b4..8290852 100644 --- a/src/drivers/Console.zig +++ b/src/drivers/Console.zig @@ -2,47 +2,35 @@ const std = @import("std"); const Fdt = @import("../Fdt.zig"); const Console = @This(); -pub const Writer = struct { - base: *volatile u8, - interface: std.Io.Writer, - - pub fn drain(io_w: *std.Io.Writer, data: []const []const u8, splat: usize) !usize { - _ = splat; - const self: *Writer = @fieldParentPtr("interface", io_w); - self.base.* = data[0][0]; - - return data.len; - } -}; - mmio: *volatile u8, +pub fn initAt(address: usize) Console { + return .{ + .mmio = @ptrFromInt(address), + }; +} + pub fn init(fdt: Fdt) ?Console { if (fdt.findFirstCompatible("ns16550a")) |console| { - var iter = console.getProperty("reg").?.asU32Array(); - _ = iter.next(); - const start = iter.next().?; - return .{ - .mmio = @ptrFromInt(@as(usize, @intCast(start))), - }; + var iter = fdt.regIterator(console) orelse return null; + const start = iter.next().?.address; + return initAt(@as(usize, @intCast(start))); } return null; } -pub fn writer(self: *const Console) Writer { - return .{ - .base = self.mmio, - .interface = std.Io.Writer { - .buffer = &[_]u8{}, - .vtable = &.{.drain = Writer.drain}, - }, - }; +pub fn write(self: *const Console, bytes: []const u8) void { + for (bytes) |byte| { + self.mmio.* = byte; + } } pub fn print(self: *const Console, comptime s: []const u8, args: anytype) void { - var w = self.writer(); - w.interface.print(s, args) catch { - @panic("Failed to print debug message.\n"); + var buffer: [512]u8 = undefined; + const formatted = std.fmt.bufPrint(&buffer, s, args) catch { + self.write("debug print overflow\n"); + return; }; + self.write(formatted); } diff --git a/src/kernel.zig b/src/kernel.zig new file mode 100644 index 0000000..b60a9ef --- /dev/null +++ b/src/kernel.zig @@ -0,0 +1,96 @@ +const std = @import("std"); +const Fdt = @import("Fdt.zig"); +const Console = @import("drivers/Console.zig"); +const debug = @import("debug.zig"); +const mem = @import("mem.zig"); +const isa = @import("riscv/isa.zig"); + +const KERNEL_PHYS_BASE: u64 = 0x80200000; + +pub const panic = debug.KernelPanic; + +pub export fn kmain(hartid: u64, fdt_phys: usize, console_phys: usize, alloc_start_phys: usize, memory_end_hint: usize) callconv(.c) noreturn { + _ = hartid; + + const console = Console.initAt(@intCast(mem.physToMmioVirt(console_phys))); + debug.init(console); + debug.print("entered higher-half kernel.\n", .{}); + + const fdt = Fdt.parse(@ptrFromInt(mem.physToDirectMap(fdt_phys))) catch { + @panic("Unable to parse higher-half FDT.\n"); + }; + debug.print("fdt remapped at 0x{x}.\n", .{mem.physToDirectMap(fdt_phys)}); + + const root = fdt.root().?; + const memory = fdt.memory().?; + var reg_iter = Fdt.parseReg(memory.getProperty("reg").?.data, root.addressCells(), root.sizeCells()); + const reg = reg_iter.next().?; + const detected_memory_end = reg.address + reg.size; + const memory_end = if (memory_end_hint != 0) @min(detected_memory_end, memory_end_hint) else detected_memory_end; + debug.print("detected RAM end at 0x{x}.\n", .{memory_end}); + + var buddy: mem.BuddyAllocator = .{}; + const alloc_start = mem.physToDirectMap(alloc_start_phys); + buddy.init(@as([*]u8, @ptrFromInt(alloc_start))[0..memory_end - alloc_start_phys]); + debug.print("direct map allocator initialized.\n", .{}); + + const allocator = buddy.allocator(); + const mmu_type = fdt.mmuType(); + + if (mmu_type != .bare) { + var table = isa.PageTable.init(allocator) catch { + @panic("Unable to allocate higher-half page table.\n"); + }; + + const kernel_flags: isa.PageTable.EntryFlags = .{ + .valid = 1, + .read = 1, + .write = 1, + .execute = 1, + .user = 0, + }; + + table.mapRange( + allocator, + mem.physToKernelVirt(KERNEL_PHYS_BASE), + KERNEL_PHYS_BASE, + alloc_start_phys - KERNEL_PHYS_BASE, + kernel_flags, + mmu_type, + mem.PHYS_MAP_BASE, + ) catch { + @panic("Unable to map higher-half kernel image.\n"); + }; + + table.mapRange( + allocator, + mem.physToDirectMap(reg.address), + reg.address, + memory_end - reg.address, + .{ .read = 1, .write = 1, .execute = 1 }, + mmu_type, + mem.PHYS_MAP_BASE, + ) catch { + @panic("Unable to map direct map window.\n"); + }; + + const console_page = std.mem.alignBackward(u64, @intCast(console_phys), mem.PAGE_SIZE); + table.map( + allocator, + mem.physToMmioVirt(console_page), + console_page, + .{ .read = 1, .write = 1, .execute = 1 }, + mmu_type, + mem.PHYS_MAP_BASE, + ) catch { + @panic("Unable to map console MMIO.\n"); + }; + + table.load(mmu_type, mem.PHYS_MAP_BASE); + debug.print("reloaded higher-half kernel page table.\n", .{}); + } + + while (true) { + asm volatile ("wfi"); + } +} diff --git a/src/main.zig b/src/main.zig deleted file mode 100644 index 1e9fc1f..0000000 --- a/src/main.zig +++ /dev/null @@ -1,73 +0,0 @@ -const std = @import("std"); -const isa = @import("riscv/isa.zig"); -const Fdt = @import("Fdt.zig"); -const Console = @import("drivers/Console.zig"); -const debug = @import("debug.zig"); -const mem = @import("mem.zig"); - -const UART_BASE: usize = 0x10000000; -const MEMORY_START = @extern([*]u8, .{.name = "__memory_start"}); - -fn uart_put(c: u8) void { - const uart: *volatile u8 = @ptrFromInt(UART_BASE); - uart.* = c; -} - -fn print(s: []const u8) void { - for (s) |c| { - uart_put(c); - } -} - -pub const panic = debug.KernelPanic; - -export fn _start() linksection(".text.init") callconv(.naked) noreturn { - asm volatile ( - \\li fp, 0 - \\li ra, 0 - \\li sp, 0x88000000 - \\tail kmain - ); -} - -export fn kmain(hartid: u64, fdt_ptr: *const anyopaque) callconv(.c) noreturn { - _ = hartid; - - const fdt = Fdt.parse(fdt_ptr) catch { - while (true) asm volatile ("wfi"); - }; - - const root = fdt.root().?; - - const console = Console.init(fdt).?; - debug.init(console); - - debug.print("booting hydra...\n", .{}); - - const memory = fdt.memory().?; - var reg_iter = Fdt.parseReg(memory.getProperty("reg").?.data, root.addressCells(), root.sizeCells()); - const reg = reg_iter.next().?; - const memory_end = reg.address + reg.size; - - var buddy: mem.BuddyAllocator = .{}; - buddy.init(MEMORY_START[0..memory_end - @intFromPtr(MEMORY_START)]); - debug.print("memory allocator initialized.\n", .{}); - - const allocator = buddy.allocator(); - const mmu_type = fdt.mmuType(); - - var table = isa.PageTable.init(allocator) catch { @panic("Unable to create page table.\n"); }; - table.identityMap(allocator, memory_end, mmu_type) catch {}; - table.map(allocator, @intFromPtr(console.mmio), @intFromPtr(console.mmio), .{.read = 1, .write = 1, .execute = 1}, mmu_type) catch { - - }; - isa.write_satp(.{ - .ppn = @as(u44, @intCast(@intFromPtr(table) >> 12)), - .mode = mmu_type, - }); - debug.print("loaded kernel page table.\n", .{}); - - while (true) { - asm volatile ("wfi"); - } -} diff --git a/src/mem.zig b/src/mem.zig index 95ea408..b9f0d74 100644 --- a/src/mem.zig +++ b/src/mem.zig @@ -1,3 +1,22 @@ pub const BuddyAllocator = @import("mem/BuddyAllocator.zig"); pub const PAGE_SIZE = 0x1000; +pub const KERNEL_VIRT_OFFSET: u64 = 0xffffffc000000000; +pub const PHYS_MAP_BASE: u64 = 0xffffffd000000000; +pub const MMIO_VIRT_OFFSET: u64 = 0xffffffe000000000; + +pub fn physToKernelVirt(physical: u64) u64 { + return KERNEL_VIRT_OFFSET + physical; +} + +pub fn physToDirectMap(physical: u64) u64 { + return PHYS_MAP_BASE + physical; +} + +pub fn directMapToPhys(virtual: u64) u64 { + return virtual - PHYS_MAP_BASE; +} + +pub fn physToMmioVirt(physical: u64) u64 { + return MMIO_VIRT_OFFSET + physical; +} diff --git a/src/mem/BuddyAllocator.zig b/src/mem/BuddyAllocator.zig index 6ae181b..414701c 100644 --- a/src/mem/BuddyAllocator.zig +++ b/src/mem/BuddyAllocator.zig @@ -22,6 +22,11 @@ fn getLevelSize(level: u8) usize { return @as(usize, MIN_BLOCK_SIZE) << @intCast(level); } +fn levelForRequest(len: usize, alignment: std.mem.Alignment) usize { + const required_size = @max(@max(len, MIN_BLOCK_SIZE), alignment.toByteUnits()); + return math.log2_int_ceil(usize, (required_size + MIN_BLOCK_SIZE - 1) / MIN_BLOCK_SIZE); +} + pub fn freeRange(self: *BuddyAllocator, start: [*]u8, size: usize) void { var current_addr = @intFromPtr(start); const end_addr = current_addr + size; @@ -119,23 +124,15 @@ pub fn allocator(self: *BuddyAllocator) Allocator { } fn alloc(ctx: *anyopaque, len: usize, ptr_align: std.mem.Alignment, ret_addr: usize) ?[*]u8 { - _ = ptr_align; _ = ret_addr; const self: *BuddyAllocator = @ptrCast(@alignCast(ctx)); - const actual_size = @max(len, MIN_BLOCK_SIZE); - const level = math.log2_int_ceil(usize, (actual_size + MIN_BLOCK_SIZE - 1) / MIN_BLOCK_SIZE); - - return self.allocBlock(@intCast(level)); + return self.allocBlock(@intCast(levelForRequest(len, ptr_align))); } fn free(ctx: *anyopaque, buf: []u8, buf_align: std.mem.Alignment, ret_addr: usize) void { - _ = buf_align; _ = ret_addr; const self: *BuddyAllocator = @ptrCast(@alignCast(ctx)); - const actual_size = @max(buf.len, MIN_BLOCK_SIZE); - const level = math.log2_int_ceil(usize, (actual_size + MIN_BLOCK_SIZE - 1) / MIN_BLOCK_SIZE); - - self.freeBlock(buf.ptr, @intCast(level)); + self.freeBlock(buf.ptr, @intCast(levelForRequest(buf.len, buf_align))); } diff --git a/src/riscv/PageTable.zig b/src/riscv/PageTable.zig index d702eae..47a901b 100644 --- a/src/riscv/PageTable.zig +++ b/src/riscv/PageTable.zig @@ -1,11 +1,8 @@ const std = @import("std"); const Allocator = std.mem.Allocator; const PageTable = @This(); -const debug = @import("../debug.zig"); const isa = @import("isa.zig"); -const MEMORY_START = @extern([*]u8, .{.name = "__memory_start"}); - pub const EntryFlags = packed struct { valid: u1 = 1, read: u1, @@ -25,17 +22,21 @@ pub const Entry = packed struct { n: u1 = 0, }; -entries: [512]Entry, +const PTE_VALID = @as(u64, 1) << 0; +const PTE_READ = @as(u64, 1) << 1; +const PTE_WRITE = @as(u64, 1) << 2; +const PTE_EXECUTE = @as(u64, 1) << 3; +const PTE_USER = @as(u64, 1) << 4; + +entries: [512]u64, pub fn init(allocator: Allocator) !*PageTable { const table = try allocator.create(PageTable); - for (&table.entries) |*entry| { - entry.* = @bitCast(@as(u64, 0x0)); - } + @memset(&table.entries, 0); return table; } -pub fn identityMap(self: *PageTable, allocator: Allocator, memory_end: u64, mode: isa.Satp.Mode) !void { +pub fn identityMap(self: *PageTable, allocator: Allocator, start: u64, end: u64, mode: isa.Satp.Mode, direct_map_base: u64) !void { const flags = EntryFlags{ .valid = 1, .read = 1, @@ -44,16 +45,53 @@ pub fn identityMap(self: *PageTable, allocator: Allocator, memory_end: u64, mode .user = 0, }; - var addr: u64 = 0x0; - while (addr < memory_end) : (addr += 0x1000) { - try self.map(allocator, addr, addr, flags, mode); + return self.mapRange(allocator, start, start, end - start, flags, mode, direct_map_base); +} + +pub fn mapRange(self: *PageTable, allocator: Allocator, virtual_start: u64, physical_start: u64, length: u64, flags: EntryFlags, mode: isa.Satp.Mode, direct_map_base: u64) !void { + if (length == 0) return; + + var virtual = virtual_start; + var physical = physical_start; + const virtual_end = virtual_start + length; + + while (virtual < virtual_end) { + const remaining = virtual_end - virtual; + if (virtual % (2 * 1024 * 1024) == 0 and physical % (2 * 1024 * 1024) == 0 and remaining >= 2 * 1024 * 1024) { + try self.mapLarge2MiB(allocator, virtual, physical, flags, mode, direct_map_base); + virtual += 2 * 1024 * 1024; + physical += 2 * 1024 * 1024; + } else { + try self.map(allocator, virtual, physical, flags, mode, direct_map_base); + virtual += 0x1000; + physical += 0x1000; + } } } -pub fn map(self: *PageTable, allocator: Allocator, virtual: u64, physical: u64, flags: EntryFlags, mode: isa.Satp.Mode) !void { - const PAGE_SIZE = 4096; +pub fn load(self: *const PageTable, mmu_type: isa.Satp.Mode, direct_map_base: u64) void { + isa.write_satp(.{ + .ppn = @as(u44, @intCast(pointerToPhysical(@intFromPtr(self), direct_map_base) >> 12)), + .mode = mmu_type, + }); +} - if (virtual % PAGE_SIZE != 0 or physical % PAGE_SIZE != 0) { +pub fn map(self: *PageTable, allocator: Allocator, virtual: u64, physical: u64, flags: EntryFlags, mode: isa.Satp.Mode, direct_map_base: u64) !void { + return self.mapAtLevel(allocator, virtual, physical, flags, mode, 0, direct_map_base); +} + +pub fn mapLarge2MiB(self: *PageTable, allocator: Allocator, virtual: u64, physical: u64, flags: EntryFlags, mode: isa.Satp.Mode, direct_map_base: u64) !void { + return self.mapAtLevel(allocator, virtual, physical, flags, mode, 1, direct_map_base); +} + +fn mapAtLevel(self: *PageTable, allocator: Allocator, virtual: u64, physical: u64, flags: EntryFlags, mode: isa.Satp.Mode, leaf_level: usize, direct_map_base: u64) !void { + const page_size: u64 = switch (leaf_level) { + 0 => 4096, + 1 => 2 * 1024 * 1024, + else => unreachable, + }; + + if (virtual % page_size != 0 or physical % page_size != 0) { return error.AddressNotAligned; } @@ -74,26 +112,52 @@ pub fn map(self: *PageTable, allocator: Allocator, virtual: u64, physical: u64, var current_table = self; - while (level > 0) : (level -= 1) { + while (level > leaf_level) : (level -= 1) { const index = vpn[level]; - var entry = ¤t_table.entries[index]; + const entry = ¤t_table.entries[index]; - if (entry.flags.valid == 0) { + if (!isValid(entry.*)) { const new_table = try PageTable.init(allocator); - - const table_phys = @intFromPtr(new_table); - entry.* = .{ - .flags = .{ .valid = 1, .read = 0, .write = 0, .execute = 0, .user = 0 }, - .ppn = @truncate(table_phys >> 12), - }; + + const table_phys = pointerToPhysical(@intFromPtr(new_table), direct_map_base); + entry.* = encodeBranchEntry(table_phys); } - const next_table_phys = @as(u64, entry.ppn) << 12; - current_table = @ptrFromInt(next_table_phys); + const next_table_phys = decodePhysical(entry.*); + current_table = @ptrFromInt(physicalToPointer(next_table_phys, direct_map_base)); } - current_table.entries[vpn[0]] = .{ - .flags = flags, - .ppn = @truncate(physical >> 12), - }; + current_table.entries[vpn[leaf_level]] = encodeLeafEntry(physical, flags); +} + +fn pointerToPhysical(pointer: u64, direct_map_base: u64) u64 { + if (direct_map_base == 0) return pointer; + return pointer - direct_map_base; +} + +fn physicalToPointer(physical: u64, direct_map_base: u64) u64 { + if (direct_map_base == 0) return physical; + return physical + direct_map_base; +} + +fn isValid(entry: u64) bool { + return (entry & PTE_VALID) != 0; +} + +fn decodePhysical(entry: u64) u64 { + return ((entry >> 10) & ((@as(u64, 1) << 44) - 1)) << 12; +} + +fn encodeBranchEntry(physical: u64) u64 { + return PTE_VALID | ((physical >> 12) << 10); +} + +fn encodeLeafEntry(physical: u64, flags: EntryFlags) u64 { + var entry = (physical >> 12) << 10; + entry |= PTE_VALID; + if (flags.read != 0) entry |= PTE_READ; + if (flags.write != 0) entry |= PTE_WRITE; + if (flags.execute != 0) entry |= PTE_EXECUTE; + if (flags.user != 0) entry |= PTE_USER; + return entry; }