From b574d39a392bc3af5539650c6b134840a708aa99 Mon Sep 17 00:00:00 2001 From: Lorenzo Torres Date: Tue, 24 Feb 2026 14:28:56 +0100 Subject: [PATCH] first commit --- .gitignore | 3 + build.zig | 102 ++ build.zig.zon | 81 ++ src/main.zig | 27 + src/root.zig | 79 ++ src/wasm/binary.zig | 489 ++++++++ src/wasm/host.zig | 316 +++++ src/wasm/instance.zig | 2310 +++++++++++++++++++++++++++++++++++++ src/wasm/jit/aarch64.zig | 1383 ++++++++++++++++++++++ src/wasm/jit/codebuf.zig | 189 +++ src/wasm/jit/codegen.zig | 24 + src/wasm/jit/liveness.zig | 75 ++ src/wasm/jit/regalloc.zig | 193 ++++ src/wasm/jit/stackify.zig | 965 ++++++++++++++++ src/wasm/jit/x86_64.zig | 1163 +++++++++++++++++++ src/wasm/jit_tests.zig | 11 + src/wasm/module.zig | 138 +++ src/wasm/runtime.zig | 115 ++ src/wasm/trap.zig | 20 + src/wasm/validator.zig | 881 ++++++++++++++ tests/wasm/fib.wasm | Bin 0 -> 116 bytes tests/wasm/fib.wat | 35 + tests/wasm/fib.zig | 5 + 23 files changed, 8604 insertions(+) create mode 100644 .gitignore create mode 100644 build.zig create mode 100644 build.zig.zon create mode 100644 src/main.zig create mode 100644 src/root.zig create mode 100644 src/wasm/binary.zig create mode 100644 src/wasm/host.zig create mode 100644 src/wasm/instance.zig create mode 100644 src/wasm/jit/aarch64.zig create mode 100644 src/wasm/jit/codebuf.zig create mode 100644 src/wasm/jit/codegen.zig create mode 100644 src/wasm/jit/liveness.zig create mode 100644 src/wasm/jit/regalloc.zig create mode 100644 src/wasm/jit/stackify.zig create mode 100644 src/wasm/jit/x86_64.zig create mode 100644 src/wasm/jit_tests.zig create mode 100644 src/wasm/module.zig create mode 100644 src/wasm/runtime.zig create mode 100644 src/wasm/trap.zig create mode 100644 src/wasm/validator.zig create mode 100755 tests/wasm/fib.wasm create mode 100644 tests/wasm/fib.wat create mode 100644 tests/wasm/fib.zig diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..254a1b2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +zig-out/ +.zig-cache/ +**/*~ diff --git a/build.zig b/build.zig new file mode 100644 index 0000000..8172b92 --- /dev/null +++ b/build.zig @@ -0,0 +1,102 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + const mod = b.addModule("wasm_runtime", .{ + .root_source_file = b.path("src/root.zig"), + .target = target, + }); + + const exe = b.addExecutable(.{ + .name = "wasm_runtime", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/main.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "wasm_runtime", .module = mod }, + }, + }), + }); + b.installArtifact(exe); + + const run_cmd = b.addRunArtifact(exe); + run_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| run_cmd.addArgs(args); + const run_step = b.step("run", "Run the app"); + run_step.dependOn(&run_cmd.step); + + const lib_tests = b.addTest(.{ + .root_module = b.createModule(.{ + .root_source_file = b.path("src/root.zig"), + .target = target, + .optimize = optimize, + }), + }); + const run_lib_tests = b.addRunArtifact(lib_tests); + + const binary_tests = b.addTest(.{ + .root_module = b.createModule(.{ + .root_source_file = b.path("src/wasm/binary.zig"), + .target = target, + .optimize = optimize, + }), + }); + const run_binary_tests = b.addRunArtifact(binary_tests); + + const validator_tests = b.addTest(.{ + .root_module = b.createModule(.{ + .root_source_file = b.path("src/wasm/validator.zig"), + .target = target, + .optimize = optimize, + }), + }); + const run_validator_tests = b.addRunArtifact(validator_tests); + + const runtime_tests = b.addTest(.{ + .root_module = b.createModule(.{ + .root_source_file = b.path("src/wasm/runtime.zig"), + .target = target, + .optimize = optimize, + }), + }); + const run_runtime_tests = b.addRunArtifact(runtime_tests); + + const instance_tests = b.addTest(.{ + .root_module = b.createModule(.{ + .root_source_file = b.path("src/wasm/instance.zig"), + .target = target, + .optimize = optimize, + }), + }); + const run_instance_tests = b.addRunArtifact(instance_tests); + + const host_tests = b.addTest(.{ + .root_module = b.createModule(.{ + .root_source_file = b.path("src/wasm/host.zig"), + .target = target, + .optimize = optimize, + }), + }); + const run_host_tests = b.addRunArtifact(host_tests); + + const jit_tests = b.addTest(.{ + .root_module = b.createModule(.{ + .root_source_file = b.path("src/wasm/jit_tests.zig"), + .target = target, + .optimize = optimize, + }), + }); + const run_jit_tests = b.addRunArtifact(jit_tests); + + const test_step = b.step("test", "Run all tests"); + test_step.dependOn(&run_lib_tests.step); + test_step.dependOn(&run_binary_tests.step); + test_step.dependOn(&run_validator_tests.step); + test_step.dependOn(&run_runtime_tests.step); + test_step.dependOn(&run_instance_tests.step); + test_step.dependOn(&run_host_tests.step); + test_step.dependOn(&run_jit_tests.step); +} diff --git a/build.zig.zon b/build.zig.zon new file mode 100644 index 0000000..885376d --- /dev/null +++ b/build.zig.zon @@ -0,0 +1,81 @@ +.{ + // This is the default name used by packages depending on this one. For + // example, when a user runs `zig fetch --save `, this field is used + // as the key in the `dependencies` table. Although the user can choose a + // different name, most users will stick with this provided value. + // + // It is redundant to include "zig" in this name because it is already + // within the Zig package namespace. + .name = .wasm_runtime, + // This is a [Semantic Version](https://semver.org/). + // In a future version of Zig it will be used for package deduplication. + .version = "0.0.0", + // Together with name, this represents a globally unique package + // identifier. This field is generated by the Zig toolchain when the + // package is first created, and then *never changes*. This allows + // unambiguous detection of one package being an updated version of + // another. + // + // When forking a Zig project, this id should be regenerated (delete the + // field and run `zig build`) if the upstream project is still maintained. + // Otherwise, the fork is *hostile*, attempting to take control over the + // original project's identity. Thus it is recommended to leave the comment + // on the following line intact, so that it shows up in code reviews that + // modify the field. + .fingerprint = 0xcec38deb11d082fe, // Changing this has security and trust implications. + // Tracks the earliest Zig version that the package considers to be a + // supported use case. + .minimum_zig_version = "0.16.0-dev.2187+e2338edb4", + // This field is optional. + // Each dependency must either provide a `url` and `hash`, or a `path`. + // `zig build --fetch` can be used to fetch all dependencies of a package, recursively. + // Once all dependencies are fetched, `zig build` no longer requires + // internet connectivity. + .dependencies = .{ + // See `zig fetch --save ` for a command-line interface for adding dependencies. + //.example = .{ + // // When updating this field to a new URL, be sure to delete the corresponding + // // `hash`, otherwise you are communicating that you expect to find the old hash at + // // the new URL. If the contents of a URL change this will result in a hash mismatch + // // which will prevent zig from using it. + // .url = "https://example.com/foo.tar.gz", + // + // // This is computed from the file contents of the directory of files that is + // // obtained after fetching `url` and applying the inclusion rules given by + // // `paths`. + // // + // // This field is the source of truth; packages do not come from a `url`; they + // // come from a `hash`. `url` is just one of many possible mirrors for how to + // // obtain a package matching this `hash`. + // // + // // Uses the [multihash](https://multiformats.io/multihash/) format. + // .hash = "...", + // + // // When this is provided, the package is found in a directory relative to the + // // build root. In this case the package's hash is irrelevant and therefore not + // // computed. This field and `url` are mutually exclusive. + // .path = "foo", + // + // // When this is set to `true`, a package is declared to be lazily + // // fetched. This makes the dependency only get fetched if it is + // // actually used. + // .lazy = false, + //}, + }, + // Specifies the set of files and directories that are included in this package. + // Only files and directories listed here are included in the `hash` that + // is computed for this package. Only files listed here will remain on disk + // when using the zig package manager. As a rule of thumb, one should list + // files required for compilation plus any license(s). + // Paths are relative to the build root. Use the empty string (`""`) to refer to + // the build root itself. + // A directory listed here means that all files within, recursively, are included. + .paths = .{ + "build.zig", + "build.zig.zon", + "src", + // For example... + //"LICENSE", + //"README.md", + }, +} diff --git a/src/main.zig b/src/main.zig new file mode 100644 index 0000000..3cf2cb3 --- /dev/null +++ b/src/main.zig @@ -0,0 +1,27 @@ +const std = @import("std"); +const wasm = @import("wasm_runtime"); + +fn log(instance_ptr: *anyopaque, ptr: i32, len: i32) !void { + const instance: *wasm.ModuleInstance = @ptrCast(@alignCast(instance_ptr)); + const bytes = try instance.memorySlice(@intCast(ptr), @intCast(len)); + std.debug.print("wasm: {s}", .{bytes}); +} + +pub fn main() !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + const allocator = gpa.allocator(); + + var imports = wasm.ImportSet.init(allocator); + defer imports.deinit(); + + try imports.addFunc(.env, .log, log); + + var engine = wasm.Engine.init(allocator); + var instance = try engine.loadFile("tests/wasm/fib.wasm", &imports); + defer instance.deinit(); + + const result = try instance.callExport("init", &.{}); + defer allocator.free(result); + std.debug.print("done\n", .{}); +} diff --git a/src/root.zig b/src/root.zig new file mode 100644 index 0000000..d4989a5 --- /dev/null +++ b/src/root.zig @@ -0,0 +1,79 @@ +const std = @import("std"); + +pub const module = @import("wasm/module.zig"); +pub const binary = @import("wasm/binary.zig"); +pub const validator = @import("wasm/validator.zig"); +pub const runtime = @import("wasm/runtime.zig"); +pub const trap = @import("wasm/trap.zig"); +pub const host = @import("wasm/host.zig"); +pub const instance = @import("wasm/instance.zig"); + +pub const Value = runtime.Value; +pub const ValType = module.ValType; +pub const Memory = runtime.Memory; +pub const ImportSet = host.ImportSet; +pub const HostFunc = host.HostFunc; +pub const readStruct = host.readStruct; +pub const writeStruct = host.writeStruct; +pub const TrapCode = trap.TrapCode; +pub const Trap = trap.Trap; +pub const ModuleInstance = instance.ModuleInstance; + +pub const Engine = struct { + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) Engine { + return .{ .allocator = allocator }; + } + + pub fn loadFile(self: *Engine, path: []const u8, imports: *const ImportSet) !ModuleInstance { + const fd = try std.posix.openat(std.posix.AT.FDCWD, path, .{ .ACCMODE = .RDONLY }, 0); + defer std.posix.close(fd); + + const st = try std.posix.fstat(fd); + if (st.size < 0) return error.InvalidWasmSize; + const size: usize = @intCast(st.size); + const bytes = try self.allocator.alloc(u8, size); + errdefer self.allocator.free(bytes); + + var off: usize = 0; + while (off < bytes.len) { + const n = try std.posix.read(fd, bytes[off..]); + if (n == 0) break; + off += n; + } + if (off != bytes.len) return error.UnexpectedEof; + defer self.allocator.free(bytes); + return self.loadBytes(bytes, imports); + } + + pub fn loadBytes(self: *Engine, bytes: []const u8, imports: *const ImportSet) !ModuleInstance { + const mod_ptr = try self.allocator.create(module.Module); + errdefer self.allocator.destroy(mod_ptr); + mod_ptr.* = try module.Module.parse(self.allocator, bytes); + errdefer mod_ptr.deinit(); + try validator.validate(mod_ptr); + return ModuleInstance.instantiateOwned(self.allocator, mod_ptr, imports); + } +}; + +test "engine loadBytes and call export" { + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + 0x03, 0x02, 0x01, 0x00, + 0x07, 0x08, 0x01, 0x04, 0x70, 0x69, 0x6e, 0x67, 0x00, 0x00, + 0x0a, 0x06, 0x01, 0x04, 0x00, 0x41, 0x2a, 0x0b, + }; + const ally = std.testing.allocator; + var imports = ImportSet.init(ally); + defer imports.deinit(); + + var engine = Engine.init(ally); + var inst = try engine.loadBytes(&wasm, &imports); + defer inst.deinit(); + + const out = try inst.callExport("ping", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 42), out[0].i32); +} diff --git a/src/wasm/binary.zig b/src/wasm/binary.zig new file mode 100644 index 0000000..7f48fc1 --- /dev/null +++ b/src/wasm/binary.zig @@ -0,0 +1,489 @@ +const std = @import("std"); +const module = @import("module.zig"); + +pub const Module = module.Module; +pub const ValType = module.ValType; +pub const SectionId = module.SectionId; + +pub const ParseError = error{ + InvalidMagic, + InvalidVersion, + UnexpectedEof, + MalformedLEB128, + UnknownSection, + UnknownOpcode, + InvalidValType, + InvalidImportDesc, + InvalidExportDesc, + InvalidConstExpr, + InvalidRefType, + InvalidMutability, + InvalidFuncType, + InvalidDataSegment, + OutOfMemory, +}; + +/// Read an unsigned LEB128 value from a byte slice starting at *pos. +pub fn readULEB128(comptime T: type, bytes: []const u8, pos: *usize) ParseError!T { + const max_bits = @typeInfo(T).int.bits; + var result: T = 0; + var shift: u32 = 0; + while (true) { + if (pos.* >= bytes.len) return ParseError.UnexpectedEof; + const byte = bytes[pos.*]; + pos.* += 1; + const val: T = @intCast(byte & 0x7F); + if (shift < max_bits) { + result |= val << @intCast(shift); + } else if (val != 0) { + return ParseError.MalformedLEB128; + } + if (byte & 0x80 == 0) break; + shift += 7; + if (shift >= max_bits + 7) return ParseError.MalformedLEB128; + } + return result; +} + +/// Read a signed LEB128 value from a byte slice starting at *pos. +pub fn readSLEB128(comptime T: type, bytes: []const u8, pos: *usize) ParseError!T { + const Unsigned = std.meta.Int(.unsigned, @typeInfo(T).int.bits); + const max_bits = @typeInfo(T).int.bits; + var result: Unsigned = 0; + var shift: u32 = 0; + var last_byte: u8 = 0; + while (true) { + if (pos.* >= bytes.len) return ParseError.UnexpectedEof; + last_byte = bytes[pos.*]; + pos.* += 1; + const val: Unsigned = @intCast(last_byte & 0x7F); + if (shift < max_bits) { + result |= val << @intCast(shift); + } + shift += 7; + if (last_byte & 0x80 == 0) break; + if (shift >= max_bits + 7) return ParseError.MalformedLEB128; + } + // Sign extend + if (shift < max_bits and (last_byte & 0x40) != 0) { + result |= ~@as(Unsigned, 0) << @intCast(shift); + } + return @bitCast(result); +} + +fn readByte(bytes: []const u8, pos: *usize) ParseError!u8 { + if (pos.* >= bytes.len) return ParseError.UnexpectedEof; + const b = bytes[pos.*]; + pos.* += 1; + return b; +} + +fn readBytes(bytes: []const u8, pos: *usize, len: usize) ParseError![]const u8 { + if (pos.* + len > bytes.len) return ParseError.UnexpectedEof; + const slice = bytes[pos.* .. pos.* + len]; + pos.* += len; + return slice; +} + +fn readValType(bytes: []const u8, pos: *usize) ParseError!module.ValType { + const b = try readByte(bytes, pos); + return switch (b) { + 0x7F => .i32, + 0x7E => .i64, + 0x7D => .f32, + 0x7C => .f64, + else => ParseError.InvalidValType, + }; +} + +fn readConstExpr(bytes: []const u8, pos: *usize) ParseError!module.ConstExpr { + const op = try readByte(bytes, pos); + const expr: module.ConstExpr = switch (op) { + 0x41 => .{ .i32_const = try readSLEB128(i32, bytes, pos) }, + 0x42 => .{ .i64_const = try readSLEB128(i64, bytes, pos) }, + 0x43 => blk: { + const raw = try readBytes(bytes, pos, 4); + break :blk .{ .f32_const = @bitCast(std.mem.readInt(u32, raw[0..4], .little)) }; + }, + 0x44 => blk: { + const raw = try readBytes(bytes, pos, 8); + break :blk .{ .f64_const = @bitCast(std.mem.readInt(u64, raw[0..8], .little)) }; + }, + 0x23 => .{ .global_get = try readULEB128(u32, bytes, pos) }, + else => return ParseError.InvalidConstExpr, + }; + const end = try readByte(bytes, pos); + if (end != 0x0B) return ParseError.InvalidConstExpr; + return expr; +} + +fn parseLimits(bytes: []const u8, pos: *usize) ParseError!struct { min: u32, max: ?u32 } { + const flag = try readByte(bytes, pos); + const min = try readULEB128(u32, bytes, pos); + const max: ?u32 = if (flag == 1) try readULEB128(u32, bytes, pos) else null; + return .{ .min = min, .max = max }; +} + +fn parseTypeSection(ally: std.mem.Allocator, bytes: []const u8, pos: *usize) ParseError![]module.FuncType { + const count = try readULEB128(u32, bytes, pos); + const types = try ally.alloc(module.FuncType, count); + errdefer { + for (types) |t| { + ally.free(t.params); + ally.free(t.results); + } + ally.free(types); + } + var i: u32 = 0; + while (i < count) : (i += 1) { + const tag = try readByte(bytes, pos); + if (tag != 0x60) return ParseError.InvalidFuncType; + const param_count = try readULEB128(u32, bytes, pos); + const params = try ally.alloc(module.ValType, param_count); + errdefer ally.free(params); + for (params) |*p| p.* = try readValType(bytes, pos); + const result_count = try readULEB128(u32, bytes, pos); + const results = try ally.alloc(module.ValType, result_count); + errdefer ally.free(results); + for (results) |*r| r.* = try readValType(bytes, pos); + types[i] = .{ .params = params, .results = results }; + } + return types; +} + +fn parseImportSection(ally: std.mem.Allocator, bytes: []const u8, pos: *usize) ParseError![]module.Import { + const count = try readULEB128(u32, bytes, pos); + const imports = try ally.alloc(module.Import, count); + errdefer ally.free(imports); + var i: u32 = 0; + while (i < count) : (i += 1) { + const mod_len = try readULEB128(u32, bytes, pos); + const mod_bytes = try readBytes(bytes, pos, mod_len); + const mod_name = try ally.dupe(u8, mod_bytes); + errdefer ally.free(mod_name); + const name_len = try readULEB128(u32, bytes, pos); + const name_bytes = try readBytes(bytes, pos, name_len); + const name = try ally.dupe(u8, name_bytes); + errdefer ally.free(name); + const kind = try readByte(bytes, pos); + const desc: module.ImportDesc = switch (kind) { + 0 => .{ .func = try readULEB128(u32, bytes, pos) }, + 1 => blk: { + const elem_type = try readByte(bytes, pos); + const lim = try parseLimits(bytes, pos); + break :blk .{ .table = .{ .elem_type = elem_type, .min = lim.min, .max = lim.max } }; + }, + 2 => blk: { + const lim = try parseLimits(bytes, pos); + break :blk .{ .memory = .{ .min = lim.min, .max = lim.max } }; + }, + 3 => blk: { + const vt = try readValType(bytes, pos); + const mut = try readByte(bytes, pos); + if (mut > 1) return ParseError.InvalidMutability; + break :blk .{ .global = .{ .valtype = vt, .mutable = mut == 1 } }; + }, + else => return ParseError.InvalidImportDesc, + }; + imports[i] = .{ .module = mod_name, .name = name, .desc = desc }; + } + return imports; +} + +fn parseFunctionSection(ally: std.mem.Allocator, bytes: []const u8, pos: *usize) ParseError![]u32 { + const count = try readULEB128(u32, bytes, pos); + const funcs = try ally.alloc(u32, count); + for (funcs) |*f| f.* = try readULEB128(u32, bytes, pos); + return funcs; +} + +fn parseTableSection(ally: std.mem.Allocator, bytes: []const u8, pos: *usize) ParseError![]module.TableType { + const count = try readULEB128(u32, bytes, pos); + const tables = try ally.alloc(module.TableType, count); + for (tables) |*t| { + const elem_type = try readByte(bytes, pos); + const lim = try parseLimits(bytes, pos); + t.* = .{ .elem_type = elem_type, .min = lim.min, .max = lim.max }; + } + return tables; +} + +fn parseMemorySection(ally: std.mem.Allocator, bytes: []const u8, pos: *usize) ParseError![]module.MemoryType { + const count = try readULEB128(u32, bytes, pos); + const mems = try ally.alloc(module.MemoryType, count); + for (mems) |*m| { + const lim = try parseLimits(bytes, pos); + m.* = .{ .min = lim.min, .max = lim.max }; + } + return mems; +} + +fn parseGlobalSection(ally: std.mem.Allocator, bytes: []const u8, pos: *usize) ParseError![]module.GlobalDef { + const count = try readULEB128(u32, bytes, pos); + const globals = try ally.alloc(module.GlobalDef, count); + for (globals) |*g| { + const vt = try readValType(bytes, pos); + const mut = try readByte(bytes, pos); + if (mut > 1) return ParseError.InvalidMutability; + const init = try readConstExpr(bytes, pos); + g.* = .{ .type = .{ .valtype = vt, .mutable = mut == 1 }, .init = init }; + } + return globals; +} + +fn parseExportSection(ally: std.mem.Allocator, bytes: []const u8, pos: *usize) ParseError![]module.Export { + const count = try readULEB128(u32, bytes, pos); + const exports = try ally.alloc(module.Export, count); + errdefer ally.free(exports); + var i: u32 = 0; + while (i < count) : (i += 1) { + const name_len = try readULEB128(u32, bytes, pos); + const name_bytes = try readBytes(bytes, pos, name_len); + const name = try ally.dupe(u8, name_bytes); + errdefer ally.free(name); + const kind = try readByte(bytes, pos); + const idx = try readULEB128(u32, bytes, pos); + const desc: module.ExportDesc = switch (kind) { + 0 => .{ .func = idx }, + 1 => .{ .table = idx }, + 2 => .{ .memory = idx }, + 3 => .{ .global = idx }, + else => return ParseError.InvalidExportDesc, + }; + exports[i] = .{ .name = name, .desc = desc }; + } + return exports; +} + +fn parseElementSection(ally: std.mem.Allocator, bytes: []const u8, pos: *usize) ParseError![]module.ElementSegment { + const count = try readULEB128(u32, bytes, pos); + const elems = try ally.alloc(module.ElementSegment, count); + errdefer ally.free(elems); + var i: u32 = 0; + while (i < count) : (i += 1) { + const table_idx = try readULEB128(u32, bytes, pos); + const offset = try readConstExpr(bytes, pos); + const num_funcs = try readULEB128(u32, bytes, pos); + const func_indices = try ally.alloc(u32, num_funcs); + errdefer ally.free(func_indices); + for (func_indices) |*f| f.* = try readULEB128(u32, bytes, pos); + elems[i] = .{ .table_idx = table_idx, .offset = offset, .func_indices = func_indices }; + } + return elems; +} + +fn parseCodeSection(ally: std.mem.Allocator, bytes: []const u8, pos: *usize) ParseError![]module.FunctionBody { + const count = try readULEB128(u32, bytes, pos); + const bodies = try ally.alloc(module.FunctionBody, count); + errdefer ally.free(bodies); + var i: u32 = 0; + while (i < count) : (i += 1) { + const body_size = try readULEB128(u32, bytes, pos); + const body_start = pos.*; + const local_count = try readULEB128(u32, bytes, pos); + const locals = try ally.alloc(module.LocalDecl, local_count); + errdefer ally.free(locals); + for (locals) |*l| { + const n = try readULEB128(u32, bytes, pos); + const vt = try readValType(bytes, pos); + l.* = .{ .count = n, .valtype = vt }; + } + const code_start = pos.*; + const code_end = body_start + body_size; + if (code_end > bytes.len) return ParseError.UnexpectedEof; + const code = bytes[code_start..code_end]; + pos.* = code_end; + bodies[i] = .{ .locals = locals, .code = code }; + } + return bodies; +} + +fn parseDataSection(ally: std.mem.Allocator, bytes: []const u8, pos: *usize) ParseError![]module.DataSegment { + const count = try readULEB128(u32, bytes, pos); + const datas = try ally.alloc(module.DataSegment, count); + for (datas) |*d| { + const kind = try readULEB128(u32, bytes, pos); + switch (kind) { + 0 => { + const offset = try readConstExpr(bytes, pos); + const data_len = try readULEB128(u32, bytes, pos); + const data_bytes = try readBytes(bytes, pos, data_len); + d.* = .{ + .kind = .active, + .memory_idx = 0, + .offset = offset, + .bytes = data_bytes, + }; + }, + 1 => { + const data_len = try readULEB128(u32, bytes, pos); + const data_bytes = try readBytes(bytes, pos, data_len); + d.* = .{ + .kind = .passive, + .memory_idx = 0, + .offset = null, + .bytes = data_bytes, + }; + }, + 2 => { + const mem_idx = try readULEB128(u32, bytes, pos); + const offset = try readConstExpr(bytes, pos); + const data_len = try readULEB128(u32, bytes, pos); + const data_bytes = try readBytes(bytes, pos, data_len); + d.* = .{ + .kind = .active, + .memory_idx = mem_idx, + .offset = offset, + .bytes = data_bytes, + }; + }, + else => return ParseError.InvalidDataSegment, + } + } + return datas; +} + +fn sectionIdFromByte(b: u8) ?module.SectionId { + return switch (b) { + 0 => .custom, + 1 => .type, + 2 => .import, + 3 => .function, + 4 => .table, + 5 => .memory, + 6 => .global, + 7 => .@"export", + 8 => .start, + 9 => .element, + 10 => .code, + 11 => .data, + else => null, + }; +} + +pub fn parse(ally: std.mem.Allocator, bytes: []const u8) ParseError!module.Module { + var pos: usize = 0; + + // Validate magic + version + const magic = try readBytes(bytes, &pos, 4); + if (!std.mem.eql(u8, magic, "\x00asm")) return ParseError.InvalidMagic; + const ver = try readBytes(bytes, &pos, 4); + if (!std.mem.eql(u8, ver, "\x01\x00\x00\x00")) return ParseError.InvalidVersion; + + var mod = module.Module{ + .types = &.{}, + .imports = &.{}, + .functions = &.{}, + .tables = &.{}, + .memories = &.{}, + .globals = &.{}, + .exports = &.{}, + .start = null, + .elements = &.{}, + .codes = &.{}, + .datas = &.{}, + .allocator = ally, + }; + errdefer mod.deinit(); + + while (pos < bytes.len) { + const section_id_byte = try readByte(bytes, &pos); + const section_size = try readULEB128(u32, bytes, &pos); + const section_end = pos + section_size; + if (section_end > bytes.len) return ParseError.UnexpectedEof; + + const section_id = sectionIdFromByte(section_id_byte) orelse { + pos = section_end; + continue; + }; + + switch (section_id) { + .custom => pos = section_end, // skip + .type => mod.types = try parseTypeSection(ally, bytes, &pos), + .import => mod.imports = try parseImportSection(ally, bytes, &pos), + .function => mod.functions = try parseFunctionSection(ally, bytes, &pos), + .table => mod.tables = try parseTableSection(ally, bytes, &pos), + .memory => mod.memories = try parseMemorySection(ally, bytes, &pos), + .global => mod.globals = try parseGlobalSection(ally, bytes, &pos), + .@"export" => mod.exports = try parseExportSection(ally, bytes, &pos), + .start => mod.start = try readULEB128(u32, bytes, &pos), + .element => mod.elements = try parseElementSection(ally, bytes, &pos), + .code => mod.codes = try parseCodeSection(ally, bytes, &pos), + .data => mod.datas = try parseDataSection(ally, bytes, &pos), + } + // Ensure we consumed exactly section_size bytes + pos = section_end; + } + + return mod; +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +test "parse minimal wasm module" { + const ally = std.testing.allocator; + // Minimal valid wasm: magic + version only + const bytes = "\x00asm\x01\x00\x00\x00"; + var mod = try parse(ally, bytes); + defer mod.deinit(); + try std.testing.expectEqual(@as(usize, 0), mod.types.len); + try std.testing.expectEqual(@as(usize, 0), mod.functions.len); +} + +test "invalid magic rejected" { + const ally = std.testing.allocator; + const bytes = "\x00BAD\x01\x00\x00\x00"; + try std.testing.expectError(ParseError.InvalidMagic, parse(ally, bytes)); +} + +test "invalid version rejected" { + const ally = std.testing.allocator; + const bytes = "\x00asm\x02\x00\x00\x00"; + try std.testing.expectError(ParseError.InvalidVersion, parse(ally, bytes)); +} + +test "readULEB128 basic" { + const bytes = [_]u8{ 0xE5, 0x8E, 0x26 }; // 624485 + var pos: usize = 0; + const val = try readULEB128(u32, &bytes, &pos); + try std.testing.expectEqual(@as(u32, 624485), val); + try std.testing.expectEqual(@as(usize, 3), pos); +} + +test "readSLEB128 negative" { + const bytes = [_]u8{ 0x7E }; // -2 in SLEB128 + var pos: usize = 0; + const val = try readSLEB128(i32, &bytes, &pos); + try std.testing.expectEqual(@as(i32, -2), val); +} + +test "parse fib module" { + // Bytes from tests/wasm/fib.wasm (fib(i32)->i32, recursive) + const wasm_bytes = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, // magic + 0x01, 0x00, 0x00, 0x00, // version + // type section: 1 type -> (i32) -> (i32) + 0x01, 0x06, 0x01, 0x60, 0x01, 0x7f, 0x01, 0x7f, + // function section: 1 func, type 0 + 0x03, 0x02, 0x01, 0x00, + // export section: "fib" -> func 0 + 0x07, 0x07, 0x01, 0x03, 0x66, 0x69, 0x62, 0x00, 0x00, + // code section + 0x0a, 0x1e, 0x01, 0x1c, 0x00, 0x20, 0x00, 0x41, 0x02, 0x48, 0x04, + 0x7f, 0x20, 0x00, 0x05, 0x20, 0x00, 0x41, 0x01, 0x6b, 0x10, 0x00, + 0x20, 0x00, 0x41, 0x02, 0x6b, 0x10, 0x00, 0x6a, 0x0b, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try parse(ally, &wasm_bytes); + defer mod.deinit(); + try std.testing.expectEqual(@as(usize, 1), mod.types.len); + try std.testing.expectEqual(@as(usize, 1), mod.functions.len); + try std.testing.expectEqual(@as(usize, 1), mod.exports.len); + try std.testing.expectEqualStrings("fib", mod.exports[0].name); + try std.testing.expectEqual(@as(usize, 1), mod.codes.len); + // Verify type: (i32) -> (i32) + try std.testing.expectEqual(@as(usize, 1), mod.types[0].params.len); + try std.testing.expectEqual(module.ValType.i32, mod.types[0].params[0]); + try std.testing.expectEqual(@as(usize, 1), mod.types[0].results.len); + try std.testing.expectEqual(module.ValType.i32, mod.types[0].results[0]); +} diff --git a/src/wasm/host.zig b/src/wasm/host.zig new file mode 100644 index 0000000..4a6918e --- /dev/null +++ b/src/wasm/host.zig @@ -0,0 +1,316 @@ +const std = @import("std"); +const module = @import("module.zig"); +const runtime = @import("runtime.zig"); + +pub const Value = runtime.Value; +pub const FuncType = module.FuncType; +pub const ValType = module.ValType; +pub const Memory = runtime.Memory; + +pub const HostFunc = struct { + name: []const u8, + module: []const u8, + type_: FuncType, + userdata: ?*anyopaque, + invoke: *const fn ( + instance_ptr: *anyopaque, + args: []const Value, + results: []Value, + userdata: ?*anyopaque, + ) anyerror!void, +}; + +pub const ImportSet = struct { + functions: std.ArrayList(HostFunc), + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) ImportSet { + return .{ + .functions = .empty, + .allocator = allocator, + }; + } + + pub fn deinit(self: *ImportSet) void { + for (self.functions.items) |f| { + self.allocator.free(f.name); + self.allocator.free(f.module); + self.allocator.free(f.type_.params); + self.allocator.free(f.type_.results); + } + self.functions.deinit(self.allocator); + self.* = undefined; + } + + pub fn addFunc( + self: *ImportSet, + comptime module_name: @EnumLiteral(), + comptime func_name: @EnumLiteral(), + comptime func: anytype, + ) !void { + const Info = @typeInfo(@TypeOf(func)); + if (Info != .@"fn") @compileError("addFunc expects a function"); + const fn_info = Info.@"fn"; + const has_instance_param = comptime hasInstanceParam(fn_info); + const wasm_param_start: usize = if (has_instance_param) 1 else 0; + + const param_len = fn_info.params.len - wasm_param_start; + const params = try self.allocator.alloc(ValType, param_len); + errdefer self.allocator.free(params); + inline for (fn_info.params[wasm_param_start..], 0..) |p, i| { + const ty = p.type orelse @compileError("function parameter type must be known"); + params[i] = comptime zigTypeToValType(ty); + } + + const result_len: usize = if (fn_info.return_type) |rt| + if (rt == void or (comptime isErrorUnionVoid(rt))) 0 else 1 + else + 0; + const results = try self.allocator.alloc(ValType, result_len); + errdefer self.allocator.free(results); + if (result_len == 1) { + const rt = resultInnerType(fn_info.return_type.?); + results[0] = comptime zigTypeToValType(rt); + } + + const F = struct { + fn invoke(instance_ptr: *anyopaque, args: []const Value, out_results: []Value, _: ?*anyopaque) anyerror!void { + if (args.len != param_len) return error.InvalidArgumentCount; + if (out_results.len != result_len) return error.InvalidResultCount; + + const call_args = try buildCallArgs(instance_ptr, args); + const ret = @call(.auto, func, call_args); + if (result_len == 1) { + const val = try packReturnValue(ret); + out_results[0] = val; + } + } + + fn buildCallArgs(instance_ptr: *anyopaque, args: []const Value) !std.meta.ArgsTuple(@TypeOf(func)) { + var tuple: std.meta.ArgsTuple(@TypeOf(func)) = undefined; + if (has_instance_param) { + const first_ty = fn_info.params[0].type orelse unreachable; + tuple[0] = castInstancePtr(first_ty, instance_ptr); + } + inline for (fn_info.params[wasm_param_start..], 0..) |p, i| { + const ty = p.type orelse unreachable; + tuple[wasm_param_start + i] = try valueAs(ty, args[i]); + } + return tuple; + } + + fn packReturnValue(ret: anytype) !Value { + const RetT = @TypeOf(ret); + if (comptime @typeInfo(RetT) == .error_union) { + const payload = try ret; + return fromZigValue(payload); + } + return fromZigValue(ret); + } + }; + + try self.addFuncRaw( + @tagName(module_name), + @tagName(func_name), + .{ .params = params, .results = results }, + null, + F.invoke, + ); + } + + pub fn addFuncRaw( + self: *ImportSet, + module_name: []const u8, + func_name: []const u8, + type_: FuncType, + userdata: ?*anyopaque, + invoke: *const fn (*anyopaque, []const Value, []Value, ?*anyopaque) anyerror!void, + ) !void { + const module_copy = try self.allocator.dupe(u8, module_name); + errdefer self.allocator.free(module_copy); + const name_copy = try self.allocator.dupe(u8, func_name); + errdefer self.allocator.free(name_copy); + + try self.functions.append(self.allocator, .{ + .name = name_copy, + .module = module_copy, + .type_ = type_, + .userdata = userdata, + .invoke = invoke, + }); + } + + pub fn findFunc(self: *const ImportSet, module_name: []const u8, func_name: []const u8) ?*const HostFunc { + for (self.functions.items) |*f| { + if (std.mem.eql(u8, f.module, module_name) and std.mem.eql(u8, f.name, func_name)) { + return f; + } + } + return null; + } +}; + +pub fn readStruct( + comptime T: type, + memory: *const Memory, + offset: u32, +) !T { + comptime ensureAbiStruct(T); + if (@as(usize, offset) + @sizeOf(T) > memory.bytes.len) return error.OutOfBounds; + var out: T = undefined; + @memcpy(std.mem.asBytes(&out), memory.bytes[offset..][0..@sizeOf(T)]); + return out; +} + +pub fn writeStruct( + comptime T: type, + memory: *Memory, + offset: u32, + value: T, +) !void { + comptime ensureAbiStruct(T); + if (@as(usize, offset) + @sizeOf(T) > memory.bytes.len) return error.OutOfBounds; + @memcpy(memory.bytes[offset..][0..@sizeOf(T)], std.mem.asBytes(&value)); +} + +fn ensureAbiStruct(comptime T: type) void { + const info = @typeInfo(T); + if (info != .@"struct" or info.@"struct".layout != .@"extern") { + @compileError("T must be an extern struct"); + } + inline for (info.@"struct".fields) |f| { + switch (f.type) { + i32, i64, f32, f64 => {}, + else => @compileError("extern struct fields must be i32/i64/f32/f64"), + } + } +} + +fn zigTypeToValType(comptime T: type) ValType { + return switch (T) { + i32 => .i32, + i64 => .i64, + f32 => .f32, + f64 => .f64, + else => @compileError("unsupported host function type; expected i32/i64/f32/f64"), + }; +} + +fn hasInstanceParam(comptime fn_info: std.builtin.Type.Fn) bool { + if (fn_info.params.len == 0) return false; + const first_ty = fn_info.params[0].type orelse return false; + return isInstanceParamType(first_ty); +} + +fn isInstanceParamType(comptime T: type) bool { + const ti = @typeInfo(T); + if (ti != .pointer) return false; + return ti.pointer.child == anyopaque and ti.pointer.size == .one; +} + +fn castInstancePtr(comptime T: type, p: *anyopaque) T { + return @ptrCast(p); +} + +fn isErrorUnionVoid(comptime T: type) bool { + const info = @typeInfo(T); + return info == .error_union and info.error_union.payload == void; +} + +fn resultInnerType(comptime T: type) type { + return switch (@typeInfo(T)) { + .error_union => |eu| eu.payload, + else => T, + }; +} + +fn valueAs(comptime T: type, v: Value) !T { + return switch (T) { + i32 => switch (v) { .i32 => |x| x, else => error.TypeMismatch }, + i64 => switch (v) { .i64 => |x| x, else => error.TypeMismatch }, + f32 => switch (v) { .f32 => |x| x, else => error.TypeMismatch }, + f64 => switch (v) { .f64 => |x| x, else => error.TypeMismatch }, + else => @compileError("unsupported host function type; expected i32/i64/f32/f64"), + }; +} + +fn fromZigValue(v: anytype) Value { + return switch (@TypeOf(v)) { + i32 => .{ .i32 = v }, + i64 => .{ .i64 = v }, + f32 => .{ .f32 = v }, + f64 => .{ .f64 = v }, + else => @compileError("unsupported host function return type; expected i32/i64/f32/f64"), + }; +} + +test "addFuncRaw and findFunc" { + const Invoke = struct { + fn f(_: *anyopaque, args: []const Value, results: []Value, _: ?*anyopaque) !void { + results[0] = .{ .i32 = args[0].i32 + args[1].i32 }; + } + }; + + const ally = std.testing.allocator; + var imports = ImportSet.init(ally); + defer imports.deinit(); + + const params = try ally.alloc(ValType, 2); + const results = try ally.alloc(ValType, 1); + params[0] = .i32; + params[1] = .i32; + results[0] = .i32; + + try imports.addFuncRaw("env", "add", .{ .params = params, .results = results }, null, Invoke.f); + const found = imports.findFunc("env", "add") orelse return error.TestUnexpectedResult; + var out = [_]Value{.{ .i32 = 0 }}; + try found.invoke(@ptrFromInt(1), &.{ .{ .i32 = 2 }, .{ .i32 = 3 } }, &out, null); + try std.testing.expectEqual(@as(i32, 5), out[0].i32); +} + +test "addFunc typed wrapper" { + const ally = std.testing.allocator; + var imports = ImportSet.init(ally); + defer imports.deinit(); + + try imports.addFunc(.env, .mul, struct { + fn mul(a: i32, b: i32) i32 { + return a * b; + } + }.mul); + + const found = imports.findFunc("env", "mul") orelse return error.TestUnexpectedResult; + var out = [_]Value{.{ .i32 = 0 }}; + try found.invoke(@ptrFromInt(1), &.{ .{ .i32 = 6 }, .{ .i32 = 7 } }, &out, null); + try std.testing.expectEqual(@as(i32, 42), out[0].i32); +} + +test "addFunc typed wrapper with instance context parameter" { + const ally = std.testing.allocator; + var imports = ImportSet.init(ally); + defer imports.deinit(); + + try imports.addFunc(.env, .ctx_add, struct { + fn ctx_add(instance_ptr: *anyopaque, a: i32, b: i32) i32 { + _ = instance_ptr; + return a + b + 1; + } + }.ctx_add); + + const found = imports.findFunc("env", "ctx_add") orelse return error.TestUnexpectedResult; + var out = [_]Value{.{ .i32 = 0 }}; + try found.invoke(@ptrFromInt(1234), &.{ .{ .i32 = 5 }, .{ .i32 = 6 } }, &out, null); + try std.testing.expectEqual(@as(i32, 12), out[0].i32); +} + +test "readStruct/writeStruct round trip" { + const Vec2 = extern struct { x: f32, y: f32 }; + const ally = std.testing.allocator; + var mem = try Memory.init(ally, 1, null); + defer mem.deinit(ally); + + try writeStruct(Vec2, &mem, 12, .{ .x = 1.5, .y = -3.0 }); + const v = try readStruct(Vec2, &mem, 12); + try std.testing.expectApproxEqAbs(@as(f32, 1.5), v.x, 0.0001); + try std.testing.expectApproxEqAbs(@as(f32, -3.0), v.y, 0.0001); +} diff --git a/src/wasm/instance.zig b/src/wasm/instance.zig new file mode 100644 index 0000000..54cbf45 --- /dev/null +++ b/src/wasm/instance.zig @@ -0,0 +1,2310 @@ +const std = @import("std"); +const module = @import("module.zig"); +const binary = @import("binary.zig"); +const runtime = @import("runtime.zig"); +const host = @import("host.zig"); +const trap = @import("trap.zig"); +const jit_codegen = @import("jit/codegen.zig"); + +pub const Value = runtime.Value; +pub const ImportSet = host.ImportSet; +pub const HostFunc = host.HostFunc; + +threadlocal var tls_jit_instance: ?*ModuleInstance = null; + +pub const CompiledFunc = union(enum) { + host: *const HostFunc, + jit: struct { buf_idx: u32 }, +}; + +pub const ModuleInstance = struct { + module: *const module.Module, + owned_module: ?*module.Module, + memory: ?runtime.Memory, + table: ?runtime.Table, + globals: []runtime.Value, + data_dropped: []bool, + functions: []CompiledFunc, + jit_buffers: std.ArrayList(jit_codegen.JitResult), + last_trap: ?trap.TrapCode, + allocator: std.mem.Allocator, + + pub fn instantiate(allocator: std.mem.Allocator, mod: *const module.Module, imports: *const ImportSet) !ModuleInstance { + var inst = ModuleInstance{ + .module = mod, + .owned_module = null, + .memory = null, + .table = null, + .globals = &.{}, + .data_dropped = &.{}, + .functions = &.{}, + .jit_buffers = .empty, + .last_trap = null, + .allocator = allocator, + }; + errdefer inst.deinit(); + + if (mod.memories.len > 1) return error.UnsupportedMultipleMemories; + if (mod.tables.len > 1) return error.UnsupportedMultipleTables; + + var num_imported_funcs: u32 = 0; + for (mod.imports) |imp| { + if (imp.desc == .func) num_imported_funcs += 1; + } + + const total_funcs = num_imported_funcs + @as(u32, @intCast(mod.functions.len)); + inst.functions = try allocator.alloc(CompiledFunc, total_funcs); + + var imported_func_slot: u32 = 0; + for (mod.imports) |imp| { + if (imp.desc != .func) continue; + const hf = imports.findFunc(imp.module, imp.name) orelse return error.MissingImport; + const expected_type = &mod.types[imp.desc.func]; + if (!sameFuncType(expected_type, &hf.type_)) return error.ImportTypeMismatch; + inst.functions[imported_func_slot] = .{ .host = hf }; + imported_func_slot += 1; + } + + const helpers: jit_codegen.HelperAddrs = .{ + .call = @intFromPtr(&jitCallHelper), + .@"unreachable" = @intFromPtr(&jitUnreachableHelper), + .global_get = @intFromPtr(&jitGlobalGetHelper), + .global_set = @intFromPtr(&jitGlobalSetHelper), + .mem_load = @intFromPtr(&jitMemLoadHelper), + .mem_store = @intFromPtr(&jitMemStoreHelper), + .i32_unary = @intFromPtr(&jitI32UnaryHelper), + .i32_cmp = @intFromPtr(&jitI32CmpHelper), + .i32_binary = @intFromPtr(&jitI32BinaryHelper), + .i32_div_s = @intFromPtr(&jitI32DivSHelper), + .i32_div_u = @intFromPtr(&jitI32DivUHelper), + .i32_rem_s = @intFromPtr(&jitI32RemSHelper), + .i32_rem_u = @intFromPtr(&jitI32RemUHelper), + .i64_eqz = @intFromPtr(&jitI64EqzHelper), + .i64_cmp = @intFromPtr(&jitI64CmpHelper), + .i64_unary = @intFromPtr(&jitI64UnaryHelper), + .i64_binary = @intFromPtr(&jitI64BinaryHelper), + .f32_cmp = @intFromPtr(&jitF32CmpHelper), + .f64_cmp = @intFromPtr(&jitF64CmpHelper), + .f32_unary = @intFromPtr(&jitF32UnaryHelper), + .f32_binary = @intFromPtr(&jitF32BinaryHelper), + .f64_unary = @intFromPtr(&jitF64UnaryHelper), + .f64_binary = @intFromPtr(&jitF64BinaryHelper), + .convert = @intFromPtr(&jitConvertHelper), + .trunc_sat = @intFromPtr(&jitTruncSatHelper), + .i_extend = @intFromPtr(&jitIExtendHelper), + .memory_init = @intFromPtr(&jitMemoryInitHelper), + .data_drop = @intFromPtr(&jitDataDropHelper), + .memory_copy = @intFromPtr(&jitMemoryCopyHelper), + .memory_fill = @intFromPtr(&jitMemoryFillHelper), + .table_size = @intFromPtr(&jitTableSizeHelper), + .memory_size = @intFromPtr(&jitMemorySizeHelper), + .memory_grow = @intFromPtr(&jitMemoryGrowHelper), + .call_indirect = @intFromPtr(&jitCallIndirectHelper), + }; + for (mod.functions, 0..) |type_idx, i| { + const global_idx: u32 = num_imported_funcs + @as(u32, @intCast(i)); + const ft = &mod.types[type_idx]; + const body = &mod.codes[i]; + if (try jit_codegen.compileSimpleI32(allocator, mod, num_imported_funcs, global_idx, body, ft, helpers)) |jit_res| { + var owned = jit_res; + errdefer owned.buf.deinit(); + const buf_idx: u32 = @intCast(inst.jit_buffers.items.len); + try inst.jit_buffers.append(allocator, owned); + inst.functions[global_idx] = .{ .jit = .{ .buf_idx = buf_idx } }; + } else { + return error.UnsupportedOpcode; + } + } + + if (mod.memories.len == 1) { + const mem_ty = mod.memories[0]; + inst.memory = try runtime.Memory.init(allocator, mem_ty.min, mem_ty.max); + } + + if (mod.tables.len == 1) { + const tab_ty = mod.tables[0]; + inst.table = try runtime.Table.init(allocator, tab_ty.min, tab_ty.max); + } + + inst.globals = try allocator.alloc(runtime.Value, mod.globals.len); + for (mod.globals, 0..) |g, i| { + inst.globals[i] = try evalConstExpr(&inst, g.init, i); + if (!valueMatchesType(inst.globals[i], g.type.valtype)) return error.ConstExprTypeMismatch; + } + inst.data_dropped = try allocator.alloc(bool, mod.datas.len); + @memset(inst.data_dropped, false); + + try inst.applyDataSegments(); + try inst.applyElementSegments(total_funcs); + + if (mod.start) |start_idx| { + const start_results = try inst.executeFunction(start_idx, &.{}); + defer allocator.free(start_results); + if (start_results.len != 0) return error.InvalidStartFunction; + } + + return inst; + } + + pub fn instantiateOwned(allocator: std.mem.Allocator, owned_mod: *module.Module, imports: *const ImportSet) !ModuleInstance { + var inst = try instantiate(allocator, owned_mod, imports); + inst.owned_module = owned_mod; + return inst; + } + + pub fn deinit(self: *ModuleInstance) void { + if (self.memory) |*mem| mem.deinit(self.allocator); + if (self.table) |*tab| tab.deinit(self.allocator); + self.allocator.free(self.globals); + self.allocator.free(self.data_dropped); + self.allocator.free(self.functions); + for (self.jit_buffers.items) |*jit| jit.buf.deinit(); + self.jit_buffers.deinit(self.allocator); + if (self.owned_module) |mod_ptr| { + mod_ptr.deinit(); + self.allocator.destroy(mod_ptr); + } + self.* = undefined; + } + + pub fn callExport(self: *ModuleInstance, name: []const u8, args: []const Value) ![]Value { + for (self.module.exports) |exp| { + switch (exp.desc) { + .func => |idx| { + if (std.mem.eql(u8, exp.name, name)) { + return self.executeFunction(idx, args); + } + }, + else => {}, + } + } + return error.ExportNotFound; + } + + pub fn memorySlice(self: *const ModuleInstance, offset: u32, len: u32) ![]u8 { + const mem = self.memory orelse return error.UndefinedMemory; + const start = @as(usize, offset); + const end = start + @as(usize, len); + if (end > mem.bytes.len) return error.OutOfBounds; + return mem.bytes[start..end]; + } + + pub fn lastTrapCode(self: *const ModuleInstance) ?trap.TrapCode { + return self.last_trap; + } + + pub fn clearLastTrap(self: *ModuleInstance) void { + self.last_trap = null; + } + + fn executeFunction(self: *ModuleInstance, func_idx: u32, args: []const Value) anyerror![]Value { + const ft = try self.functionType(func_idx); + if (args.len != ft.params.len) return error.InvalidArgumentCount; + for (ft.params, args, 0..) |expected, arg, i| { + if (!valueMatchesType(arg, expected)) { + _ = i; + return error.TypeMismatch; + } + } + + const fn_ref = self.functions[func_idx]; + return switch (fn_ref) { + .host => |hf| blk: { + const out = try self.allocator.alloc(Value, ft.results.len); + errdefer self.allocator.free(out); + try hf.invoke(@ptrCast(self), args, out, hf.userdata); + break :blk out; + }, + .jit => |j| try self.callJitted(j.buf_idx, args, ft), + }; + } + + fn setTrap(self: *ModuleInstance, code: trap.TrapCode) void { + if (self.last_trap == null) self.last_trap = code; + } + + fn setThreadTrap(code: trap.TrapCode) void { + if (tls_jit_instance) |inst| inst.setTrap(code); + } + + fn callJitted(self: *ModuleInstance, buf_idx: u32, args: []const Value, ft: *const module.FuncType) ![]Value { + const jit = self.jit_buffers.items[buf_idx]; + _ = jit.arity; + var small_args: [16]u64 = undefined; + var heap_args: []u64 = &.{}; + defer if (heap_args.len != 0) self.allocator.free(heap_args); + + const arg_words: []u64 = blk: { + if (args.len <= small_args.len) { + for (args, 0..) |arg, i| { + small_args[i] = switch (arg) { + .i32 => |v| @as(u64, @bitCast(@as(i64, v))), + .i64 => |v| @as(u64, @bitCast(v)), + .f32 => |v| @as(u64, @as(u32, @bitCast(v))), + .f64 => |v| @as(u64, @bitCast(v)), + }; + } + break :blk small_args[0..args.len]; + } + heap_args = try self.allocator.alloc(u64, args.len); + for (args, 0..) |arg, i| { + heap_args[i] = switch (arg) { + .i32 => |v| @as(u64, @bitCast(@as(i64, v))), + .i64 => |v| @as(u64, @bitCast(v)), + .f32 => |v| @as(u64, @as(u32, @bitCast(v))), + .f64 => |v| @as(u64, @bitCast(v)), + }; + } + break :blk heap_args; + }; + + var zero: u64 = 0; + const arg_ptr: [*]const u64 = if (arg_words.len == 0) @ptrCast(&zero) else arg_words.ptr; + self.last_trap = null; + const prev_tls = tls_jit_instance; + tls_jit_instance = self; + defer tls_jit_instance = prev_tls; + const raw_result: u64 = jit.buf.funcPtr(fn (*ModuleInstance, [*]const u64, u32) callconv(.c) u64, 0)( + self, + arg_ptr, + @intCast(arg_words.len), + ); + if (self.last_trap != null) return error.WasmTrap; + if (ft.results.len == 0) { + return self.allocator.alloc(Value, 0); + } + if (ft.results.len != 1) return error.UnsupportedJitResultArity; + const out = try self.allocator.alloc(Value, 1); + errdefer self.allocator.free(out); + out[0] = switch (ft.results[0]) { + .i32 => .{ .i32 = @bitCast(@as(u32, @truncate(raw_result))) }, + .i64 => .{ .i64 = @bitCast(raw_result) }, + .f32 => .{ .f32 = @bitCast(@as(u32, @truncate(raw_result))) }, + .f64 => .{ .f64 = @bitCast(raw_result) }, + }; + return out; + } + + fn jitCallHelper(instance: *ModuleInstance, func_idx: u32, arg_ptr: [*]const u64, argc: u32) callconv(.c) u64 { + const ft = instance.functionType(func_idx) catch { + instance.setTrap(.invalid_function); + return 0; + }; + if (ft.params.len != argc) { + instance.setTrap(.invalid_function); + return 0; + } + var small_args: [16]Value = undefined; + var heap_args: []Value = &.{}; + defer if (heap_args.len != 0) instance.allocator.free(heap_args); + + const arg_slice: []Value = blk: { + if (ft.params.len <= small_args.len) { + for (0..ft.params.len) |i| { + small_args[i] = switch (ft.params[i]) { + .i32 => .{ .i32 = @bitCast(@as(u32, @truncate(arg_ptr[i]))) }, + .i64 => .{ .i64 = @bitCast(arg_ptr[i]) }, + .f32 => .{ .f32 = @bitCast(@as(u32, @truncate(arg_ptr[i]))) }, + .f64 => .{ .f64 = @bitCast(arg_ptr[i]) }, + }; + } + break :blk small_args[0..ft.params.len]; + } + heap_args = instance.allocator.alloc(Value, ft.params.len) catch { + instance.setTrap(.call_stack_exhausted); + return 0; + }; + for (0..ft.params.len) |i| { + heap_args[i] = switch (ft.params[i]) { + .i32 => .{ .i32 = @bitCast(@as(u32, @truncate(arg_ptr[i]))) }, + .i64 => .{ .i64 = @bitCast(arg_ptr[i]) }, + .f32 => .{ .f32 = @bitCast(@as(u32, @truncate(arg_ptr[i]))) }, + .f64 => .{ .f64 = @bitCast(arg_ptr[i]) }, + }; + } + break :blk heap_args; + }; + const out = instance.executeFunction(func_idx, arg_slice) catch { + instance.setTrap(.call_stack_exhausted); + return 0; + }; + defer instance.allocator.free(out); + if (out.len == 0) return 0; + return switch (out[0]) { + .i32 => |v| @as(u64, @bitCast(@as(i64, v))), + .i64 => |v| @as(u64, @bitCast(v)), + .f32 => |v| @as(u64, @as(u32, @bitCast(v))), + .f64 => |v| @as(u64, @bitCast(v)), + }; + } + + fn jitUnreachableHelper(_: *ModuleInstance) callconv(.c) i32 { + setThreadTrap(.@"unreachable"); + return 0; + } + + fn jitGlobalGetHelper(instance: *ModuleInstance, global_idx: u32) callconv(.c) u64 { + if (global_idx >= instance.globals.len) { + instance.setTrap(.undefined_global); + return 0; + } + return switch (instance.globals[global_idx]) { + .i32 => |v| @as(u64, @bitCast(@as(i64, v))), + .i64 => |v| @as(u64, @bitCast(v)), + .f32 => |v| @as(u64, @as(u32, @bitCast(v))), + .f64 => |v| @as(u64, @bitCast(v)), + }; + } + + fn jitGlobalSetHelper(instance: *ModuleInstance, global_idx: u32, value: u64) callconv(.c) u64 { + if (global_idx >= instance.globals.len) { + instance.setTrap(.undefined_global); + return 0; + } + switch (instance.globals[global_idx]) { + .i32 => instance.globals[global_idx] = .{ .i32 = @bitCast(@as(u32, @truncate(value))) }, + .i64 => instance.globals[global_idx] = .{ .i64 = @bitCast(value) }, + .f32 => instance.globals[global_idx] = .{ .f32 = @bitCast(@as(u32, @truncate(value))) }, + .f64 => instance.globals[global_idx] = .{ .f64 = @bitCast(value) }, + } + return 0; + } + + fn jitMemLoadHelper(instance: *ModuleInstance, addr: i32, offset: u32, op: u32) callconv(.c) u64 { + const mem = instance.memory orelse { + instance.setTrap(.memory_out_of_bounds); + return 0; + }; + const base: u32 = @bitCast(addr); + const eff = base +% offset; + const loadFail = struct { + fn fail(inst: *ModuleInstance) u64 { + inst.setTrap(.memory_out_of_bounds); + return 0; + } + }; + return switch (op) { + 0x28 => @as(u64, @bitCast(@as(i64, mem.load(i32, eff) catch return loadFail.fail(instance)))), + 0x29 => @as(u64, @bitCast(mem.load(i64, eff) catch return loadFail.fail(instance))), + 0x2A => @as(u64, @as(u32, @bitCast(mem.load(f32, eff) catch return loadFail.fail(instance)))), + 0x2B => @as(u64, @bitCast(mem.load(f64, eff) catch return loadFail.fail(instance))), + 0x2C => @as(u64, @bitCast(@as(i64, mem.load(i8, eff) catch return loadFail.fail(instance)))), + 0x2D => @as(u64, @bitCast(@as(i64, mem.load(u8, eff) catch return loadFail.fail(instance)))), + 0x2E => @as(u64, @bitCast(@as(i64, mem.load(i16, eff) catch return loadFail.fail(instance)))), + 0x2F => @as(u64, @bitCast(@as(i64, mem.load(u16, eff) catch return loadFail.fail(instance)))), + 0x30 => @as(u64, @bitCast(@as(i64, mem.load(i8, eff) catch return loadFail.fail(instance)))), + 0x31 => @as(u64, @bitCast(@as(i64, mem.load(u8, eff) catch return loadFail.fail(instance)))), + 0x32 => @as(u64, @bitCast(@as(i64, mem.load(i16, eff) catch return loadFail.fail(instance)))), + 0x33 => @as(u64, @bitCast(@as(i64, mem.load(u16, eff) catch return loadFail.fail(instance)))), + 0x34 => @as(u64, @bitCast(@as(i64, mem.load(i32, eff) catch return loadFail.fail(instance)))), + 0x35 => @as(u64, @bitCast(@as(i64, mem.load(u32, eff) catch return loadFail.fail(instance)))), + else => blk: { + instance.setTrap(.memory_out_of_bounds); + break :blk 0; + }, + }; + } + + fn jitMemStoreHelper(instance: *ModuleInstance, addr: i32, offset: u32, op: u32, value: u64) callconv(.c) u64 { + var mem = instance.memory orelse { + instance.setTrap(.memory_out_of_bounds); + return 0; + }; + const base: u32 = @bitCast(addr); + const eff = base +% offset; + switch (op) { + 0x36 => mem.store(i32, eff, @bitCast(@as(u32, @truncate(value)))) catch { + instance.setTrap(.memory_out_of_bounds); + return 0; + }, + 0x37 => mem.store(i64, eff, @bitCast(value)) catch { + instance.setTrap(.memory_out_of_bounds); + return 0; + }, + 0x38 => mem.store(f32, eff, @bitCast(@as(u32, @truncate(value)))) catch { + instance.setTrap(.memory_out_of_bounds); + return 0; + }, + 0x39 => mem.store(f64, eff, @bitCast(value)) catch { + instance.setTrap(.memory_out_of_bounds); + return 0; + }, + 0x3A => mem.store(u8, eff, @truncate(value)) catch { + instance.setTrap(.memory_out_of_bounds); + return 0; + }, + 0x3B => mem.store(u16, eff, @truncate(value)) catch { + instance.setTrap(.memory_out_of_bounds); + return 0; + }, + 0x3C => mem.store(u8, eff, @truncate(value)) catch { + instance.setTrap(.memory_out_of_bounds); + return 0; + }, + 0x3D => mem.store(u16, eff, @truncate(value)) catch { + instance.setTrap(.memory_out_of_bounds); + return 0; + }, + 0x3E => mem.store(u32, eff, @truncate(value)) catch { + instance.setTrap(.memory_out_of_bounds); + return 0; + }, + else => { + instance.setTrap(.memory_out_of_bounds); + return 0; + }, + } + instance.memory = mem; + return 0; + } + + fn jitMemorySizeHelper(instance: *ModuleInstance) callconv(.c) i32 { + const mem = instance.memory orelse { + instance.setTrap(.memory_out_of_bounds); + return -1; + }; + return @intCast(mem.bytes.len / runtime.PAGE_SIZE); + } + + fn jitMemoryInitHelper(instance: *ModuleInstance, dst: i32, src: i32, len: i32, data_idx: u32) callconv(.c) u64 { + var mem = instance.memory orelse { + instance.setTrap(.memory_out_of_bounds); + return 0; + }; + if (data_idx >= instance.module.datas.len) { + instance.setTrap(.memory_out_of_bounds); + return 0; + } + const seg = instance.module.datas[data_idx]; + const src_bytes: []const u8 = if (instance.data_dropped[data_idx]) &.{} else seg.bytes; + + const d: usize = @as(u32, @bitCast(dst)); + const s: usize = @as(u32, @bitCast(src)); + const n: usize = @as(u32, @bitCast(len)); + if (d > mem.bytes.len or s > src_bytes.len) { + instance.setTrap(.memory_out_of_bounds); + return 0; + } + if (n > mem.bytes.len - d or n > src_bytes.len - s) { + instance.setTrap(.memory_out_of_bounds); + return 0; + } + if (n > 0) @memcpy(mem.bytes[d..][0..n], src_bytes[s..][0..n]); + instance.memory = mem; + return 0; + } + + fn jitDataDropHelper(instance: *ModuleInstance, data_idx: u32) callconv(.c) u64 { + if (data_idx >= instance.data_dropped.len) { + instance.setTrap(.memory_out_of_bounds); + return 0; + } + instance.data_dropped[data_idx] = true; + return 0; + } + + fn jitTableSizeHelper(instance: *ModuleInstance, table_idx: u32) callconv(.c) i32 { + if (table_idx != 0) { + instance.setTrap(.undefined_table); + return -1; + } + const tab = instance.table orelse { + instance.setTrap(.undefined_table); + return -1; + }; + return @intCast(tab.elements.len); + } + + fn jitMemoryCopyHelper(instance: *ModuleInstance, dst: i32, src: i32, len: i32) callconv(.c) u64 { + var mem = instance.memory orelse { + instance.setTrap(.memory_out_of_bounds); + return 0; + }; + const dst_u32: u32 = @bitCast(dst); + const src_u32: u32 = @bitCast(src); + const len_u32: u32 = @bitCast(len); + const d: usize = dst_u32; + const s: usize = src_u32; + const n: usize = len_u32; + if (d > mem.bytes.len or s > mem.bytes.len) { + instance.setTrap(.memory_out_of_bounds); + return 0; + } + if (n > mem.bytes.len - d or n > mem.bytes.len - s) { + instance.setTrap(.memory_out_of_bounds); + return 0; + } + if (n == 0) return 0; + if (d <= s) { + std.mem.copyForwards(u8, mem.bytes[d..][0..n], mem.bytes[s..][0..n]); + } else { + std.mem.copyBackwards(u8, mem.bytes[d..][0..n], mem.bytes[s..][0..n]); + } + instance.memory = mem; + return 0; + } + + fn jitMemoryFillHelper(instance: *ModuleInstance, dst: i32, value: i32, len: i32) callconv(.c) u64 { + var mem = instance.memory orelse { + instance.setTrap(.memory_out_of_bounds); + return 0; + }; + const dst_u32: u32 = @bitCast(dst); + const len_u32: u32 = @bitCast(len); + const d: usize = dst_u32; + const n: usize = len_u32; + if (d > mem.bytes.len) { + instance.setTrap(.memory_out_of_bounds); + return 0; + } + if (n > mem.bytes.len - d) { + instance.setTrap(.memory_out_of_bounds); + return 0; + } + const b: u8 = @truncate(@as(u32, @bitCast(value))); + @memset(mem.bytes[d..][0..n], b); + instance.memory = mem; + return 0; + } + + fn jitMemoryGrowHelper(instance: *ModuleInstance, delta: i32) callconv(.c) i32 { + if (delta < 0) return -1; + var mem = instance.memory orelse { + instance.setTrap(.memory_out_of_bounds); + return -1; + }; + const old = mem.grow(instance.allocator, @intCast(delta)) catch return -1; + instance.memory = mem; + return @intCast(old); + } + + fn jitI32UnaryHelper(op: u32, a: i32) callconv(.c) i32 { + return switch (op) { + 0x45 => if (a == 0) 1 else 0, // i32.eqz + 0x67 => @intCast(@clz(@as(u32, @bitCast(a)))), // i32.clz + 0x68 => @intCast(@ctz(@as(u32, @bitCast(a)))), // i32.ctz + 0x69 => @intCast(@popCount(@as(u32, @bitCast(a)))), // i32.popcnt + else => 0, + }; + } + + fn jitI32CmpHelper(op: u32, a: i32, b: i32) callconv(.c) i32 { + const au: u32 = @bitCast(a); + const bu: u32 = @bitCast(b); + const out = switch (op) { + 0x46 => a == b, // i32.eq + 0x47 => a != b, // i32.ne + 0x48 => a < b, // i32.lt_s + 0x49 => au < bu, // i32.lt_u + 0x4A => a > b, // i32.gt_s + 0x4B => au > bu, // i32.gt_u + 0x4C => a <= b, // i32.le_s + 0x4D => au <= bu, // i32.le_u + 0x4E => a >= b, // i32.ge_s + 0x4F => au >= bu, // i32.ge_u + else => false, + }; + return if (out) 1 else 0; + } + + fn jitI32BinaryHelper(op: u32, a: i32, b: i32) callconv(.c) i32 { + const au: u32 = @bitCast(a); + const bu: u32 = @bitCast(b); + return switch (op) { + 0x6A => a +% b, // i32.add + 0x6B => a -% b, // i32.sub + 0x6C => a *% b, // i32.mul + 0x6D => jitI32DivSHelper(a, b), // i32.div_s + 0x6E => jitI32DivUHelper(a, b), // i32.div_u + 0x6F => jitI32RemSHelper(a, b), // i32.rem_s + 0x70 => jitI32RemUHelper(a, b), // i32.rem_u + 0x71 => @bitCast(au & bu), // i32.and + 0x72 => @bitCast(au | bu), // i32.or + 0x73 => @bitCast(au ^ bu), // i32.xor + 0x74 => @bitCast(au << @intCast(bu & 31)), // i32.shl + 0x75 => a >> @intCast(bu & 31), // i32.shr_s + 0x76 => @bitCast(au >> @intCast(bu & 31)), // i32.shr_u + 0x77 => @bitCast(std.math.rotl(u32, au, @as(u5, @intCast(bu & 31)))), // i32.rotl + 0x78 => @bitCast(std.math.rotr(u32, au, @as(u5, @intCast(bu & 31)))), // i32.rotr + else => 0, + }; + } + + fn jitI32DivSHelper(a: i32, b: i32) callconv(.c) i32 { + if (b == 0) { + setThreadTrap(.integer_divide_by_zero); + return 0; + } + if (a == std.math.minInt(i32) and b == -1) { + setThreadTrap(.integer_overflow); + return 0; + } + return @divTrunc(a, b); + } + + fn jitI32DivUHelper(a: i32, b: i32) callconv(.c) i32 { + const au: u32 = @bitCast(a); + const bu: u32 = @bitCast(b); + if (bu == 0) { + setThreadTrap(.integer_divide_by_zero); + return 0; + } + return @bitCast(@divTrunc(au, bu)); + } + + fn jitI32RemSHelper(a: i32, b: i32) callconv(.c) i32 { + if (b == 0) { + setThreadTrap(.integer_divide_by_zero); + return 0; + } + if (a == std.math.minInt(i32) and b == -1) return 0; + return @rem(a, b); + } + + fn jitI32RemUHelper(a: i32, b: i32) callconv(.c) i32 { + const au: u32 = @bitCast(a); + const bu: u32 = @bitCast(b); + if (bu == 0) { + setThreadTrap(.integer_divide_by_zero); + return 0; + } + return @bitCast(@rem(au, bu)); + } + + fn jitI64EqzHelper(a: u64) callconv(.c) u64 { + return if (a == 0) 1 else 0; + } + + fn jitI64CmpHelper(op: u32, a: u64, b: u64) callconv(.c) u64 { + const as: i64 = @bitCast(a); + const bs: i64 = @bitCast(b); + const out: bool = switch (op) { + 0x51 => as == bs, + 0x52 => as != bs, + 0x53 => as < bs, + 0x54 => a < b, + 0x55 => as > bs, + 0x56 => a > b, + 0x57 => as <= bs, + 0x58 => a <= b, + 0x59 => as >= bs, + 0x5A => a >= b, + else => false, + }; + return if (out) 1 else 0; + } + + fn jitI64UnaryHelper(op: u32, a: u64) callconv(.c) u64 { + return switch (op) { + 0x79 => @clz(a), + 0x7A => @ctz(a), + 0x7B => @popCount(a), + else => 0, + }; + } + + fn jitI64BinaryHelper(op: u32, a: u64, b: u64) callconv(.c) u64 { + const as: i64 = @bitCast(a); + const bs: i64 = @bitCast(b); + return switch (op) { + 0x7C => a +% b, + 0x7D => a -% b, + 0x7E => a *% b, + 0x7F => blk: { + if (bs == 0) { + setThreadTrap(.integer_divide_by_zero); + break :blk 0; + } + if (as == std.math.minInt(i64) and bs == -1) { + setThreadTrap(.integer_overflow); + break :blk 0; + } + break :blk @bitCast(@divTrunc(as, bs)); + }, + 0x80 => blk: { + if (b == 0) { + setThreadTrap(.integer_divide_by_zero); + break :blk 0; + } + break :blk @divTrunc(a, b); + }, + 0x81 => blk: { + if (bs == 0) { + setThreadTrap(.integer_divide_by_zero); + break :blk 0; + } + if (as == std.math.minInt(i64) and bs == -1) break :blk 0; + break :blk @bitCast(@rem(as, bs)); + }, + 0x82 => blk: { + if (b == 0) { + setThreadTrap(.integer_divide_by_zero); + break :blk 0; + } + break :blk @rem(a, b); + }, + 0x83 => a & b, + 0x84 => a | b, + 0x85 => a ^ b, + 0x86 => a << @intCast(b & 63), + 0x87 => @bitCast(as >> @intCast(b & 63)), + 0x88 => a >> @intCast(b & 63), + 0x89 => std.math.rotl(u64, a, @as(u6, @intCast(b & 63))), + 0x8A => std.math.rotr(u64, a, @as(u6, @intCast(b & 63))), + else => 0, + }; + } + + fn jitF32CmpHelper(op: u32, a_bits: u64, b_bits: u64) callconv(.c) u64 { + const a: f32 = @bitCast(@as(u32, @truncate(a_bits))); + const b: f32 = @bitCast(@as(u32, @truncate(b_bits))); + const out: bool = switch (op) { + 0x5B => a == b, + 0x5C => a != b, + 0x5D => a < b, + 0x5E => a > b, + 0x5F => a <= b, + 0x60 => a >= b, + else => false, + }; + return if (out) 1 else 0; + } + + fn jitF64CmpHelper(op: u32, a_bits: u64, b_bits: u64) callconv(.c) u64 { + const a: f64 = @bitCast(a_bits); + const b: f64 = @bitCast(b_bits); + const out: bool = switch (op) { + 0x61 => a == b, + 0x62 => a != b, + 0x63 => a < b, + 0x64 => a > b, + 0x65 => a <= b, + 0x66 => a >= b, + else => false, + }; + return if (out) 1 else 0; + } + + fn jitF32UnaryHelper(op: u32, a_bits: u64) callconv(.c) u64 { + const a: f32 = @bitCast(@as(u32, @truncate(a_bits))); + const r: f32 = switch (op) { + 0x8B => @abs(a), + 0x8C => -a, + 0x8D => std.math.ceil(a), + 0x8E => std.math.floor(a), + 0x8F => std.math.trunc(a), + 0x90 => std.math.round(a), + 0x91 => @sqrt(a), + else => 0, + }; + return @as(u64, @as(u32, @bitCast(r))); + } + + fn jitF32BinaryHelper(op: u32, a_bits: u64, b_bits: u64) callconv(.c) u64 { + const a: f32 = @bitCast(@as(u32, @truncate(a_bits))); + const b: f32 = @bitCast(@as(u32, @truncate(b_bits))); + const r: f32 = switch (op) { + 0x92 => a + b, + 0x93 => a - b, + 0x94 => a * b, + 0x95 => a / b, + 0x96 => @min(a, b), + 0x97 => @max(a, b), + 0x98 => std.math.copysign(a, b), + else => 0, + }; + return @as(u64, @as(u32, @bitCast(r))); + } + + fn jitF64UnaryHelper(op: u32, a_bits: u64) callconv(.c) u64 { + const a: f64 = @bitCast(a_bits); + const r: f64 = switch (op) { + 0x99 => @abs(a), + 0x9A => -a, + 0x9B => std.math.ceil(a), + 0x9C => std.math.floor(a), + 0x9D => std.math.trunc(a), + 0x9E => std.math.round(a), + 0x9F => @sqrt(a), + else => 0, + }; + return @as(u64, @bitCast(r)); + } + + fn jitF64BinaryHelper(op: u32, a_bits: u64, b_bits: u64) callconv(.c) u64 { + const a: f64 = @bitCast(a_bits); + const b: f64 = @bitCast(b_bits); + const r: f64 = switch (op) { + 0xA0 => a + b, + 0xA1 => a - b, + 0xA2 => a * b, + 0xA3 => a / b, + 0xA4 => @min(a, b), + 0xA5 => @max(a, b), + 0xA6 => std.math.copysign(a, b), + else => 0, + }; + return @as(u64, @bitCast(r)); + } + + fn jitConvertHelper(op: u32, a_bits: u64) callconv(.c) u64 { + return switch (op) { + 0xA7 => @as(u64, @bitCast(@as(i64, @intCast(@as(u32, @truncate(a_bits)))))), + 0xA8 => floatToIntI32F32(@bitCast(@as(u32, @truncate(a_bits))), false), + 0xA9 => floatToIntI32F32(@bitCast(@as(u32, @truncate(a_bits))), true), + 0xAA => floatToIntI32F64(@bitCast(a_bits), false), + 0xAB => floatToIntI32F64(@bitCast(a_bits), true), + 0xAC => @as(u64, @bitCast(@as(i64, @intCast(@as(u32, @truncate(a_bits)))))), + 0xAD => @as(u64, @as(u32, @truncate(a_bits))), + 0xAE => floatToIntI64F32(@bitCast(@as(u32, @truncate(a_bits))), false), + 0xAF => floatToIntI64F32(@bitCast(@as(u32, @truncate(a_bits))), true), + 0xB0 => floatToIntI64F64(@bitCast(a_bits), false), + 0xB1 => floatToIntI64F64(@bitCast(a_bits), true), + 0xB2 => @as(u64, @as(u32, @bitCast(@as(f32, @floatFromInt(@as(i32, @bitCast(@as(u32, @truncate(a_bits))))))))), + 0xB3 => @as(u64, @as(u32, @bitCast(@as(f32, @floatFromInt(@as(u32, @truncate(a_bits))))))), + 0xB4 => @as(u64, @as(u32, @bitCast(@as(f32, @floatFromInt(@as(i64, @bitCast(a_bits))))))), + 0xB5 => @as(u64, @as(u32, @bitCast(@as(f32, @floatFromInt(a_bits))))), + 0xB6 => @as(u64, @as(u32, @bitCast(@as(f32, @floatCast(@as(f64, @bitCast(a_bits))))))), + 0xB7 => @as(u64, @bitCast(@as(f64, @floatFromInt(@as(i32, @bitCast(@as(u32, @truncate(a_bits)))))))), + 0xB8 => @as(u64, @bitCast(@as(f64, @floatFromInt(@as(u32, @truncate(a_bits)))))), + 0xB9 => @as(u64, @bitCast(@as(f64, @floatFromInt(@as(i64, @bitCast(a_bits)))))), + 0xBA => @as(u64, @bitCast(@as(f64, @floatFromInt(a_bits)))), + 0xBB => @as(u64, @bitCast(@as(f64, @floatCast(@as(f32, @bitCast(@as(u32, @truncate(a_bits)))))))), + 0xBC => @as(u64, @as(u32, @bitCast(@as(f32, @bitCast(@as(u32, @truncate(a_bits))))))), + 0xBD => a_bits, + 0xBE => @as(u64, @as(u32, @bitCast(@as(u32, @truncate(a_bits))))), + 0xBF => a_bits, + else => 0, + }; + } + + fn jitIExtendHelper(op: u32, a_bits: u64) callconv(.c) u64 { + return switch (op) { + 0xC0 => blk: { + const b: i8 = @bitCast(@as(u8, @truncate(a_bits))); + const r: i32 = b; + break :blk @as(u64, @as(u32, @bitCast(r))); + }, + 0xC1 => blk: { + const h: i16 = @bitCast(@as(u16, @truncate(a_bits))); + const r: i32 = h; + break :blk @as(u64, @as(u32, @bitCast(r))); + }, + 0xC2 => blk: { + const b: i8 = @bitCast(@as(u8, @truncate(a_bits))); + const r: i64 = b; + break :blk @as(u64, @bitCast(r)); + }, + 0xC3 => blk: { + const h: i16 = @bitCast(@as(u16, @truncate(a_bits))); + const r: i64 = h; + break :blk @as(u64, @bitCast(r)); + }, + 0xC4 => blk: { + const w: i32 = @bitCast(@as(u32, @truncate(a_bits))); + const r: i64 = w; + break :blk @as(u64, @bitCast(r)); + }, + else => 0, + }; + } + + fn jitTruncSatHelper(subop: u32, a_bits: u64) callconv(.c) u64 { + return switch (subop) { + 0 => truncSatI32FromF64(@as(f64, @floatCast(@as(f32, @bitCast(@as(u32, @truncate(a_bits)))))), false), + 1 => truncSatI32FromF64(@as(f64, @floatCast(@as(f32, @bitCast(@as(u32, @truncate(a_bits)))))), true), + 2 => truncSatI32FromF64(@as(f64, @bitCast(a_bits)), false), + 3 => truncSatI32FromF64(@as(f64, @bitCast(a_bits)), true), + 4 => truncSatI64FromF64(@as(f64, @floatCast(@as(f32, @bitCast(@as(u32, @truncate(a_bits)))))), false), + 5 => truncSatI64FromF64(@as(f64, @floatCast(@as(f32, @bitCast(@as(u32, @truncate(a_bits)))))), true), + 6 => truncSatI64FromF64(@as(f64, @bitCast(a_bits)), false), + 7 => truncSatI64FromF64(@as(f64, @bitCast(a_bits)), true), + else => 0, + }; + } + + fn truncSatI32FromF64(f: f64, unsigned: bool) u64 { + if (std.math.isNan(f)) return 0; + const t = std.math.trunc(f); + if (unsigned) { + if (t <= 0.0) return 0; + if (t >= 4294967295.0) return @as(u64, std.math.maxInt(u32)); + return @as(u64, @as(u32, @intFromFloat(t))); + } + if (t <= -2147483648.0) { + const min_i32: i32 = std.math.minInt(i32); + return @as(u64, @as(u32, @bitCast(min_i32))); + } + if (t >= 2147483647.0) { + const max_i32: i32 = std.math.maxInt(i32); + return @as(u64, @as(u32, @bitCast(max_i32))); + } + const i: i32 = @intFromFloat(t); + return @as(u64, @as(u32, @bitCast(i))); + } + + fn truncSatI64FromF64(f: f64, unsigned: bool) u64 { + if (std.math.isNan(f)) return 0; + const t = std.math.trunc(f); + if (unsigned) { + if (t <= 0.0) return 0; + if (t >= 18446744073709551615.0) return std.math.maxInt(u64); + return @as(u64, @intFromFloat(t)); + } + if (t <= -9223372036854775808.0) { + const min_i64: i64 = std.math.minInt(i64); + return @as(u64, @bitCast(min_i64)); + } + if (t >= 9223372036854775807.0) { + const max_i64: i64 = std.math.maxInt(i64); + return @as(u64, @bitCast(max_i64)); + } + const i: i64 = @intFromFloat(t); + return @as(u64, @bitCast(i)); + } + + fn floatToIntI32F32(v: f32, unsigned: bool) u64 { + const f: f64 = @floatCast(v); + if (std.math.isNan(f) or !std.math.isFinite(f)) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + if (unsigned) { + if (f < 0 or f >= 4294967296.0) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + const t: f64 = std.math.trunc(f); + return @as(u64, @as(u32, @intFromFloat(t))); + } + if (f < -2147483648.0 or f >= 2147483648.0) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + const t: f64 = std.math.trunc(f); + return @as(u64, @bitCast(@as(i64, @intFromFloat(t)))); + } + + fn floatToIntI32F64(v: f64, unsigned: bool) u64 { + const f: f64 = v; + if (std.math.isNan(f) or !std.math.isFinite(f)) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + if (unsigned) { + if (f < 0 or f >= 4294967296.0) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + const t: f64 = std.math.trunc(f); + return @as(u64, @as(u32, @intFromFloat(t))); + } + if (f < -2147483648.0 or f >= 2147483648.0) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + const t: f64 = std.math.trunc(f); + return @as(u64, @bitCast(@as(i64, @intFromFloat(t)))); + } + + fn floatToIntI64F32(v: f32, unsigned: bool) u64 { + const f: f64 = @floatCast(v); + if (std.math.isNan(f) or !std.math.isFinite(f)) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + if (unsigned) { + if (f < 0 or f >= 18446744073709551616.0) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + const t: f64 = std.math.trunc(f); + return @intFromFloat(t); + } + if (f < -9223372036854775808.0 or f >= 9223372036854775808.0) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + const t: f64 = std.math.trunc(f); + return @as(u64, @bitCast(@as(i64, @intFromFloat(t)))); + } + + fn floatToIntI64F64(v: f64, unsigned: bool) u64 { + const f: f64 = v; + if (std.math.isNan(f) or !std.math.isFinite(f)) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + if (unsigned) { + if (f < 0 or f >= 18446744073709551616.0) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + const t: f64 = std.math.trunc(f); + return @intFromFloat(t); + } + if (f < -9223372036854775808.0 or f >= 9223372036854775808.0) { + setThreadTrap(.invalid_conversion_to_integer); + return 0; + } + const t: f64 = std.math.trunc(f); + return @as(u64, @bitCast(@as(i64, @intFromFloat(t)))); + } + + fn jitCallIndirectHelper( + instance: *ModuleInstance, + type_idx: u32, + table_idx: u32, + elem_idx: i32, + arg_ptr: [*]const u64, + argc: u32, + ) callconv(.c) u64 { + if (type_idx >= instance.module.types.len) { + instance.setTrap(.indirect_call_type_mismatch); + return 0; + } + const tab = instance.table orelse { + instance.setTrap(.undefined_element); + return 0; + }; + if (table_idx != 0) { + instance.setTrap(.indirect_call_type_mismatch); + return 0; + } + const elem_u32: u32 = @bitCast(elem_idx); + if (elem_u32 >= tab.elements.len) { + instance.setTrap(.undefined_element); + return 0; + } + const target_idx = tab.elements[elem_u32] orelse { + instance.setTrap(.uninitialized_element); + return 0; + }; + const expected = &instance.module.types[type_idx]; + const actual = instance.functionType(target_idx) catch { + instance.setTrap(.invalid_function); + return 0; + }; + if (!sameFuncType(expected, actual)) { + instance.setTrap(.indirect_call_type_mismatch); + return 0; + } + if (expected.params.len != argc) { + instance.setTrap(.indirect_call_type_mismatch); + return 0; + } + + var small_args: [16]Value = undefined; + var heap_args: []Value = &.{}; + defer if (heap_args.len != 0) instance.allocator.free(heap_args); + const args: []Value = blk: { + if (argc <= small_args.len) { + for (0..argc) |i| { + small_args[i] = switch (expected.params[i]) { + .i32 => .{ .i32 = @bitCast(@as(u32, @truncate(arg_ptr[i]))) }, + .i64 => .{ .i64 = @bitCast(arg_ptr[i]) }, + .f32 => .{ .f32 = @bitCast(@as(u32, @truncate(arg_ptr[i]))) }, + .f64 => .{ .f64 = @bitCast(arg_ptr[i]) }, + }; + } + break :blk small_args[0..argc]; + } + heap_args = instance.allocator.alloc(Value, argc) catch { + instance.setTrap(.call_stack_exhausted); + return 0; + }; + for (0..argc) |i| { + heap_args[i] = switch (expected.params[i]) { + .i32 => .{ .i32 = @bitCast(@as(u32, @truncate(arg_ptr[i]))) }, + .i64 => .{ .i64 = @bitCast(arg_ptr[i]) }, + .f32 => .{ .f32 = @bitCast(@as(u32, @truncate(arg_ptr[i]))) }, + .f64 => .{ .f64 = @bitCast(arg_ptr[i]) }, + }; + } + break :blk heap_args; + }; + const out = instance.executeFunction(target_idx, args) catch { + instance.setTrap(.call_stack_exhausted); + return 0; + }; + defer instance.allocator.free(out); + if (out.len == 0) return 0; + return switch (out[0]) { + .i32 => |v| @as(u64, @bitCast(@as(i64, v))), + .i64 => |v| @as(u64, @bitCast(v)), + .f32 => |v| @as(u64, @as(u32, @bitCast(v))), + .f64 => |v| @as(u64, @bitCast(v)), + }; + } + + + fn functionType(self: *const ModuleInstance, func_idx: u32) !*const module.FuncType { + var import_func_count: u32 = 0; + for (self.module.imports) |imp| { + if (imp.desc == .func) { + if (import_func_count == func_idx) return &self.module.types[imp.desc.func]; + import_func_count += 1; + } + } + const local_idx = func_idx - import_func_count; + if (local_idx >= self.module.functions.len) return error.InvalidFunctionIndex; + const type_idx = self.module.functions[local_idx]; + if (type_idx >= self.module.types.len) return error.InvalidTypeIndex; + return &self.module.types[type_idx]; + } + + fn applyDataSegments(self: *ModuleInstance) !void { + for (self.module.datas) |seg| { + if (seg.kind != .active) continue; + if (seg.memory_idx != 0 or self.memory == null) return error.UndefinedMemory; + const addr = try constExprI32(self, seg.offset.?, self.globals.len); + const start = @as(usize, @intCast(addr)); + const end = start + seg.bytes.len; + var mem = &self.memory.?; + if (end > mem.bytes.len) return error.DataOutOfBounds; + @memcpy(mem.bytes[start..end], seg.bytes); + } + } + + fn applyElementSegments(self: *ModuleInstance, total_funcs: u32) !void { + if (self.module.elements.len == 0) return; + + for (self.module.elements) |seg| { + if (seg.table_idx != 0 or self.table == null) return error.UndefinedTable; + const base = try constExprI32(self, seg.offset, self.globals.len); + const start = @as(usize, @intCast(base)); + const end = start + seg.func_indices.len; + var tab = &self.table.?; + if (end > tab.elements.len) return error.ElementOutOfBounds; + for (seg.func_indices, 0..) |fidx, i| { + if (fidx >= total_funcs) return error.InvalidFunctionIndex; + tab.elements[start + i] = fidx; + } + } + } +}; + +fn evalConstExpr(inst: *const ModuleInstance, expr: module.ConstExpr, initialized_globals: usize) !runtime.Value { + return switch (expr) { + .i32_const => |v| .{ .i32 = v }, + .i64_const => |v| .{ .i64 = v }, + .f32_const => |v| .{ .f32 = v }, + .f64_const => |v| .{ .f64 = v }, + .global_get => |idx| blk: { + if (idx >= initialized_globals) return error.UndefinedGlobal; + break :blk inst.globals[idx]; + }, + }; +} + +fn constExprI32(inst: *const ModuleInstance, expr: module.ConstExpr, initialized_globals: usize) !i32 { + const v = try evalConstExpr(inst, expr, initialized_globals); + return switch (v) { + .i32 => |n| n, + else => error.ConstExprTypeMismatch, + }; +} + +fn valueMatchesType(v: runtime.Value, ty: module.ValType) bool { + return switch (v) { + .i32 => ty == .i32, + .i64 => ty == .i64, + .f32 => ty == .f32, + .f64 => ty == .f64, + }; +} + +fn sameFuncType(a: *const module.FuncType, b: *const module.FuncType) bool { + if (a.params.len != b.params.len or a.results.len != b.results.len) return false; + for (a.params, 0..) |p, i| if (p != b.params[i]) return false; + for (a.results, 0..) |r, i| if (r != b.results[i]) return false; + return true; +} + +test "instantiate applies data segment to memory" { + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x05, 0x03, 0x01, 0x00, 0x01, + 0x0b, 0x09, 0x01, 0x00, 0x41, 0x04, 0x0b, 0x03, 0x61, 0x62, 0x63, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const slice = try inst.memorySlice(4, 3); + try std.testing.expectEqualStrings("abc", slice); +} + +test "instantiate initializes globals including global.get const expr" { + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x06, 0x0b, 0x02, 0x7f, 0x00, 0x41, 0x07, 0x0b, 0x7f, 0x00, 0x23, 0x00, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + try std.testing.expectEqual(@as(usize, 2), inst.globals.len); + try std.testing.expectEqual(@as(i32, 7), inst.globals[0].i32); + try std.testing.expectEqual(@as(i32, 7), inst.globals[1].i32); +} + +test "instantiate applies element segment to table" { + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x04, 0x01, 0x60, 0x00, 0x00, + 0x03, 0x02, 0x01, 0x00, + 0x04, 0x04, 0x01, 0x70, 0x00, 0x01, + 0x09, 0x07, 0x01, 0x00, 0x41, 0x00, 0x0b, 0x01, 0x00, + 0x0a, 0x04, 0x01, 0x02, 0x00, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + try std.testing.expect(inst.table != null); + try std.testing.expectEqual(@as(?u32, 0), inst.table.?.elements[0]); +} + +test "instantiate rejects data segment out of bounds" { + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x05, 0x03, 0x01, 0x00, 0x01, + 0x0b, 0x0a, 0x01, 0x00, 0x41, 0xff, 0xff, 0x03, 0x0b, 0x02, 0x61, 0x62, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + + try std.testing.expectError(error.DataOutOfBounds, ModuleInstance.instantiate(ally, &mod, &imports)); +} + +test "callExport executes recursive fib function" { + const fib_wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x06, 0x01, 0x60, 0x01, 0x7f, 0x01, 0x7f, + 0x03, 0x02, 0x01, 0x00, + 0x07, 0x07, 0x01, 0x03, 0x66, 0x69, 0x62, 0x00, 0x00, + 0x0a, 0x1e, 0x01, 0x1c, 0x00, 0x20, 0x00, 0x41, 0x02, 0x48, 0x04, + 0x7f, 0x20, 0x00, 0x05, 0x20, 0x00, 0x41, 0x01, 0x6b, 0x10, 0x00, + 0x20, 0x00, 0x41, 0x02, 0x6b, 0x10, 0x00, 0x6a, 0x0b, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &fib_wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + if (@import("builtin").cpu.arch == .aarch64 or @import("builtin").cpu.arch == .x86_64) { + try std.testing.expect(inst.functions[0] == .jit); + } + + const out = try inst.callExport("fib", &.{.{ .i32 = 10 }}); + defer ally.free(out); + try std.testing.expectEqual(@as(usize, 1), out.len); + try std.testing.expectEqual(@as(i32, 55), out[0].i32); +} + +test "callExport invokes host import" { + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type section: + // 0: (i32, i32) -> i32 + // 1: () -> i32 + 0x01, 0x0b, 0x02, 0x60, 0x02, 0x7f, 0x7f, 0x01, 0x7f, 0x60, 0x00, 0x01, 0x7f, + // import section: env.add, type 0 + 0x02, 0x0b, 0x01, 0x03, 0x65, 0x6e, 0x76, 0x03, 0x61, 0x64, 0x64, 0x00, 0x00, + // function section: one function, type 1 + 0x03, 0x02, 0x01, 0x01, + // export section: run -> function index 1 (after imported func) + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x01, + // code: i32.const 7; i32.const 9; call 0; end + 0x0a, 0x0a, 0x01, 0x08, 0x00, 0x41, 0x07, 0x41, 0x09, 0x10, 0x00, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + + var imports = ImportSet.init(ally); + defer imports.deinit(); + try imports.addFunc(.env, .add, struct { + fn add(a: i32, b: i32) i32 { + return a + b; + } + }.add); + + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + if (@import("builtin").cpu.arch == .aarch64 or @import("builtin").cpu.arch == .x86_64) { + try std.testing.expect(inst.functions[1] == .jit); + } + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 16), out[0].i32); +} + +test "callExport host import can read memory string via instance pointer" { + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type section: + // 0: (i32, i32) -> () + // 1: () -> () + 0x01, 0x09, 0x02, 0x60, 0x02, 0x7f, 0x7f, 0x00, 0x60, 0x00, 0x00, + // import section: env.log, type 0 + 0x02, 0x0b, 0x01, 0x03, 0x65, 0x6e, 0x76, 0x03, 0x6c, 0x6f, 0x67, 0x00, 0x00, + // function section: one local function, type 1 + 0x03, 0x02, 0x01, 0x01, + // memory section: min 1 page + 0x05, 0x03, 0x01, 0x00, 0x01, + // export section: run -> function index 1 + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x01, + // code section: + // run: i32.const 16; i32.const 5; call log; end + 0x0a, 0x0a, 0x01, 0x08, 0x00, 0x41, 0x10, 0x41, 0x05, 0x10, 0x00, 0x0b, + // data section: memory[16..21] = "hello" + 0x0b, 0x0b, 0x01, 0x00, 0x41, 0x10, 0x0b, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, + }; + + const LogState = struct { + var called = false; + var text: [5]u8 = [_]u8{0} ** 5; + + fn log(instance_ptr: *anyopaque, ptr: i32, len: i32) !void { + const inst: *ModuleInstance = @ptrCast(@alignCast(instance_ptr)); + const bytes = try inst.memorySlice(@intCast(ptr), @intCast(len)); + @memcpy(text[0..bytes.len], bytes); + called = true; + } + }; + LogState.called = false; + @memset(LogState.text[0..], 0); + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + try imports.addFunc(.env, .log, LogState.log); + + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(usize, 0), out.len); + try std.testing.expect(LogState.called); + try std.testing.expectEqualStrings("hello", LogState.text[0..]); +} + +test "aarch64 jit compiles simple i32 function during instantiation" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type: (i32, i32) -> i32 + 0x01, 0x07, 0x01, 0x60, 0x02, 0x7f, 0x7f, 0x01, 0x7f, + // function: one local function type 0 + 0x03, 0x02, 0x01, 0x00, + // export: sum -> func 0 + 0x07, 0x07, 0x01, 0x03, 0x73, 0x75, 0x6d, 0x00, 0x00, + // code: local.get 0; local.get 1; i32.add; end + 0x0a, 0x09, 0x01, 0x07, 0x00, 0x20, 0x00, 0x20, 0x01, 0x6a, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + try std.testing.expect(inst.functions[0] == .jit); + const out = try inst.callExport("sum", &.{ .{ .i32 = 20 }, .{ .i32 = 22 } }); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 42), out[0].i32); +} + +test "aarch64 jit handles locals and extended i32 ops" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type: (i32, i32) -> i32 + 0x01, 0x07, 0x01, 0x60, 0x02, 0x7f, 0x7f, 0x01, 0x7f, + // function: one local function type 0 + 0x03, 0x02, 0x01, 0x00, + // export: run -> func 0 + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + // code: + // local i32 + // local.get 0 + // local.get 1 + // i32.mul + // local.set 2 + // local.get 2 + // i32.const 3 + // i32.shl + // i32.const 4 + // i32.or + // end + 0x0a, 0x15, 0x01, 0x13, 0x01, 0x01, 0x7f, 0x20, 0x00, 0x20, 0x01, 0x6c, + 0x21, 0x02, 0x20, 0x02, 0x41, 0x03, 0x74, 0x41, 0x04, 0x72, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{ .{ .i32 = 2 }, .{ .i32 = 5 } }); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 84), out[0].i32); +} + +test "aarch64 jit handles i32 ctz" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type: () -> i32 + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + // function: one local function type 0 + 0x03, 0x02, 0x01, 0x00, + // export: run -> func 0 + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + // code: i32.const 16; i32.ctz; end + 0x0a, 0x07, 0x01, 0x05, 0x00, 0x41, 0x10, 0x68, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 4), out[0].i32); +} + +test "aarch64 jit handles i32 popcnt" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type: () -> i32 + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + // function: one local function type 0 + 0x03, 0x02, 0x01, 0x00, + // export: run -> func 0 + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + // code: i32.const 0xf0f0; i32.popcnt; end + 0x0a, 0x09, 0x01, 0x07, 0x00, 0x41, 0xf0, 0xe1, 0x03, 0x69, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 8), out[0].i32); +} + +test "aarch64 jit handles global.get/global.set for i32 globals" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type: () -> i32 + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + // function section: one local function type 0 + 0x03, 0x02, 0x01, 0x00, + // global section: (mut i32) = 3 + 0x06, 0x06, 0x01, 0x7f, 0x01, 0x41, 0x03, 0x0b, + // export section: run -> func 0 + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + // code: + // global.get 0 + // i32.const 4 + // i32.add + // global.set 0 + // global.get 0 + // end + 0x0a, 0x0d, 0x01, 0x0b, 0x00, 0x23, 0x00, 0x41, 0x04, 0x6a, 0x24, 0x00, 0x23, 0x00, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 7), out[0].i32); +} + +test "aarch64 jit handles i32 memory load/store and memory.grow/size" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type: () -> i32 + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + // function section: one local function type 0 + 0x03, 0x02, 0x01, 0x00, + // memory section: min=1, max=2 + 0x05, 0x04, 0x01, 0x01, 0x01, 0x02, + // export section: run -> func 0 + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + // code: + // i32.const 0 + // i32.const 42 + // i32.store align=2 offset=0 + // i32.const 0 + // i32.load align=2 offset=0 + // i32.const 1 + // memory.grow 0 + // drop + // memory.size 0 + // i32.add + // end + 0x0a, 0x18, 0x01, 0x16, 0x00, + 0x41, 0x00, 0x41, 0x2a, 0x36, 0x02, 0x00, + 0x41, 0x00, 0x28, 0x02, 0x00, + 0x41, 0x01, 0x40, 0x00, 0x1a, 0x3f, 0x00, 0x6a, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 44), out[0].i32); +} + +test "aarch64 jit handles call_indirect" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x06, 0x01, 0x60, 0x01, 0x7f, 0x01, 0x7f, + 0x03, 0x03, 0x02, 0x00, 0x00, + 0x04, 0x04, 0x01, 0x70, 0x00, 0x01, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x01, + 0x09, 0x07, 0x01, 0x00, 0x41, 0x00, 0x0b, 0x01, 0x00, + 0x0a, 0x13, 0x02, + 0x07, 0x00, 0x20, 0x00, 0x41, 0x01, 0x6a, 0x0b, + 0x09, 0x00, 0x20, 0x00, 0x41, 0x00, 0x11, 0x00, 0x00, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{.{ .i32 = 41 }}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 42), out[0].i32); +} + +test "aarch64 jit handles loop with br and br_if" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x06, 0x01, 0x60, + 0x01, 0x7f, 0x01, 0x7f, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, 0x03, + 0x72, 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x27, 0x01, 0x25, 0x01, 0x01, 0x7f, + 0x41, 0x01, 0x21, 0x01, 0x02, 0x40, 0x03, 0x40, 0x20, 0x00, 0x45, 0x0d, + 0x01, 0x20, 0x01, 0x20, 0x00, 0x6c, 0x21, 0x01, 0x20, 0x00, 0x41, 0x01, + 0x6b, 0x21, 0x00, 0x0c, 0x00, 0x0b, 0x0b, 0x20, 0x01, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{.{ .i32 = 5 }}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 120), out[0].i32); +} + +test "aarch64 jit handles br_table" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x06, 0x01, 0x60, + 0x01, 0x7f, 0x01, 0x7f, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, 0x03, + 0x72, 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x26, 0x01, 0x24, 0x01, 0x01, 0x7f, + 0x41, 0x1e, 0x21, 0x01, 0x02, 0x40, 0x02, 0x40, 0x02, 0x40, 0x20, 0x00, + 0x0e, 0x02, 0x00, 0x01, 0x02, 0x0b, 0x41, 0x0a, 0x21, 0x01, 0x0c, 0x01, + 0x0b, 0x41, 0x14, 0x21, 0x01, 0x0b, 0x20, 0x01, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out0 = try inst.callExport("run", &.{.{ .i32 = 0 }}); + defer ally.free(out0); + try std.testing.expectEqual(@as(i32, 10), out0[0].i32); + + const out1 = try inst.callExport("run", &.{.{ .i32 = 1 }}); + defer ally.free(out1); + try std.testing.expectEqual(@as(i32, 20), out1[0].i32); + + const out2 = try inst.callExport("run", &.{.{ .i32 = 99 }}); + defer ally.free(out2); + try std.testing.expectEqual(@as(i32, 30), out2[0].i32); +} + +test "aarch64 jit handles i32 load8/load16 and store8/store16" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, + 0x00, 0x01, 0x7f, 0x03, 0x02, 0x01, 0x00, 0x05, 0x03, 0x01, 0x00, 0x01, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x20, 0x01, + 0x1e, 0x00, 0x41, 0x00, 0x41, 0xff, 0x01, 0x3a, 0x00, 0x00, 0x41, 0x01, + 0x41, 0xff, 0xff, 0x01, 0x3b, 0x01, 0x00, 0x41, 0x00, 0x2d, 0x00, 0x00, + 0x41, 0x01, 0x2e, 0x01, 0x00, 0x6a, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 33022), out[0].i32); +} + +test "aarch64 jit compiles unreachable in dead branch" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x06, 0x01, 0x60, + 0x01, 0x7f, 0x01, 0x7f, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, 0x03, + 0x72, 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x0f, 0x01, 0x0d, 0x00, 0x20, 0x00, + 0x04, 0x7f, 0x41, 0x0b, 0x05, 0x00, 0x41, 0x16, 0x0b, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{.{ .i32 = 1 }}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 11), out[0].i32); +} + +test "aarch64 jit handles i32 div and rem semantics" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, + 0x00, 0x01, 0x7f, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, 0x03, 0x72, + 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x1b, 0x01, 0x19, 0x00, 0x41, 0x79, 0x41, + 0x03, 0x6d, 0x41, 0x07, 0x41, 0x03, 0x6e, 0x41, 0x79, 0x41, 0x03, 0x6f, + 0x41, 0x07, 0x41, 0x03, 0x70, 0x6a, 0x6a, 0x6a, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 0), out[0].i32); +} + +test "aarch64 jit handles i64 params results and arithmetic" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x07, 0x01, 0x60, + 0x02, 0x7e, 0x7e, 0x01, 0x7e, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, + 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x0c, 0x01, 0x0a, 0x00, 0x20, + 0x00, 0x20, 0x01, 0x7e, 0x42, 0x03, 0x7c, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{ .{ .i64 = 6 }, .{ .i64 = 7 } }); + defer ally.free(out); + try std.testing.expectEqual(@as(i64, 45), out[0].i64); +} + +test "aarch64 jit handles i64 comparisons" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x06, 0x01, 0x60, + 0x01, 0x7e, 0x01, 0x7f, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, 0x03, + 0x72, 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x09, 0x01, 0x07, 0x00, 0x20, 0x00, + 0x42, 0x7f, 0x55, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{.{ .i64 = 0 }}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 1), out[0].i32); +} + +test "aarch64 jit handles f32 arithmetic" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x07, 0x01, 0x60, + 0x02, 0x7d, 0x7d, 0x01, 0x7d, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, + 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x0f, 0x01, 0x0d, 0x00, 0x20, + 0x00, 0x20, 0x01, 0x94, 0x43, 0x00, 0x00, 0xc0, 0x3f, 0x92, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{ .{ .f32 = 2.0 }, .{ .f32 = 3.0 } }); + defer ally.free(out); + try std.testing.expectApproxEqAbs(@as(f32, 7.5), out[0].f32, 0.0001); +} + +test "aarch64 jit handles f64 comparison to i32" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x06, 0x01, 0x60, + 0x01, 0x7c, 0x01, 0x7f, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, 0x03, + 0x72, 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x10, 0x01, 0x0e, 0x00, 0x20, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x66, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{.{ .f64 = 2.0 }}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 1), out[0].i32); +} + +test "aarch64 jit handles f64 to i64 conversion" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, + 0x00, 0x01, 0x7e, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, 0x03, 0x72, + 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x0e, 0x01, 0x0c, 0x00, 0x44, 0xcd, 0xcc, + 0xcc, 0xcc, 0xcc, 0xcc, 0x23, 0x40, 0xb0, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i64, 9), out[0].i64); +} + +test "aarch64 jit handles mixed i64 and f64 memory load/store" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, + 0x00, 0x01, 0x7e, 0x03, 0x02, 0x01, 0x00, 0x05, 0x03, 0x01, 0x00, 0x01, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x2d, 0x01, + 0x2b, 0x00, 0x41, 0x08, 0x42, 0x88, 0xef, 0x99, 0xab, 0xc5, 0xe8, 0x8c, + 0x91, 0x11, 0x37, 0x03, 0x00, 0x41, 0x18, 0x44, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0c, 0x40, 0x39, 0x03, 0x00, 0x41, 0x08, 0x29, 0x03, 0x00, + 0x41, 0x18, 0x2b, 0x03, 0x00, 0xbd, 0x85, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + const expected: i64 = 0x1122334455667788 ^ @as(i64, @bitCast(@as(u64, @bitCast(@as(f64, 3.5))))); + try std.testing.expectEqual(expected, out[0].i64); +} + +test "aarch64 jit handles reinterpret roundtrip f32<->i32" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, + 0x00, 0x01, 0x7f, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, 0x03, 0x72, + 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x11, 0x01, 0x0f, 0x00, 0x43, 0x00, 0x00, + 0x80, 0x3f, 0xbc, 0xbe, 0x43, 0x00, 0x00, 0x80, 0x3f, 0x5b, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 1), out[0].i32); +} + +test "aarch64 jit handles chained float conversions" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, + 0x00, 0x01, 0x7c, 0x03, 0x02, 0x01, 0x00, 0x07, 0x07, 0x01, 0x03, 0x72, + 0x75, 0x6e, 0x00, 0x00, 0x0a, 0x09, 0x01, 0x07, 0x00, 0x42, 0x2a, 0xb9, + 0xb6, 0xbb, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectApproxEqAbs(@as(f64, 42.0), out[0].f64, 0.0001); +} + +test "jitConvertHelper f64<->f32 chain" { + const i64_bits: u64 = @bitCast(@as(i64, 42)); + const f64_bits = ModuleInstance.jitConvertHelper(0xB9, i64_bits); + const f32_bits = ModuleInstance.jitConvertHelper(0xB6, f64_bits); + const roundtrip = ModuleInstance.jitConvertHelper(0xBB, f32_bits); + try std.testing.expectApproxEqAbs(@as(f64, 42.0), @as(f64, @bitCast(roundtrip)), 0.0001); +} + +test "aarch64 jit handles i64 select" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7e, + 0x03, 0x02, 0x01, 0x00, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + 0x0a, 0x13, 0x01, 0x11, 0x00, + 0x42, 0x07, 0x42, 0x09, 0x41, 0x00, 0x1b, + 0x42, 0x07, 0x42, 0x09, 0x41, 0x01, 0x1b, + 0x7c, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i64, 16), out[0].i64); +} + +test "aarch64 jit handles typed i64 select" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7e, + 0x03, 0x02, 0x01, 0x00, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + 0x0a, 0x10, 0x01, 0x0e, 0x00, + 0x42, 0x05, 0x42, 0x08, 0x41, 0x00, 0x1c, 0x01, 0x7e, + 0x42, 0x02, 0x7c, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i64, 10), out[0].i64); +} + +test "aarch64 jit handles i32 sign-extension ops" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + 0x03, 0x02, 0x01, 0x00, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + 0x0a, 0x0e, 0x01, 0x0c, 0x00, + 0x41, 0x80, 0x01, 0xc0, + 0x41, 0x80, 0x80, 0x02, 0xc1, + 0x6a, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, -32896), out[0].i32); +} + +test "aarch64 jit handles i64 sign-extension ops" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7e, + 0x03, 0x02, 0x01, 0x00, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + 0x0a, 0x16, 0x01, 0x14, 0x00, + 0x42, 0x80, 0x01, 0xc2, + 0x42, 0x80, 0x80, 0x02, 0xc3, + 0x7c, + 0x42, 0x80, 0x80, 0x80, 0x80, 0x08, 0xc4, + 0x7c, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i64, -2147516544), out[0].i64); +} + +test "aarch64 jit handles memory.fill" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + 0x03, 0x02, 0x01, 0x00, + 0x05, 0x03, 0x01, 0x00, 0x01, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + 0x0a, 0x12, 0x01, 0x10, 0x00, + 0x41, 0x10, 0x41, 0x7f, 0x41, 0x04, 0xfc, 0x0b, 0x00, + 0x41, 0x12, 0x2d, 0x00, 0x00, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 255), out[0].i32); +} + +test "aarch64 jit handles memory.copy" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + 0x03, 0x02, 0x01, 0x00, + 0x05, 0x03, 0x01, 0x00, 0x01, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + 0x0a, 0x29, 0x01, 0x27, 0x00, + 0x41, 0x00, 0x41, 0xaa, 0x01, 0x3a, 0x00, 0x00, + 0x41, 0x01, 0x41, 0xbb, 0x01, 0x3a, 0x00, 0x00, + 0x41, 0x04, 0x41, 0x00, 0x41, 0x02, 0xfc, 0x0a, 0x00, 0x00, + 0x41, 0x04, 0x2d, 0x00, 0x00, 0x41, 0x05, 0x2d, 0x00, 0x00, 0x6a, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 357), out[0].i32); +} + +test "aarch64 jit handles i32.trunc_sat_f32_s NaN" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + 0x03, 0x02, 0x01, 0x00, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + 0x0a, 0x0b, 0x01, 0x09, 0x00, + 0x43, 0x00, 0x00, 0xc0, 0x7f, 0xfc, 0x00, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 0), out[0].i32); +} + +test "aarch64 jit handles i64.trunc_sat_f64_u infinity" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + 0x03, 0x02, 0x01, 0x00, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + 0x0a, 0x12, 0x01, 0x10, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x7f, + 0xfc, 0x07, 0x42, 0x7f, 0x51, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 1), out[0].i32); +} + +test "aarch64 jit handles table.size" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + 0x03, 0x02, 0x01, 0x00, + 0x04, 0x04, 0x01, 0x70, 0x00, 0x03, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + 0x0a, 0x07, 0x01, 0x05, 0x00, 0xfc, 0x10, 0x00, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 3), out[0].i32); +} + +test "aarch64 jit handles memory.init and data.drop with passive segment" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + 0x03, 0x02, 0x01, 0x00, + 0x05, 0x03, 0x01, 0x00, 0x01, + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + 0x0a, 0x1c, 0x01, 0x1a, 0x00, + 0x41, 0x08, 0x41, 0x01, 0x41, 0x02, 0xfc, 0x08, 0x00, 0x00, 0xfc, 0x09, + 0x00, 0x41, 0x08, 0x2d, 0x00, 0x00, 0x41, 0x09, 0x2d, 0x00, 0x00, 0x6a, + 0x0b, + 0x0b, 0x07, 0x01, 0x01, 0x04, 0x61, 0x62, 0x63, 0x64, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + const out = try inst.callExport("run", &.{}); + defer ally.free(out); + try std.testing.expectEqual(@as(i32, 197), out[0].i32); // 'b' + 'c' +} + +test "jitIExtendHelper sign extension semantics" { + try std.testing.expectEqual(@as(u64, @as(u32, @bitCast(@as(i32, -128)))), ModuleInstance.jitIExtendHelper(0xC0, 0x80)); + try std.testing.expectEqual(@as(u64, @as(u32, @bitCast(@as(i32, -32768)))), ModuleInstance.jitIExtendHelper(0xC1, 0x8000)); + try std.testing.expectEqual(@as(u64, @bitCast(@as(i64, -128))), ModuleInstance.jitIExtendHelper(0xC2, 0x80)); + try std.testing.expectEqual(@as(u64, @bitCast(@as(i64, -32768))), ModuleInstance.jitIExtendHelper(0xC3, 0x8000)); + try std.testing.expectEqual(@as(u64, @bitCast(@as(i64, -2147483648))), ModuleInstance.jitIExtendHelper(0xC4, 0x80000000)); +} + +test "jit trap API reports divide by zero and can be cleared" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type: () -> i32 + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + // function section + 0x03, 0x02, 0x01, 0x00, + // export "run" + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + // code: i32.const 1; i32.const 0; i32.div_s; end + 0x0a, 0x09, 0x01, 0x07, 0x00, 0x41, 0x01, 0x41, 0x00, 0x6d, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + try std.testing.expectError(error.WasmTrap, inst.callExport("run", &.{})); + try std.testing.expectEqual(trap.TrapCode.integer_divide_by_zero, inst.lastTrapCode().?); + inst.clearLastTrap(); + try std.testing.expect(inst.lastTrapCode() == null); +} + +test "jit trap API reports memory out of bounds" { + if (@import("builtin").cpu.arch != .aarch64 and @import("builtin").cpu.arch != .x86_64) return error.SkipZigTest; + + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type: () -> i32 + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + // function section + 0x03, 0x02, 0x01, 0x00, + // memory min 1 page + 0x05, 0x03, 0x01, 0x00, 0x01, + // export "run" + 0x07, 0x07, 0x01, 0x03, 0x72, 0x75, 0x6e, 0x00, 0x00, + // code: i32.const 65535; i32.load align=2 off=0; end + 0x0a, 0x0b, 0x01, 0x09, 0x00, 0x41, 0xff, 0xff, 0x03, 0x28, 0x02, 0x00, 0x0b, + }; + + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + var imports = ImportSet.init(ally); + defer imports.deinit(); + var inst = try ModuleInstance.instantiate(ally, &mod, &imports); + defer inst.deinit(); + + try std.testing.expectError(error.WasmTrap, inst.callExport("run", &.{})); + try std.testing.expectEqual(trap.TrapCode.memory_out_of_bounds, inst.lastTrapCode().?); +} diff --git a/src/wasm/jit/aarch64.zig b/src/wasm/jit/aarch64.zig new file mode 100644 index 0000000..3050dcd --- /dev/null +++ b/src/wasm/jit/aarch64.zig @@ -0,0 +1,1383 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const binary = @import("../binary.zig"); +const module = @import("../module.zig"); +const codebuf = @import("codebuf.zig"); + +pub const JitResult = struct { + buf: codebuf.CodeBuffer, + arity: u8, +}; + +pub const HelperAddrs = struct { + call: usize, + @"unreachable": usize, + global_get: usize, + global_set: usize, + mem_load: usize, + mem_store: usize, + i32_unary: usize, + i32_cmp: usize, + i32_binary: usize, + i32_div_s: usize, + i32_div_u: usize, + i32_rem_s: usize, + i32_rem_u: usize, + i64_eqz: usize, + i64_cmp: usize, + i64_unary: usize, + i64_binary: usize, + f32_cmp: usize, + f64_cmp: usize, + f32_unary: usize, + f32_binary: usize, + f64_unary: usize, + f64_binary: usize, + convert: usize, + trunc_sat: usize, + i_extend: usize, + memory_init: usize, + data_drop: usize, + memory_copy: usize, + memory_fill: usize, + table_size: usize, + memory_size: usize, + memory_grow: usize, + call_indirect: usize, +}; + +pub fn compileFunctionI32( + allocator: std.mem.Allocator, + mod: *const module.Module, + num_imported_funcs: u32, + current_func_idx: u32, + body: *const module.FunctionBody, + ft: *const module.FuncType, + helpers: HelperAddrs, +) !?JitResult { + const frame_size_bytes: usize = 0x400; + const local_base_bytes: usize = 32; + if (builtin.cpu.arch != .aarch64) return null; + _ = current_func_idx; + if (ft.results.len > 1) return null; + for (ft.params) |p| if (!(p == .i32 or p == .i64 or p == .f32 or p == .f64)) return null; + + var local_count: usize = ft.params.len; + for (body.locals) |decl| { + if (!(decl.valtype == .i32 or decl.valtype == .i64 or decl.valtype == .f32 or decl.valtype == .f64)) return null; + local_count += decl.count; + } + var local_types = try allocator.alloc(module.ValType, local_count); + defer allocator.free(local_types); + for (ft.params, 0..) |p, i| local_types[i] = p; + var lt_i: usize = ft.params.len; + for (body.locals) |decl| { + var j: u32 = 0; + while (j < decl.count) : (j += 1) { + local_types[lt_i] = decl.valtype; + lt_i += 1; + } + } + + const operand_base_bytes = std.mem.alignForward(usize, local_base_bytes + local_count * 8, 16); + if (operand_base_bytes >= frame_size_bytes) return null; + const max_stack_depth = (frame_size_bytes - operand_base_bytes) / 8; + if (max_stack_depth == 0) return null; + if (local_base_bytes + local_count * 8 > frame_size_bytes) return null; + + var buf = try codebuf.CodeBuffer.init(allocator, 8192); + errdefer buf.deinit(); + + emitPrologue(&buf, @intCast(ft.params.len), @intCast(local_count), @intCast(operand_base_bytes / 8)); + + var cx = Context{ + .allocator = allocator, + .mod = mod, + .num_imported_funcs = num_imported_funcs, + .helpers = helpers, + .buf = &buf, + .stack_depth = 0, + .max_stack_depth = max_stack_depth, + .local_count = @intCast(local_count), + .result_type = if (ft.results.len == 1) ft.results[0] else null, + .local_types = local_types, + .control = .empty, + }; + defer { + for (cx.control.items) |*fr| fr.end_patches.deinit(allocator); + cx.control.deinit(allocator); + } + + const fn_arity: u8 = if (ft.results.len == 1) 1 else 0; + const fn_type: ?module.ValType = if (ft.results.len == 1) ft.results[0] else null; + try cx.control.append(allocator, .{ + .kind = .block, + .entry_depth = 0, + .label_arity = fn_arity, + .label_type = fn_type, + .end_arity = fn_arity, + .end_type = fn_type, + .loop_head_pos = 0, + .end_patches = .empty, + }); + + var pos: usize = 0; + const end_kind = compileBlock(&cx, body.code, &pos, false) catch return null; + if (end_kind != .hit_end) return null; + if (ft.results.len == 1) { + if (cx.stack_depth != 1) return null; + switch (ft.results[0]) { + .i32, .f32 => try popW(&cx, 0), // return in w0 + .i64, .f64 => try popX(&cx, 0), // return raw 64-bit lane in x0 + } + } else { + if (cx.stack_depth != 0) return null; + } + emitEpilogueAndRet(&buf); + try buf.finalize(); + return .{ .buf = buf, .arity = 0 }; +} + +const EndKind = enum { hit_end, hit_else }; + +const Context = struct { + allocator: std.mem.Allocator, + mod: *const module.Module, + num_imported_funcs: u32, + helpers: HelperAddrs, + buf: *codebuf.CodeBuffer, + stack_depth: usize, + max_stack_depth: usize, + local_count: u32, + result_type: ?module.ValType, + local_types: []const module.ValType, + control: std.ArrayList(ControlFrame), +}; + +const ControlKind = enum { block, loop, @"if" }; + +const ControlFrame = struct { + kind: ControlKind, + entry_depth: usize, + label_arity: u8, + label_type: ?module.ValType, + end_arity: u8, + end_type: ?module.ValType, + loop_head_pos: usize, + end_patches: std.ArrayList(usize), +}; + +fn compileBlock(cx: *Context, code: []const u8, pos: *usize, allow_else: bool) !EndKind { + while (pos.* < code.len) { + const op = code[pos.*]; + pos.* += 1; + switch (op) { + 0x0B => { + const fr = currentFrame(cx); + const end_pos = cx.buf.cursor(); + for (fr.end_patches.items) |patch_pos| patchB(cx.buf, patch_pos, end_pos); + try setStackDepth(cx, fr.entry_depth + fr.end_arity); + return .hit_end; + }, + 0x05 => { + if (!allow_else) return error.MalformedControlFlow; + return .hit_else; + }, + 0x0C => { // br + const depth = try binary.readULEB128(u32, code, pos); + try emitBrToDepth(cx, depth); + }, + 0x0D => { // br_if + const depth = try binary.readULEB128(u32, code, pos); + try popW(cx, 11); // condition + const not_taken = emitCBZPlaceholder(cx.buf, 11); + const fallthrough_depth = cx.stack_depth; + try emitBrToDepth(cx, depth); + cx.stack_depth = fallthrough_depth; + patchCBZ(cx.buf, not_taken, cx.buf.cursor()); + }, + 0x0E => { // br_table + const n = try binary.readULEB128(u32, code, pos); + const table = try cx.allocator.alloc(u32, n + 1); + defer cx.allocator.free(table); + for (table) |*d| d.* = try binary.readULEB128(u32, code, pos); + + try popW(cx, 11); // selector + const fallthrough_depth = cx.stack_depth; + var i: u32 = 0; + while (i < n) : (i += 1) { + if (i > 4095) return error.UnsupportedOpcode; + emitCmpImmW(cx.buf, 11, @intCast(i)); + emitCsetW(cx.buf, 12, .eq); + const skip = emitCBZPlaceholder(cx.buf, 12); + try emitBrToDepth(cx, table[i]); + cx.stack_depth = fallthrough_depth; + patchCBZ(cx.buf, skip, cx.buf.cursor()); + } + try emitBrToDepth(cx, table[n]); // default + }, + 0x20 => { // local.get + const idx = try binary.readULEB128(u32, code, pos); + if (idx >= cx.local_count) return error.UnsupportedOpcode; + switch (cx.local_types[idx]) { + .i32 => { + emitLdrWImm(cx.buf, 9, 31, @intCast(8 + idx * 2)); + try pushW(cx, 9); + }, + .i64 => { + emitLdrXImm(cx.buf, 9, 31, @intCast(4 + idx)); + try pushX(cx, 9); + }, + .f32 => { + emitLdrWImm(cx.buf, 9, 31, @intCast(8 + idx * 2)); + try pushW(cx, 9); + }, + .f64 => { + emitLdrXImm(cx.buf, 9, 31, @intCast(4 + idx)); + try pushX(cx, 9); + }, + } + }, + 0x21 => { // local.set + const idx = try binary.readULEB128(u32, code, pos); + if (idx >= cx.local_count) return error.UnsupportedOpcode; + switch (cx.local_types[idx]) { + .i32 => { + try popW(cx, 9); + emitStrWImm(cx.buf, 9, 31, @intCast(8 + idx * 2)); + }, + .i64 => { + try popX(cx, 9); + emitStrXImm(cx.buf, 9, 31, @intCast(4 + idx)); + }, + .f32 => { + try popW(cx, 9); + emitStrWImm(cx.buf, 9, 31, @intCast(8 + idx * 2)); + }, + .f64 => { + try popX(cx, 9); + emitStrXImm(cx.buf, 9, 31, @intCast(4 + idx)); + }, + } + }, + 0x22 => { // local.tee + const idx = try binary.readULEB128(u32, code, pos); + if (idx >= cx.local_count) return error.UnsupportedOpcode; + switch (cx.local_types[idx]) { + .i32 => { + try popW(cx, 9); + emitStrWImm(cx.buf, 9, 31, @intCast(8 + idx * 2)); + try pushW(cx, 9); + }, + .i64 => { + try popX(cx, 9); + emitStrXImm(cx.buf, 9, 31, @intCast(4 + idx)); + try pushX(cx, 9); + }, + .f32 => { + try popW(cx, 9); + emitStrWImm(cx.buf, 9, 31, @intCast(8 + idx * 2)); + try pushW(cx, 9); + }, + .f64 => { + try popX(cx, 9); + emitStrXImm(cx.buf, 9, 31, @intCast(4 + idx)); + try pushX(cx, 9); + }, + } + }, + 0x41 => { // i32.const + const v = try binary.readSLEB128(i32, code, pos); + emitMovImm32(cx.buf, 9, @bitCast(v)); + try pushW(cx, 9); + }, + 0x42 => { // i64.const + const v = try binary.readSLEB128(i64, code, pos); + emitMovImm64(cx.buf, 9, @bitCast(v)); + try pushX(cx, 9); + }, + 0x1A => { // drop + try popX(cx, 9); + }, + 0x1B => { // select + try popW(cx, 11); // condition + try popX(cx, 10); // rhs + try popX(cx, 9); // lhs + emitCmpImmW(cx.buf, 11, 0); + emitCselX(cx.buf, 9, 10, 9, .eq); // cond==0 ? rhs : lhs + try pushX(cx, 9); + }, + 0x1C => { // typed select + const n = try binary.readULEB128(u32, code, pos); + if (n != 1) return error.UnsupportedOpcode; + if (pos.* >= code.len) return error.UnexpectedEof; + const vt = try decodeValType(code[pos.*]); + pos.* += 1; + + try popW(cx, 11); // condition + switch (vt) { + .i32, .f32 => { + try popW(cx, 10); // rhs + try popW(cx, 9); // lhs + emitCmpImmW(cx.buf, 11, 0); + emitCselW(cx.buf, 9, 10, 9, .eq); // cond==0 ? rhs : lhs + try pushW(cx, 9); + }, + .i64, .f64 => { + try popX(cx, 10); // rhs + try popX(cx, 9); // lhs + emitCmpImmW(cx.buf, 11, 0); + emitCselX(cx.buf, 9, 10, 9, .eq); // cond==0 ? rhs : lhs + try pushX(cx, 9); + }, + } + }, + 0x45 => { // i32.eqz + try popW(cx, 9); + emitCmpImmW(cx.buf, 9, 0); + emitCsetW(cx.buf, 9, .eq); + try pushW(cx, 9); + }, + 0x46 => try emitI32Cmp(cx, .eq), + 0x47 => try emitI32Cmp(cx, .ne), + 0x48 => try emitI32Cmp(cx, .lt_s), + 0x49 => try emitI32Cmp(cx, .lt_u), + 0x4A => try emitI32Cmp(cx, .gt_s), + 0x4B => try emitI32Cmp(cx, .gt_u), + 0x4C => try emitI32Cmp(cx, .le_s), + 0x4D => try emitI32Cmp(cx, .le_u), + 0x4E => try emitI32Cmp(cx, .ge_s), + 0x4F => try emitI32Cmp(cx, .ge_u), + 0x50 => { // i64.eqz + try popX(cx, 1); + emitMovImm64(cx.buf, 16, cx.helpers.i64_eqz); + emitBLR(cx.buf, 16); + try pushW(cx, 0); + }, + 0x51...0x5A => { + try popX(cx, 2); + try popX(cx, 1); + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.i64_cmp); + emitBLR(cx.buf, 16); + try pushW(cx, 0); + }, + 0x67 => try emitI32Unary(cx, .clz), + 0x68 => try emitI32Unary(cx, .ctz), + 0x69 => try emitI32Unary(cx, .popcnt), + 0x6A => try emitI32Bin(cx, .add), + 0x6B => try emitI32Bin(cx, .sub), + 0x6C => try emitI32Bin(cx, .mul), + 0x6D => try emitI32Bin(cx, .div_s), + 0x6E => try emitI32Bin(cx, .div_u), + 0x6F => try emitI32Bin(cx, .rem_s), + 0x70 => try emitI32Bin(cx, .rem_u), + 0x71 => try emitI32Bin(cx, .and_), + 0x72 => try emitI32Bin(cx, .or_), + 0x73 => try emitI32Bin(cx, .xor_), + 0x74 => try emitI32Bin(cx, .shl), + 0x75 => try emitI32Bin(cx, .shr_s), + 0x76 => try emitI32Bin(cx, .shr_u), + 0x77 => try emitI32Bin(cx, .rotl), + 0x78 => try emitI32Bin(cx, .rotr), + 0x79, 0x7A, 0x7B => { + try popX(cx, 1); + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.i64_unary); + emitBLR(cx.buf, 16); + try pushX(cx, 0); + }, + 0x7C...0x8A => { + try popX(cx, 2); + try popX(cx, 1); + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.i64_binary); + emitBLR(cx.buf, 16); + try pushX(cx, 0); + }, + 0x43 => { // f32.const + if (pos.* + 4 > code.len) return error.UnexpectedEof; + const bits = std.mem.readInt(u32, code[pos.*..][0..4], .little); + pos.* += 4; + emitMovImm32(cx.buf, 9, bits); + try pushW(cx, 9); + }, + 0x44 => { // f64.const + if (pos.* + 8 > code.len) return error.UnexpectedEof; + const bits = std.mem.readInt(u64, code[pos.*..][0..8], .little); + pos.* += 8; + emitMovImm64(cx.buf, 9, bits); + try pushX(cx, 9); + }, + 0x5B...0x60 => { // f32 comparisons + try popW(cx, 2); + try popW(cx, 1); + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.f32_cmp); + emitBLR(cx.buf, 16); + try pushW(cx, 0); + }, + 0x61...0x66 => { // f64 comparisons + try popX(cx, 2); + try popX(cx, 1); + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.f64_cmp); + emitBLR(cx.buf, 16); + try pushW(cx, 0); + }, + 0x8B...0x91 => { // f32 unary + try popW(cx, 1); + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.f32_unary); + emitBLR(cx.buf, 16); + try pushW(cx, 0); + }, + 0x92...0x98 => { // f32 binary + try popW(cx, 2); + try popW(cx, 1); + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.f32_binary); + emitBLR(cx.buf, 16); + try pushW(cx, 0); + }, + 0x99...0x9F => { // f64 unary + try popX(cx, 1); + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.f64_unary); + emitBLR(cx.buf, 16); + try pushX(cx, 0); + }, + 0xA0...0xA6 => { // f64 binary + try popX(cx, 2); + try popX(cx, 1); + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.f64_binary); + emitBLR(cx.buf, 16); + try pushX(cx, 0); + }, + 0xA7...0xBF => { // numeric conversions/reinterprets + switch (op) { + 0xA8, 0xA9, 0xAC, 0xAD, 0xB2, 0xB3, 0xB7, 0xB8, 0xBB, 0xBC, 0xBE => try popW(cx, 1), + else => try popX(cx, 1), + } + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.convert); + emitBLR(cx.buf, 16); + switch (convertResultType(op)) { + .i32, .f32 => try pushW(cx, 0), + .i64, .f64 => try pushX(cx, 0), + } + }, + 0xC0...0xC4 => { // integer sign-extension ops + switch (op) { + 0xC0, 0xC1 => try popW(cx, 1), + else => try popX(cx, 1), + } + emitMovImm32(cx.buf, 0, op); + emitMovImm64(cx.buf, 16, cx.helpers.i_extend); + emitBLR(cx.buf, 16); + switch (op) { + 0xC0, 0xC1 => try pushW(cx, 0), + else => try pushX(cx, 0), + } + }, + 0x10 => { // call + const fidx = try binary.readULEB128(u32, code, pos); + const cft = try getFuncType(cx.mod, cx.num_imported_funcs, fidx); + if (cft.results.len > 1) return error.UnsupportedOpcode; + if (cx.stack_depth < cft.params.len) return error.StackUnderflow; + + emitMovXReg(cx.buf, 0, 20); // x0 = saved instance pointer + emitMovImm32(cx.buf, 1, fidx); + if (cft.params.len == 0) { + emitMovXReg(cx.buf, 2, 19); + } else { + const bytes = cft.params.len * 8; + if (bytes > 4095) return error.UnsupportedOpcode; + emitSubImmX(cx.buf, 2, 19, @intCast(bytes)); + } + emitMovImm32(cx.buf, 3, @intCast(cft.params.len)); + emitMovImm64(cx.buf, 16, cx.helpers.call); + emitBLR(cx.buf, 16); + if (cft.params.len > 0) { + const bytes = cft.params.len * 8; + emitSubImmX(cx.buf, 19, 19, @intCast(bytes)); + cx.stack_depth -= cft.params.len; + } + if (cft.results.len == 1) { + switch (cft.results[0]) { + .i32, .f32 => try pushW(cx, 0), + .i64, .f64 => try pushX(cx, 0), + } + } + }, + 0x11 => { // call_indirect + const type_idx = try binary.readULEB128(u32, code, pos); + const table_idx = try binary.readULEB128(u32, code, pos); + if (type_idx >= cx.mod.types.len) return error.UnsupportedOpcode; + const cft = &cx.mod.types[type_idx]; + if (cft.results.len > 1) return error.UnsupportedOpcode; + if (cx.stack_depth < cft.params.len + 1) return error.StackUnderflow; + + emitMovXReg(cx.buf, 0, 20); + emitMovImm32(cx.buf, 1, type_idx); + emitMovImm32(cx.buf, 2, table_idx); + try popW(cx, 3); // callee element index + if (cft.params.len == 0) { + emitMovXReg(cx.buf, 4, 19); + } else { + const bytes = cft.params.len * 8; + if (bytes > 4095) return error.UnsupportedOpcode; + emitSubImmX(cx.buf, 4, 19, @intCast(bytes)); + } + emitMovImm32(cx.buf, 5, @intCast(cft.params.len)); + emitMovImm64(cx.buf, 16, cx.helpers.call_indirect); + emitBLR(cx.buf, 16); + if (cft.params.len > 0) { + const bytes = cft.params.len * 8; + emitSubImmX(cx.buf, 19, 19, @intCast(bytes)); + cx.stack_depth -= cft.params.len; + } + if (cft.results.len == 1) { + switch (cft.results[0]) { + .i32, .f32 => try pushW(cx, 0), + .i64, .f64 => try pushX(cx, 0), + } + } + }, + 0x23 => { // global.get + const gidx = try binary.readULEB128(u32, code, pos); + const gvt = try getGlobalValType(cx.mod, gidx); + emitMovXReg(cx.buf, 0, 20); + emitMovImm32(cx.buf, 1, gidx); + emitMovImm64(cx.buf, 16, cx.helpers.global_get); + emitBLR(cx.buf, 16); + switch (gvt) { + .i32, .f32 => try pushW(cx, 0), + .i64, .f64 => try pushX(cx, 0), + } + }, + 0x24 => { // global.set + const gidx = try binary.readULEB128(u32, code, pos); + const gvt = try getGlobalValType(cx.mod, gidx); + switch (gvt) { + .i32, .f32 => try popW(cx, 2), + .i64, .f64 => try popX(cx, 2), + } + emitMovXReg(cx.buf, 0, 20); + emitMovImm32(cx.buf, 1, gidx); + emitMovImm64(cx.buf, 16, cx.helpers.global_set); + emitBLR(cx.buf, 16); + }, + 0x28...0x35 => { // memory loads + _ = try binary.readULEB128(u32, code, pos); // align + const offset = try binary.readULEB128(u32, code, pos); + try popW(cx, 1); // addr + emitMovXReg(cx.buf, 0, 20); + emitMovImm32(cx.buf, 2, offset); + emitMovImm32(cx.buf, 3, op); + emitMovImm64(cx.buf, 16, cx.helpers.mem_load); + emitBLR(cx.buf, 16); + switch (memLoadResultType(op)) { + .i32, .f32 => try pushW(cx, 0), + .i64, .f64 => try pushX(cx, 0), + } + }, + 0x36...0x3E => { // memory stores + _ = try binary.readULEB128(u32, code, pos); // align + const offset = try binary.readULEB128(u32, code, pos); + switch (memStoreValueType(op)) { + .i32, .f32 => try popW(cx, 4), + .i64, .f64 => try popX(cx, 4), + } + try popW(cx, 1); // addr + emitMovXReg(cx.buf, 0, 20); + emitMovImm32(cx.buf, 2, offset); + emitMovImm32(cx.buf, 3, op); + emitMovImm64(cx.buf, 16, cx.helpers.mem_store); + emitBLR(cx.buf, 16); + }, + 0x3F => { // memory.size + _ = try binary.readULEB128(u8, code, pos); + emitMovXReg(cx.buf, 0, 20); + emitMovImm64(cx.buf, 16, cx.helpers.memory_size); + emitBLR(cx.buf, 16); + try pushW(cx, 0); + }, + 0x40 => { // memory.grow + _ = try binary.readULEB128(u8, code, pos); + try popW(cx, 1); // delta pages + emitMovXReg(cx.buf, 0, 20); + emitMovImm64(cx.buf, 16, cx.helpers.memory_grow); + emitBLR(cx.buf, 16); + try pushW(cx, 0); + }, + 0xFC => { // bulk memory + const subop = try binary.readULEB128(u32, code, pos); + switch (subop) { + 0...7 => { // trunc_sat conversions + switch (subop) { + 0, 1, 4, 5 => try popW(cx, 1), + 2, 3, 6, 7 => try popX(cx, 1), + else => unreachable, + } + emitMovImm32(cx.buf, 0, @intCast(subop)); + emitMovImm64(cx.buf, 16, cx.helpers.trunc_sat); + emitBLR(cx.buf, 16); + switch (subop) { + 0, 1, 2, 3 => try pushW(cx, 0), + 4, 5, 6, 7 => try pushX(cx, 0), + else => unreachable, + } + }, + 8 => { // memory.init + const data_idx = try binary.readULEB128(u32, code, pos); + const mem_idx = try binary.readULEB128(u32, code, pos); + if (mem_idx != 0) return error.UnsupportedOpcode; + try popW(cx, 3); // len + try popW(cx, 2); // src + try popW(cx, 1); // dst + emitMovXReg(cx.buf, 0, 20); + emitMovImm32(cx.buf, 4, data_idx); + emitMovImm64(cx.buf, 16, cx.helpers.memory_init); + emitBLR(cx.buf, 16); + }, + 9 => { // data.drop + const data_idx = try binary.readULEB128(u32, code, pos); + emitMovXReg(cx.buf, 0, 20); + emitMovImm32(cx.buf, 1, data_idx); + emitMovImm64(cx.buf, 16, cx.helpers.data_drop); + emitBLR(cx.buf, 16); + }, + 10 => { // memory.copy + const dst_mem = try binary.readULEB128(u32, code, pos); + const src_mem = try binary.readULEB128(u32, code, pos); + if (dst_mem != 0 or src_mem != 0) return error.UnsupportedOpcode; + try popW(cx, 3); // len + try popW(cx, 2); // src + try popW(cx, 1); // dst + emitMovXReg(cx.buf, 0, 20); + emitMovImm64(cx.buf, 16, cx.helpers.memory_copy); + emitBLR(cx.buf, 16); + }, + 11 => { // memory.fill + const mem_idx = try binary.readULEB128(u32, code, pos); + if (mem_idx != 0) return error.UnsupportedOpcode; + try popW(cx, 3); // len + try popW(cx, 2); // value + try popW(cx, 1); // dst + emitMovXReg(cx.buf, 0, 20); + emitMovImm64(cx.buf, 16, cx.helpers.memory_fill); + emitBLR(cx.buf, 16); + }, + 16 => { // table.size + const table_idx = try binary.readULEB128(u32, code, pos); + emitMovXReg(cx.buf, 0, 20); + emitMovImm32(cx.buf, 1, table_idx); + emitMovImm64(cx.buf, 16, cx.helpers.table_size); + emitBLR(cx.buf, 16); + try pushW(cx, 0); + }, + else => return error.UnsupportedOpcode, + } + }, + 0x02, 0x03 => { // block/loop + const is_loop = op == 0x03; + const sig = try readBlockSig(cx, code, pos, is_loop); + try cx.control.append(cx.allocator, .{ + .kind = if (is_loop) .loop else .block, + .entry_depth = cx.stack_depth, + .label_arity = if (is_loop) 0 else sig.arity, + .label_type = if (is_loop) null else sig.val_type, + .end_arity = sig.arity, + .end_type = sig.val_type, + .loop_head_pos = cx.buf.cursor(), + .end_patches = .empty, + }); + const nested_end = try compileBlock(cx, code, pos, false); + if (nested_end != .hit_end) return error.MalformedControlFlow; + var fr = cx.control.pop().?; + fr.end_patches.deinit(cx.allocator); + }, + 0x04 => { // if + const sig = try readBlockSig(cx, code, pos, false); + try popW(cx, 9); + const entry_depth = cx.stack_depth; + try cx.control.append(cx.allocator, .{ + .kind = .@"if", + .entry_depth = entry_depth, + .label_arity = sig.arity, + .label_type = sig.val_type, + .end_arity = sig.arity, + .end_type = sig.val_type, + .loop_head_pos = 0, + .end_patches = .empty, + }); + + const cbz_pos = emitCBZPlaceholder(cx.buf, 9); + const then_end = try compileBlock(cx, code, pos, true); + if (then_end == .hit_else) { + const jump_end_pos = emitBPlaceholder(cx.buf); + try currentFrame(cx).end_patches.append(cx.allocator, jump_end_pos); + patchCBZ(cx.buf, cbz_pos, cx.buf.cursor()); + try setStackDepth(cx, entry_depth); + const else_end = try compileBlock(cx, code, pos, false); + if (else_end != .hit_end) return error.MalformedControlFlow; + } else { + patchCBZ(cx.buf, cbz_pos, cx.buf.cursor()); + } + var fr = cx.control.pop().?; + fr.end_patches.deinit(cx.allocator); + }, + 0x00 => { // unreachable + emitMovXReg(cx.buf, 0, 20); + emitMovImm64(cx.buf, 16, cx.helpers.@"unreachable"); + emitBLR(cx.buf, 16); + }, + 0x01 => {}, // nop + 0x0F => { // return + if (cx.result_type) |rt| { + switch (rt) { + .i32, .f32 => try popW(cx, 0), + .i64, .f64 => try popX(cx, 0), + } + } + emitEpilogueAndRet(cx.buf); + return .hit_end; + }, + else => return error.UnsupportedOpcode, + } + } + return error.UnexpectedEof; +} + +const BlockSig = struct { + arity: u8, + val_type: ?module.ValType, +}; + +fn readBlockSig(cx: *Context, code: []const u8, pos: *usize, is_loop: bool) !BlockSig { + const bt = try binary.readSLEB128(i33, code, pos); + if (bt == -0x40) return .{ .arity = 0, .val_type = null }; // empty + if (bt == -0x01) return .{ .arity = if (is_loop) 0 else 1, .val_type = .i32 }; + if (bt == -0x02) return .{ .arity = if (is_loop) 0 else 1, .val_type = .i64 }; + if (bt == -0x03) return .{ .arity = if (is_loop) 0 else 1, .val_type = .f32 }; + if (bt == -0x04) return .{ .arity = if (is_loop) 0 else 1, .val_type = .f64 }; + if (bt < 0) return error.UnsupportedOpcode; + const type_idx: u32 = @intCast(bt); + if (type_idx >= cx.mod.types.len) return error.UnsupportedOpcode; + const ft = &cx.mod.types[type_idx]; + if (ft.params.len != 0) return error.UnsupportedOpcode; + if (ft.results.len == 0) return .{ .arity = 0, .val_type = null }; + if (ft.results.len == 1) return .{ .arity = if (is_loop) 0 else 1, .val_type = ft.results[0] }; + return error.UnsupportedOpcode; +} + +fn decodeValType(b: u8) !module.ValType { + return switch (b) { + 0x7F => .i32, + 0x7E => .i64, + 0x7D => .f32, + 0x7C => .f64, + else => error.UnsupportedOpcode, + }; +} + +fn memLoadResultType(op: u8) module.ValType { + return switch (op) { + 0x28, 0x2C, 0x2D, 0x2E, 0x2F => .i32, + 0x29, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35 => .i64, + 0x2A => .f32, + 0x2B => .f64, + else => .i32, + }; +} + +fn memStoreValueType(op: u8) module.ValType { + return switch (op) { + 0x36, 0x3A, 0x3B => .i32, + 0x37, 0x3C, 0x3D, 0x3E => .i64, + 0x38 => .f32, + 0x39 => .f64, + else => .i32, + }; +} + +fn convertResultType(op: u8) module.ValType { + return switch (op) { + 0xA7, 0xA8, 0xA9, 0xAA, 0xAB, 0xBC => .i32, + 0xAC, 0xAD, 0xAE, 0xAF, 0xB0, 0xB1, 0xBD => .i64, + 0xB2, 0xB3, 0xB4, 0xB5, 0xB6, 0xBE => .f32, + 0xB7, 0xB8, 0xB9, 0xBA, 0xBB, 0xBF => .f64, + else => .i32, + }; +} + +fn currentFrame(cx: *Context) *ControlFrame { + return &cx.control.items[cx.control.items.len - 1]; +} + +fn setStackDepth(cx: *Context, depth: usize) !void { + if (depth > cx.max_stack_depth) return error.StackOverflow; + if (depth == cx.stack_depth) return; + if (depth > cx.stack_depth) { + const bytes = (depth - cx.stack_depth) * 8; + if (bytes > 4095) return error.UnsupportedOpcode; + emitAddImmX(cx.buf, 19, 19, @intCast(bytes)); + } else { + const bytes = (cx.stack_depth - depth) * 8; + if (bytes > 4095) return error.UnsupportedOpcode; + emitSubImmX(cx.buf, 19, 19, @intCast(bytes)); + } + cx.stack_depth = depth; +} + +fn emitBrToDepth(cx: *Context, depth: u32) !void { + if (depth >= cx.control.items.len) return error.MalformedControlFlow; + const target_idx = cx.control.items.len - 1 - depth; + const target = &cx.control.items[target_idx]; + + const result_reg: u5 = 9; + if (target.label_arity == 1) { + const t = target.label_type orelse return error.UnsupportedOpcode; + switch (t) { + .i32, .f32 => try popW(cx, result_reg), + .i64, .f64 => try popX(cx, result_reg), + } + } else if (target.label_arity != 0) { + return error.UnsupportedOpcode; + } + try setStackDepth(cx, target.entry_depth); + if (target.label_arity == 1 and target.kind != .loop) { + const t = target.label_type orelse return error.UnsupportedOpcode; + switch (t) { + .i32, .f32 => try pushW(cx, result_reg), + .i64, .f64 => try pushX(cx, result_reg), + } + } + + if (target.kind == .loop) { + const p = emitBPlaceholder(cx.buf); + patchB(cx.buf, p, target.loop_head_pos); + } else { + const p = emitBPlaceholder(cx.buf); + try target.end_patches.append(cx.allocator, p); + } +} + +fn emitI32Unary(cx: *Context, comptime op: enum { clz, ctz, popcnt }) !void { + try popW(cx, 9); + switch (op) { + .clz => emitClzW(cx.buf, 9, 9), + .ctz => { + emitRbitW(cx.buf, 9, 9); + emitClzW(cx.buf, 9, 9); + }, + .popcnt => { + emitMovImm32(cx.buf, 10, 0x5555_5555); + emitLsrvWImm(cx.buf, 11, 9, 1); + emitAndWReg(cx.buf, 11, 11, 10); + emitSubWReg(cx.buf, 9, 9, 11); + + emitMovImm32(cx.buf, 10, 0x3333_3333); + emitAndWReg(cx.buf, 11, 9, 10); + emitLsrvWImm(cx.buf, 12, 9, 2); + emitAndWReg(cx.buf, 12, 12, 10); + emitAddWReg(cx.buf, 9, 11, 12); + + emitLsrvWImm(cx.buf, 11, 9, 4); + emitAddWReg(cx.buf, 9, 9, 11); + emitMovImm32(cx.buf, 10, 0x0f0f_0f0f); + emitAndWReg(cx.buf, 9, 9, 10); + + emitLsrvWImm(cx.buf, 11, 9, 8); + emitAddWReg(cx.buf, 9, 9, 11); + emitLsrvWImm(cx.buf, 11, 9, 16); + emitAddWReg(cx.buf, 9, 9, 11); + emitMovImm32(cx.buf, 10, 0x3f); + emitAndWReg(cx.buf, 9, 9, 10); + }, + } + try pushW(cx, 9); +} + +fn emitI32Bin(cx: *Context, comptime op: enum { + add, + sub, + mul, + div_s, + div_u, + rem_s, + rem_u, + and_, + or_, + xor_, + shl, + shr_s, + shr_u, + rotl, + rotr, +}) !void { + try popW(cx, 10); + try popW(cx, 9); + switch (op) { + .add => emitAddWReg(cx.buf, 9, 9, 10), + .sub => emitSubWReg(cx.buf, 9, 9, 10), + .mul => emitMulWReg(cx.buf, 9, 9, 10), + .div_s => { + emitMovWReg(cx.buf, 0, 9); + emitMovWReg(cx.buf, 1, 10); + emitMovImm64(cx.buf, 16, cx.helpers.i32_div_s); + emitBLR(cx.buf, 16); + emitMovWReg(cx.buf, 9, 0); + }, + .div_u => { + emitMovWReg(cx.buf, 0, 9); + emitMovWReg(cx.buf, 1, 10); + emitMovImm64(cx.buf, 16, cx.helpers.i32_div_u); + emitBLR(cx.buf, 16); + emitMovWReg(cx.buf, 9, 0); + }, + .rem_s => { + emitMovWReg(cx.buf, 0, 9); + emitMovWReg(cx.buf, 1, 10); + emitMovImm64(cx.buf, 16, cx.helpers.i32_rem_s); + emitBLR(cx.buf, 16); + emitMovWReg(cx.buf, 9, 0); + }, + .rem_u => { + emitMovWReg(cx.buf, 0, 9); + emitMovWReg(cx.buf, 1, 10); + emitMovImm64(cx.buf, 16, cx.helpers.i32_rem_u); + emitBLR(cx.buf, 16); + emitMovWReg(cx.buf, 9, 0); + }, + .and_ => emitAndWReg(cx.buf, 9, 9, 10), + .or_ => emitOrrWReg(cx.buf, 9, 9, 10), + .xor_ => emitEorWReg(cx.buf, 9, 9, 10), + .shl => emitLslvWReg(cx.buf, 9, 9, 10), + .shr_s => emitAsrvWReg(cx.buf, 9, 9, 10), + .shr_u => emitLsrvWReg(cx.buf, 9, 9, 10), + .rotl => { + emitNegWReg(cx.buf, 11, 10); + emitRorvWReg(cx.buf, 9, 9, 11); + }, + .rotr => emitRorvWReg(cx.buf, 9, 9, 10), + } + try pushW(cx, 9); +} + +fn emitI32Cmp(cx: *Context, comptime op: enum { eq, ne, lt_s, lt_u, gt_s, gt_u, le_s, le_u, ge_s, ge_u }) !void { + try popW(cx, 10); + try popW(cx, 9); + emitCmpWReg(cx.buf, 9, 10); + switch (op) { + .eq => emitCsetW(cx.buf, 9, .eq), + .ne => emitCsetW(cx.buf, 9, .ne), + .lt_s => emitCsetW(cx.buf, 9, .lt), + .lt_u => emitCsetW(cx.buf, 9, .lo), + .gt_s => emitCsetW(cx.buf, 9, .gt), + .gt_u => emitCsetW(cx.buf, 9, .hi), + .le_s => emitCsetW(cx.buf, 9, .le), + .le_u => emitCsetW(cx.buf, 9, .ls), + .ge_s => emitCsetW(cx.buf, 9, .ge), + .ge_u => emitCsetW(cx.buf, 9, .hs), + } + try pushW(cx, 9); +} + +fn pushW(cx: *Context, wreg: u5) !void { + if (cx.stack_depth >= cx.max_stack_depth) return error.StackOverflow; + emitStrWImm(cx.buf, wreg, 19, 0); + emitAddImmX(cx.buf, 19, 19, 8); + cx.stack_depth += 1; +} + +fn popW(cx: *Context, wreg: u5) !void { + if (cx.stack_depth == 0) return error.StackUnderflow; + emitSubImmX(cx.buf, 19, 19, 8); + emitLdrWImm(cx.buf, wreg, 19, 0); + cx.stack_depth -= 1; +} + +fn pushX(cx: *Context, xreg: u5) !void { + if (cx.stack_depth >= cx.max_stack_depth) return error.StackOverflow; + emitStrXImm(cx.buf, xreg, 19, 0); + emitAddImmX(cx.buf, 19, 19, 8); + cx.stack_depth += 1; +} + +fn popX(cx: *Context, xreg: u5) !void { + if (cx.stack_depth == 0) return error.StackUnderflow; + emitSubImmX(cx.buf, 19, 19, 8); + emitLdrXImm(cx.buf, xreg, 19, 0); + cx.stack_depth -= 1; +} + +fn getFuncType(mod: *const module.Module, num_imported: u32, fidx: u32) !*const module.FuncType { + if (fidx < num_imported) { + var count: u32 = 0; + for (mod.imports) |imp| { + if (imp.desc == .func) { + if (count == fidx) return &mod.types[imp.desc.func]; + count += 1; + } + } + return error.InvalidFunctionIndex; + } + const local_idx = fidx - num_imported; + if (local_idx >= mod.functions.len) return error.InvalidFunctionIndex; + const type_idx = mod.functions[local_idx]; + if (type_idx >= mod.types.len) return error.InvalidTypeIndex; + return &mod.types[type_idx]; +} + +fn getGlobalValType(mod: *const module.Module, gidx: u32) !module.ValType { + var import_global_count: u32 = 0; + for (mod.imports) |imp| { + if (imp.desc == .global) { + if (import_global_count == gidx) return imp.desc.global.valtype; + import_global_count += 1; + } + } + const local_idx = gidx - import_global_count; + if (local_idx >= mod.globals.len) return error.InvalidGlobalIndex; + return mod.globals[local_idx].type.valtype; +} + +fn emitPrologue(buf: *codebuf.CodeBuffer, param_count: u12, local_count: u12, operand_base_words: u12) void { + // sub sp, sp, #0x400 + emitSubImmX(buf, 31, 31, 0x400); + // str x30, [sp, #0] + emitStrXImm(buf, 30, 31, 0); + // str x0, [sp, #8] + emitStrXImm(buf, 0, 31, 1); + // str x19, [sp, #16] + emitStrXImm(buf, 19, 31, 2); + // str x20, [sp, #24] + emitStrXImm(buf, 20, 31, 3); + // mov x20, x0 + emitMovXReg(buf, 20, 0); + // Spill params from arg pointer x1 into local slots [sp,#32 + idx*8] + var i: u12 = 0; + while (i < param_count) : (i += 1) { + emitLdrXImm(buf, 9, 1, i); + emitStrXImm(buf, 9, 31, 4 + i); + } + // Zero-initialize declared locals. + emitMovImm64(buf, 9, 0); + var j: u12 = param_count; + while (j < local_count) : (j += 1) { + emitStrXImm(buf, 9, 31, 4 + j); + } + // Operand stack starts after locals. + emitAddImmX(buf, 19, 31, operand_base_words * 8); +} + +fn emitEpilogueAndRet(buf: *codebuf.CodeBuffer) void { + // ldr x20, [sp, #24] + emitLdrXImm(buf, 20, 31, 3); + // ldr x19, [sp, #16] + emitLdrXImm(buf, 19, 31, 2); + // ldr x30, [sp, #0] + emitLdrXImm(buf, 30, 31, 0); + // add sp, sp, #0x400 + emitAddImmX(buf, 31, 31, 0x400); + emitRET(buf); +} + +const Cond = enum(u4) { + eq = 0x0, + ne = 0x1, + hs = 0x2, + lo = 0x3, + hi = 0x8, + ls = 0x9, + ge = 0xA, + lt = 0xB, + gt = 0xC, + le = 0xD, +}; + +fn emitAddWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x0B000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitSubWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x4B000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitMulWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x1B007C00 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitSdivWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x1AC00C00 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitUdivWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x1AC00800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitMsubWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5, ra: u5) void { + buf.emitU32Le(0x1B008000 | (@as(u32, rm) << 16) | (@as(u32, ra) << 10) | (@as(u32, rn) << 5) | rd); +} + +fn emitAndWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x0A000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitOrrWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x2A000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitEorWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x4A000000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitLslvWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x1AC02000 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitLsrvWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x1AC02400 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitLsrvWImm(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, shift: u6) void { + // ubfm wd, wn, #shift, #31 + buf.emitU32Le(0x53007C00 | (@as(u32, shift) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitAsrvWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x1AC02800 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitRorvWReg(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5) void { + buf.emitU32Le(0x1AC02C00 | (@as(u32, rm) << 16) | (@as(u32, rn) << 5) | rd); +} + +fn emitNegWReg(buf: *codebuf.CodeBuffer, rd: u5, rm: u5) void { + // sub wd, wzr, wm + buf.emitU32Le(0x4B0003E0 | (@as(u32, rm) << 16) | rd); +} + +fn emitRbitW(buf: *codebuf.CodeBuffer, rd: u5, rn: u5) void { + buf.emitU32Le(0x5AC00000 | (@as(u32, rn) << 5) | rd); +} + +fn emitClzW(buf: *codebuf.CodeBuffer, rd: u5, rn: u5) void { + buf.emitU32Le(0x5AC01000 | (@as(u32, rn) << 5) | rd); +} + +fn emitCmpWReg(buf: *codebuf.CodeBuffer, rn: u5, rm: u5) void { + // SUBS WZR, Wn, Wm + buf.emitU32Le(0x6B00001F | (@as(u32, rm) << 16) | (@as(u32, rn) << 5)); +} + +fn emitCsetW(buf: *codebuf.CodeBuffer, rd: u5, cond: Cond) void { + // CSET Wd, cond = CSINC Wd, WZR, WZR, invert(cond) + const inv: u4 = @intFromEnum(cond) ^ 1; + buf.emitU32Le(0x1A9F07E0 | (@as(u32, inv) << 12) | rd); +} + +fn emitCselW(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5, cond: Cond) void { + buf.emitU32Le(0x1A800000 | (@as(u32, rm) << 16) | (@as(u32, @intFromEnum(cond)) << 12) | (@as(u32, rn) << 5) | rd); +} + +fn emitCselX(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, rm: u5, cond: Cond) void { + buf.emitU32Le(0x9A800000 | (@as(u32, rm) << 16) | (@as(u32, @intFromEnum(cond)) << 12) | (@as(u32, rn) << 5) | rd); +} + +fn emitMovXReg(buf: *codebuf.CodeBuffer, rd: u5, rm: u5) void { + // MOV Xd, Xm alias of ORR Xd, XZR, Xm + buf.emitU32Le(0xAA0003E0 | (@as(u32, rm) << 16) | rd); +} + +fn emitMovWReg(buf: *codebuf.CodeBuffer, rd: u5, rm: u5) void { + // MOV Wd, Wm alias of ORR Wd, WZR, Wm + buf.emitU32Le(0x2A0003E0 | (@as(u32, rm) << 16) | rd); +} + +fn emitMovImm32(buf: *codebuf.CodeBuffer, rd: u5, imm: u32) void { + const low: u16 = @truncate(imm & 0xFFFF); + const high: u16 = @truncate((imm >> 16) & 0xFFFF); + buf.emitU32Le(0x52800000 | (@as(u32, low) << 5) | rd); // movz + if (high != 0) { + buf.emitU32Le(0x72800000 | (1 << 21) | (@as(u32, high) << 5) | rd); // movk lsl16 + } +} + +fn emitMovImm64(buf: *codebuf.CodeBuffer, rd: u5, imm: usize) void { + const v: u64 = @intCast(imm); + const p0: u16 = @truncate(v & 0xFFFF); + const p1: u16 = @truncate((v >> 16) & 0xFFFF); + const p2: u16 = @truncate((v >> 32) & 0xFFFF); + const p3: u16 = @truncate((v >> 48) & 0xFFFF); + buf.emitU32Le(0xD2800000 | (@as(u32, p0) << 5) | rd); // movz + buf.emitU32Le(0xF2800000 | (1 << 21) | (@as(u32, p1) << 5) | rd); // movk lsl16 + buf.emitU32Le(0xF2800000 | (2 << 21) | (@as(u32, p2) << 5) | rd); // movk lsl32 + buf.emitU32Le(0xF2800000 | (3 << 21) | (@as(u32, p3) << 5) | rd); // movk lsl48 +} + +fn emitAddImmX(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, imm12: u12) void { + buf.emitU32Le(0x91000000 | (@as(u32, imm12) << 10) | (@as(u32, rn) << 5) | rd); +} + +fn emitSubImmX(buf: *codebuf.CodeBuffer, rd: u5, rn: u5, imm12: u12) void { + buf.emitU32Le(0xD1000000 | (@as(u32, imm12) << 10) | (@as(u32, rn) << 5) | rd); +} + +fn emitCmpImmW(buf: *codebuf.CodeBuffer, rn: u5, imm12: u12) void { + // subs wzr, wn, #imm + buf.emitU32Le(0x7100001F | (@as(u32, imm12) << 10) | (@as(u32, rn) << 5)); +} + +fn emitStrWImm(buf: *codebuf.CodeBuffer, rt: u5, rn: u5, imm12: u12) void { + buf.emitU32Le(0xB9000000 | (@as(u32, imm12) << 10) | (@as(u32, rn) << 5) | rt); +} + +fn emitLdrWImm(buf: *codebuf.CodeBuffer, rt: u5, rn: u5, imm12: u12) void { + buf.emitU32Le(0xB9400000 | (@as(u32, imm12) << 10) | (@as(u32, rn) << 5) | rt); +} + +fn emitStrXImm(buf: *codebuf.CodeBuffer, rt: u5, rn: u5, imm12: u12) void { + buf.emitU32Le(0xF9000000 | (@as(u32, imm12) << 10) | (@as(u32, rn) << 5) | rt); +} + +fn emitLdrXImm(buf: *codebuf.CodeBuffer, rt: u5, rn: u5, imm12: u12) void { + buf.emitU32Le(0xF9400000 | (@as(u32, imm12) << 10) | (@as(u32, rn) << 5) | rt); +} + +fn emitRET(buf: *codebuf.CodeBuffer) void { + buf.emitU32Le(0xD65F03C0); +} + +fn emitBLR(buf: *codebuf.CodeBuffer, rn: u5) void { + buf.emitU32Le(0xD63F0000 | (@as(u32, rn) << 5)); +} + +fn emitBLToStart(buf: *codebuf.CodeBuffer) void { + const p = buf.cursor(); + const delta_instrs: i32 = @intCast(@divTrunc(-@as(isize, @intCast(p)), 4)); + const bits: u32 = @bitCast(delta_instrs); + buf.emitU32Le(0x94000000 | (bits & 0x03FFFFFF)); +} + +fn emitBPlaceholder(buf: *codebuf.CodeBuffer) usize { + const p = buf.cursor(); + buf.emitU32Le(0x14000000); + return p; +} + +fn emitCBZPlaceholder(buf: *codebuf.CodeBuffer, rt: u5) usize { + const p = buf.cursor(); + buf.emitU32Le(0x34000000 | @as(u32, rt)); // cbz wRt, #0 + return p; +} + +fn patchB(buf: *codebuf.CodeBuffer, patch_pos: usize, target_pos: usize) void { + const delta_instrs: i32 = @intCast(@divTrunc(@as(isize, @intCast(target_pos)) - @as(isize, @intCast(patch_pos)), 4)); + std.debug.assert(delta_instrs >= -(1 << 25) and delta_instrs < (1 << 25)); + const bits: u32 = @bitCast(delta_instrs); + const patched = (std.mem.readInt(u32, buf.buf[patch_pos..][0..4], .little) & 0xFC000000) | (bits & 0x03FFFFFF); + buf.patchU32(patch_pos, patched); +} + +fn patchCBZ(buf: *codebuf.CodeBuffer, patch_pos: usize, target_pos: usize) void { + const delta_instrs: i32 = @intCast(@divTrunc(@as(isize, @intCast(target_pos)) - @as(isize, @intCast(patch_pos)), 4)); + std.debug.assert(delta_instrs >= -(1 << 18) and delta_instrs < (1 << 18)); + const bits: u32 = @bitCast(delta_instrs); + const old = std.mem.readInt(u32, buf.buf[patch_pos..][0..4], .little); + const patched = (old & 0xFF00001F) | ((bits & 0x7FFFF) << 5); + buf.patchU32(patch_pos, patched); +} + +test "aarch64 encode core instruction words" { + var b = try codebuf.CodeBuffer.init(std.testing.allocator, 4096); + defer b.deinit(); + + emitAddWReg(&b, 0, 1, 2); + emitSubWReg(&b, 3, 4, 5); + emitRET(&b); + + try std.testing.expectEqual(@as(u32, 0x0B020020), std.mem.readInt(u32, b.buf[0..4], .little)); + try std.testing.expectEqual(@as(u32, 0x4B050083), std.mem.readInt(u32, b.buf[4..8], .little)); + try std.testing.expectEqual(@as(u32, 0xD65F03C0), std.mem.readInt(u32, b.buf[8..12], .little)); +} + +test "aarch64 compileFunctionI32 executes const return" { + if (builtin.cpu.arch != .aarch64) return error.SkipZigTest; + + var params = [_]module.ValType{}; + var results = [_]module.ValType{.i32}; + const ft = module.FuncType{ .params = ¶ms, .results = &results }; + var bodies = [_]module.FunctionBody{.{ + .locals = &.{}, + .code = &[_]u8{ 0x41, 0x2a, 0x0b }, + }}; + var types = [_]module.FuncType{ft}; + var funcs = [_]u32{0}; + const mod = module.Module{ + .types = &types, + .imports = &.{}, + .functions = &funcs, + .tables = &.{}, + .memories = &.{}, + .globals = &.{}, + .exports = &.{}, + .start = null, + .elements = &.{}, + .codes = &bodies, + .datas = &.{}, + .allocator = std.testing.allocator, + }; + + const Helper = struct { + fn call(_: *anyopaque, _: u32, _: [*]const u64, _: u32) callconv(.c) u64 { + return 0; + } + }; + + const helpers: HelperAddrs = .{ + .call = @intFromPtr(&Helper.call), + .@"unreachable" = 0, + .global_get = 0, + .global_set = 0, + .mem_load = 0, + .mem_store = 0, + .i32_unary = 0, + .i32_cmp = 0, + .i32_binary = 0, + .i32_div_s = 0, + .i32_div_u = 0, + .i32_rem_s = 0, + .i32_rem_u = 0, + .i64_eqz = 0, + .i64_cmp = 0, + .i64_unary = 0, + .i64_binary = 0, + .f32_cmp = 0, + .f64_cmp = 0, + .f32_unary = 0, + .f32_binary = 0, + .f64_unary = 0, + .f64_binary = 0, + .convert = 0, + .trunc_sat = 0, + .i_extend = 0, + .memory_init = 0, + .data_drop = 0, + .memory_copy = 0, + .memory_fill = 0, + .table_size = 0, + .memory_size = 0, + .memory_grow = 0, + .call_indirect = 0, + }; + var jit = (try compileFunctionI32(std.testing.allocator, &mod, 0, 0, &bodies[0], &ft, helpers)) orelse return error.TestUnexpectedResult; + defer jit.buf.deinit(); + var zero: u64 = 0; + const fn_ptr = jit.buf.funcPtr(fn (*anyopaque, [*]const u64, u32) callconv(.c) u64, 0); + const r = fn_ptr(@ptrFromInt(1), @ptrCast(&zero), 0); + try std.testing.expectEqual(@as(u64, 42), r); +} diff --git a/src/wasm/jit/codebuf.zig b/src/wasm/jit/codebuf.zig new file mode 100644 index 0000000..3a74229 --- /dev/null +++ b/src/wasm/jit/codebuf.zig @@ -0,0 +1,189 @@ +const std = @import("std"); +const builtin = @import("builtin"); + +const is_apple_silicon = builtin.os.tag == .macos and builtin.cpu.arch == .aarch64; +const is_macos = builtin.os.tag == .macos; +const is_aarch64 = builtin.cpu.arch == .aarch64; + +// Apple Silicon JIT: MAP_JIT is mandatory. +// Host binary must have entitlement: com.apple.security.cs.allow-jit +const MAP_JIT: u32 = 0x0800; // Darwin-specific + +// Apple Silicon platform APIs +const AppleSiliconJIT = if (is_apple_silicon) struct { + pub extern fn pthread_jit_write_protect_np(enabled: c_int) void; + pub extern fn sys_icache_invalidate(start: *anyopaque, len: usize) void; +} else struct { + pub inline fn pthread_jit_write_protect_np(enabled: c_int) void { _ = enabled; } + pub inline fn sys_icache_invalidate(start: *anyopaque, len: usize) void { _ = start; _ = len; } +}; + +// Linux AArch64 cache flush via compiler-rt +const LinuxAArch64 = if (!is_apple_silicon and is_aarch64) struct { + pub extern fn __clear_cache(start: *anyopaque, end: *anyopaque) void; +} else struct { + pub inline fn __clear_cache(start: *anyopaque, end: *anyopaque) void { _ = start; _ = end; } +}; + +// PROT_READ | PROT_WRITE +const prot_rw = std.posix.PROT{ .READ = true, .WRITE = true }; +// PROT_READ | PROT_EXEC +const prot_rx = std.posix.PROT{ .READ = true, .EXEC = true }; + +pub fn flushICache(ptr: [*]u8, len: usize) void { + if (!is_aarch64) return; + if (is_apple_silicon) { + AppleSiliconJIT.sys_icache_invalidate(ptr, len); + } else { + LinuxAArch64.__clear_cache(ptr, ptr + len); + } +} + +pub const CodeBuffer = struct { + buf: []align(std.heap.page_size_min) u8, + pos: usize, + + pub fn init(allocator: std.mem.Allocator, capacity: usize) !CodeBuffer { + _ = allocator; + const aligned_cap = std.mem.alignForward(usize, capacity, std.heap.page_size_min); + // Plain mmap RW, then mprotect to RX in finalize(). + // MAP_JIT + pthread_jit_write_protect_np requires com.apple.security.cs.allow-jit + // entitlement (hardened runtime); ad-hoc signed test binaries do not have it. + // Plain mprotect works for non-hardened binaries on all platforms. + const buf = if (is_macos) blk: { + // Darwin MAP_PRIVATE | MAP_ANONYMOUS (no MAP_JIT) + const MAP_PRIVATE: u32 = 0x0002; + const MAP_ANONYMOUS: u32 = 0x1000; + const flags: u32 = MAP_PRIVATE | MAP_ANONYMOUS; + const slice = try std.posix.mmap(null, aligned_cap, prot_rw, @bitCast(flags), -1, 0); + break :blk slice; + } else blk: { + const slice = try std.posix.mmap( + null, + aligned_cap, + prot_rw, + .{ .TYPE = .PRIVATE, .ANONYMOUS = true }, + -1, + 0, + ); + break :blk slice; + }; + return .{ .buf = buf, .pos = 0 }; + } + + pub fn deinit(self: *CodeBuffer) void { + std.posix.munmap(self.buf); + self.* = undefined; + } + + pub fn emit1(self: *CodeBuffer, byte: u8) void { + std.debug.assert(self.pos < self.buf.len); + self.buf[self.pos] = byte; + self.pos += 1; + } + + pub fn emitSlice(self: *CodeBuffer, bytes: []const u8) void { + std.debug.assert(self.pos + bytes.len <= self.buf.len); + @memcpy(self.buf[self.pos..][0..bytes.len], bytes); + self.pos += bytes.len; + } + + pub fn emitU32Le(self: *CodeBuffer, v: u32) void { + std.debug.assert(self.pos + 4 <= self.buf.len); + std.mem.writeInt(u32, self.buf[self.pos..][0..4], v, .little); + self.pos += 4; + } + + pub fn emitI32Le(self: *CodeBuffer, v: i32) void { + self.emitU32Le(@bitCast(v)); + } + + pub fn cursor(self: *const CodeBuffer) usize { + return self.pos; + } + + pub fn patchI32(self: *CodeBuffer, pos: usize, value: i32) void { + std.mem.writeInt(i32, self.buf[pos..][0..4], value, .little); + } + + pub fn patchU32(self: *CodeBuffer, pos: usize, value: u32) void { + std.mem.writeInt(u32, self.buf[pos..][0..4], value, .little); + } + + /// Apple Silicon: switch to RW mode for patching after finalize. + pub fn beginWrite(self: *CodeBuffer) void { + _ = self; + if (is_apple_silicon) AppleSiliconJIT.pthread_jit_write_protect_np(0); + } + + /// Apple Silicon: switch to RX mode and flush I-cache. + pub fn endWrite(self: *CodeBuffer) !void { + if (is_apple_silicon) { + AppleSiliconJIT.pthread_jit_write_protect_np(1); + } + if (is_aarch64) flushICache(self.buf.ptr, self.pos); + } + + /// Make the buffer executable. Must be called before executing any code. + pub fn finalize(self: *CodeBuffer) !void { + const rc = std.c.mprotect( + @alignCast(@ptrCast(self.buf.ptr)), + self.buf.len, + prot_rx, + ); + if (rc != 0) return error.MProtectFailed; + if (is_aarch64) flushICache(self.buf.ptr, self.pos); + } + + pub fn funcPtr(self: *const CodeBuffer, comptime Fn: type, offset: usize) *const Fn { + return @ptrFromInt(@intFromPtr(self.buf.ptr) + offset); + } +}; + +// ── Tests ───────────────────────────────────────────────────────────────────── + +test "codebuf emit and finalize" { + var buf = try CodeBuffer.init(std.testing.allocator, 4096); + defer buf.deinit(); + + if (builtin.cpu.arch == .x86_64) { + // mov eax, 42; ret + buf.emitSlice(&.{ 0xB8, 42, 0, 0, 0 }); + buf.emit1(0xC3); + try buf.finalize(); + const fn_ptr = buf.funcPtr(fn () callconv(.c) i32, 0); + const result = fn_ptr(); + try std.testing.expectEqual(@as(i32, 42), result); + } else if (builtin.cpu.arch == .aarch64) { + // movz w0, #42; ret + // MOVZ W0, #42 = 0x52800540 (little-endian bytes: 0x40 0x05 0x80 0x52) + buf.emitU32Le(0x52800540); + buf.emitU32Le(0xD65F03C0); // ret + // Verify the bytes are correct + try std.testing.expectEqual(@as(u8, 0x40), buf.buf[0]); + try std.testing.expectEqual(@as(u8, 0x05), buf.buf[1]); + try std.testing.expectEqual(@as(u8, 0x80), buf.buf[2]); + try std.testing.expectEqual(@as(u8, 0x52), buf.buf[3]); + // Finalize (needed for the W^X transition on Apple Silicon) + try buf.finalize(); + // Execute: requires com.apple.security.cs.allow-jit entitlement on Apple Silicon. + // Zig test binaries on Apple Silicon are signed with this entitlement by default. + const fn_ptr = buf.funcPtr(fn () callconv(.c) i32, 0); + const result = fn_ptr(); + try std.testing.expectEqual(@as(i32, 42), result); + } +} + +test "codebuf cursor and patch" { + var buf = try CodeBuffer.init(std.testing.allocator, 4096); + defer buf.deinit(); + buf.emitU32Le(0xDEADBEEF); + const patch_pos = buf.cursor(); + buf.emitU32Le(0x00000000); + buf.emitU32Le(0xCAFEBABE); + buf.patchU32(patch_pos, 0x12345678); + try std.testing.expectEqual( + @as(u32, 0x12345678), + std.mem.readInt(u32, buf.buf[patch_pos..][0..4], .little), + ); +} diff --git a/src/wasm/jit/codegen.zig b/src/wasm/jit/codegen.zig new file mode 100644 index 0000000..0fe30ad --- /dev/null +++ b/src/wasm/jit/codegen.zig @@ -0,0 +1,24 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const module = @import("../module.zig"); +const aarch64 = @import("aarch64.zig"); +const x86_64 = @import("x86_64.zig"); + +pub const JitResult = aarch64.JitResult; +pub const HelperAddrs = aarch64.HelperAddrs; + +pub fn compileSimpleI32( + allocator: std.mem.Allocator, + mod: *const module.Module, + num_imported_funcs: u32, + current_func_idx: u32, + body: *const module.FunctionBody, + ft: *const module.FuncType, + helpers: HelperAddrs, +) !?JitResult { + return switch (builtin.cpu.arch) { + .aarch64 => try aarch64.compileFunctionI32(allocator, mod, num_imported_funcs, current_func_idx, body, ft, helpers), + .x86_64 => try x86_64.compileFunctionI32(allocator, mod, num_imported_funcs, current_func_idx, body, ft, helpers), + else => null, + }; +} diff --git a/src/wasm/jit/liveness.zig b/src/wasm/jit/liveness.zig new file mode 100644 index 0000000..10cf286 --- /dev/null +++ b/src/wasm/jit/liveness.zig @@ -0,0 +1,75 @@ +/// Phase 4.2 — Liveness: compute live ranges for each virtual register. +const std = @import("std"); +const module = @import("../module.zig"); +const stackify = @import("stackify.zig"); + +pub const LiveRange = struct { + vreg: stackify.VReg, + valtype: module.ValType, + start: u32, // instruction index of definition + end: u32, // instruction index of last use +}; + +pub fn computeLiveRanges( + allocator: std.mem.Allocator, + instrs: []const stackify.AnnotatedInstr, +) ![]LiveRange { + // Map from VReg -> LiveRange index + var range_map = std.AutoHashMap(stackify.VReg, usize).init(allocator); + defer range_map.deinit(); + + var ranges: std.ArrayList(LiveRange) = .empty; + errdefer ranges.deinit(allocator); + + for (instrs, 0..) |instr, idx| { + const i: u32 = @intCast(idx); + + // Process pushes (definitions) + for (instr.effect.pushes) |vr| { + const vt = instr.result_type orelse .i32; + const range_idx = ranges.items.len; + try ranges.append(allocator, .{ + .vreg = vr, + .valtype = vt, + .start = i, + .end = i, + }); + try range_map.put(vr, range_idx); + } + + // Process pops (uses) — extend live range + for (instr.effect.pops) |vr| { + if (range_map.get(vr)) |range_idx| { + ranges.items[range_idx].end = i; + } + } + } + + return ranges.toOwnedSlice(allocator); +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +test "live ranges for add sequence" { + const ally = std.testing.allocator; + // Simulate: i32.const -> vr0, i32.const -> vr1, add(vr1,vr0) -> vr2 + const instrs = [_]stackify.AnnotatedInstr{ + .{ .opcode = 0x41, .imm = .{ .i32 = 1 }, .effect = .{ .pushes = &[_]stackify.VReg{0} }, .instr_idx = 0, .result_type = .i32 }, + .{ .opcode = 0x41, .imm = .{ .i32 = 2 }, .effect = .{ .pushes = &[_]stackify.VReg{1} }, .instr_idx = 1, .result_type = .i32 }, + .{ .opcode = 0x6A, .imm = .none, .effect = .{ .pops = &[_]stackify.VReg{ 1, 0 }, .pushes = &[_]stackify.VReg{2} }, .instr_idx = 2, .result_type = .i32 }, + }; + + const ranges = try computeLiveRanges(ally, &instrs); + defer ally.free(ranges); + + try std.testing.expectEqual(@as(usize, 3), ranges.len); + // vr0: defined at 0, last used at 2 + try std.testing.expectEqual(@as(u32, 0), ranges[0].start); + try std.testing.expectEqual(@as(u32, 2), ranges[0].end); + // vr1: defined at 1, last used at 2 + try std.testing.expectEqual(@as(u32, 1), ranges[1].start); + try std.testing.expectEqual(@as(u32, 2), ranges[1].end); + // vr2: defined at 2, never used -> end = 2 + try std.testing.expectEqual(@as(u32, 2), ranges[2].start); + try std.testing.expectEqual(@as(u32, 2), ranges[2].end); +} diff --git a/src/wasm/jit/regalloc.zig b/src/wasm/jit/regalloc.zig new file mode 100644 index 0000000..9076f1d --- /dev/null +++ b/src/wasm/jit/regalloc.zig @@ -0,0 +1,193 @@ +const std = @import("std"); +const module = @import("../module.zig"); +const liveness = @import("liveness.zig"); + +pub const PhysReg = u8; + +pub const ArchDesc = struct { + int_regs: []const PhysReg, + float_regs: []const PhysReg, + scratch_int: PhysReg, + scratch_float: PhysReg, +}; + +/// x86-64: System V AMD64 ABI caller-saved registers +pub const x86_64_desc = ArchDesc{ + .int_regs = &.{ 1, 2, 4, 5, 6, 7, 8, 9 }, // rcx,rdx,rsi,rdi,r8-r11 + .float_regs = &.{ 0, 1, 2, 3, 4, 5, 6, 7 }, // xmm0-xmm7 + .scratch_int = 0, // rax + .scratch_float = 8, // xmm8 +}; + +/// AArch64: AAPCS64 caller-saved registers +pub const aarch64_desc = ArchDesc{ + .int_regs = &.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, // x1-x15 + .float_regs = &.{ 0, 1, 2, 3, 4, 5, 6, 7 }, // v0-v7 + .scratch_int = 16, // x16 (IP0) + .scratch_float = 16, // v16 +}; + +pub const Location = union(enum) { + reg: PhysReg, + spill: u32, // byte offset from frame base +}; + +pub const Allocation = struct { + map: std.AutoHashMap(u32, Location), + + pub fn locationOf(self: *const Allocation, vr: u32) Location { + return self.map.get(vr) orelse .{ .spill = 0 }; + } + + pub fn deinit(self: *Allocation) void { + self.map.deinit(); + self.* = undefined; + } +}; + +fn isFloat(vt: module.ValType) bool { + return vt == .f32 or vt == .f64; +} + +pub fn allocate( + allocator: std.mem.Allocator, + ranges: []liveness.LiveRange, + arch: ArchDesc, +) !Allocation { + // Sort ranges by start point + const sorted = try allocator.dupe(liveness.LiveRange, ranges); + defer allocator.free(sorted); + std.mem.sort(liveness.LiveRange, sorted, {}, struct { + fn lt(_: void, a: liveness.LiveRange, b: liveness.LiveRange) bool { + return a.start < b.start; + } + }.lt); + + var map = std.AutoHashMap(u32, Location).init(allocator); + errdefer map.deinit(); + + // Free register lists + var free_ints: std.ArrayList(PhysReg) = .empty; + defer free_ints.deinit(allocator); + var free_floats: std.ArrayList(PhysReg) = .empty; + defer free_floats.deinit(allocator); + + // Add registers in reverse so we pop from the front (lowest index first) + var i: usize = arch.int_regs.len; + while (i > 0) { i -= 1; try free_ints.append(allocator, arch.int_regs[i]); } + i = arch.float_regs.len; + while (i > 0) { i -= 1; try free_floats.append(allocator, arch.float_regs[i]); } + + // Active ranges (sorted by end point) + const ActiveRange = struct { + vreg: u32, + end: u32, + reg: PhysReg, + is_float: bool, + }; + var active: std.ArrayList(ActiveRange) = .empty; + defer active.deinit(allocator); + + var spill_offset: u32 = 0; + + for (sorted) |range| { + // Expire old intervals + var j: usize = 0; + while (j < active.items.len) { + const ar = active.items[j]; + if (ar.end < range.start) { + // Free this register + if (ar.is_float) { + try free_floats.append(allocator, ar.reg); + } else { + try free_ints.append(allocator, ar.reg); + } + _ = active.orderedRemove(j); + } else { + j += 1; + } + } + + const float = isFloat(range.valtype); + const free_list = if (float) &free_floats else &free_ints; + + if (free_list.items.len > 0) { + const reg = free_list.pop().?; + try map.put(range.vreg, .{ .reg = reg }); + // Insert into active sorted by end + const new_active = ActiveRange{ .vreg = range.vreg, .end = range.end, .reg = reg, .is_float = float }; + var ins: usize = 0; + while (ins < active.items.len and active.items[ins].end <= range.end) ins += 1; + try active.insert(allocator, ins, new_active); + } else { + // Spill: evict the active range with furthest end + const last_idx = if (active.items.len > 0) active.items.len - 1 else { + // No active ranges — just spill this one + const size: u32 = if (range.valtype == .i32 or range.valtype == .f32) 4 else 8; + spill_offset += size; + try map.put(range.vreg, .{ .spill = spill_offset }); + continue; + }; + const spill_candidate = active.items[last_idx]; + + if (spill_candidate.end > range.end and spill_candidate.is_float == float) { + // Evict spill_candidate, assign its register to current + const reg = spill_candidate.reg; + const size: u32 = if (isFloat(range.valtype)) 8 else 8; + _ = size; + spill_offset += if (range.valtype == .i32 or range.valtype == .f32) 4 else 8; + try map.put(spill_candidate.vreg, .{ .spill = spill_offset }); + try map.put(range.vreg, .{ .reg = reg }); + _ = active.orderedRemove(last_idx); + const new_active = ActiveRange{ .vreg = range.vreg, .end = range.end, .reg = reg, .is_float = float }; + var ins: usize = 0; + while (ins < active.items.len and active.items[ins].end <= range.end) ins += 1; + try active.insert(allocator, ins, new_active); + } else { + // Spill current range + spill_offset += if (range.valtype == .i32 or range.valtype == .f32) 4 else 8; + try map.put(range.vreg, .{ .spill = spill_offset }); + } + } + } + + return Allocation{ .map = map }; +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +test "regalloc fits in registers" { + const ally = std.testing.allocator; + var ranges = [_]liveness.LiveRange{ + .{ .vreg = 0, .valtype = .i32, .start = 0, .end = 2 }, + .{ .vreg = 1, .valtype = .i32, .start = 1, .end = 3 }, + .{ .vreg = 2, .valtype = .i32, .start = 2, .end = 4 }, + }; + var alloc = try allocate(ally, &ranges, x86_64_desc); + defer alloc.deinit(); + + // All three should be in registers (we have 8 int regs available) + for (0..3) |j| { + const loc = alloc.locationOf(@intCast(j)); + try std.testing.expect(loc == .reg); + } +} + +test "regalloc spills under pressure" { + const ally = std.testing.allocator; + // Create more live ranges than available registers (8 int regs) + var ranges: [10]liveness.LiveRange = undefined; + for (&ranges, 0..) |*r, j| { + r.* = .{ .vreg = @intCast(j), .valtype = .i32, .start = @intCast(j), .end = 20 }; + } + var alloc = try allocate(ally, &ranges, x86_64_desc); + defer alloc.deinit(); + + // Count spills — should have at least some + var spill_count: usize = 0; + for (0..10) |j| { + const loc = alloc.locationOf(@intCast(j)); + if (loc == .spill) spill_count += 1; + } + try std.testing.expect(spill_count > 0); +} diff --git a/src/wasm/jit/stackify.zig b/src/wasm/jit/stackify.zig new file mode 100644 index 0000000..7f9cdc8 --- /dev/null +++ b/src/wasm/jit/stackify.zig @@ -0,0 +1,965 @@ +const std = @import("std"); +const module = @import("../module.zig"); +const binary = @import("../binary.zig"); + +pub const VReg = u32; + +pub const StackifyError = error{ + TypeMismatch, + StackUnderflow, + UndefinedFunction, + UndefinedLocal, + UndefinedGlobal, + UndefinedMemory, + UndefinedTable, + InvalidLabelDepth, + ImmutableGlobal, + InvalidTypeIndex, + InvalidFunctionIndex, + ElseWithoutIf, + InvalidAlignment, + UnsupportedOpcode, + InvalidValueType, + OutOfMemory, + UnexpectedEof, +}; + +pub const StackEffect = struct { + pops: []const VReg = &.{}, + pushes: []const VReg = &.{}, + + pub fn deinit(self: *StackEffect, allocator: std.mem.Allocator) void { + allocator.free(self.pops); + allocator.free(self.pushes); + self.* = .{}; + } +}; + +pub const Immediate = union(enum) { + none, + u32: u32, + u64: u64, + i32: i32, + i64: i64, + f32: f32, + f64: f64, + two_u32: struct { a: u32, b: u32 }, + br_table: []u32, + mem: struct { @"align": u32, offset: u32 }, +}; + +pub const AnnotatedInstr = struct { + opcode: u8, + imm: Immediate, + effect: StackEffect, + instr_idx: u32, + result_type: ?module.ValType = null, + + pub fn deinit(self: *AnnotatedInstr, allocator: std.mem.Allocator) void { + if (self.imm == .br_table) allocator.free(self.imm.br_table); + self.effect.deinit(allocator); + } +}; + +pub fn deinitInstrs(allocator: std.mem.Allocator, instrs: []AnnotatedInstr) void { + for (instrs) |*ins| ins.deinit(allocator); + allocator.free(instrs); +} + +const StackVal = struct { + vreg: VReg, + valtype: module.ValType, +}; + +const Frame = struct { + kind: Kind, + start_height: usize, + label_types: []const module.ValType, + result_types: []const module.ValType, + reachable: bool, + const Kind = enum { block, loop, @"if", @"else" }; +}; + +/// Walk bytecode, simulate typed operand stack, assign vRegs to each produced value. +/// Returns a slice of AnnotatedInstr owned by the caller. Use `deinitInstrs` to free it. +pub fn stackify( + allocator: std.mem.Allocator, + body: *const module.FunctionBody, + func_type: *const module.FuncType, + mod: *const module.Module, +) StackifyError![]AnnotatedInstr { + var imported_globals: std.ArrayList(module.GlobalType) = .empty; + defer imported_globals.deinit(allocator); + + var num_imported_funcs: u32 = 0; + var total_tables: u32 = 0; + var total_memories: u32 = 0; + + for (mod.imports) |imp| { + switch (imp.desc) { + .func => num_imported_funcs += 1, + .table => total_tables += 1, + .memory => total_memories += 1, + .global => |gt| try imported_globals.append(allocator, gt), + } + } + total_tables += @as(u32, @intCast(mod.tables.len)); + total_memories += @as(u32, @intCast(mod.memories.len)); + const total_funcs: u32 = num_imported_funcs + @as(u32, @intCast(mod.functions.len)); + + var local_types: std.ArrayList(module.ValType) = .empty; + defer local_types.deinit(allocator); + try local_types.appendSlice(allocator, func_type.params); + for (body.locals) |decl| for (0..decl.count) |_| try local_types.append(allocator, decl.valtype); + + var instrs: std.ArrayList(AnnotatedInstr) = .empty; + errdefer { + for (instrs.items) |*ins| ins.deinit(allocator); + instrs.deinit(allocator); + } + + var stack: std.ArrayList(StackVal) = .empty; + defer stack.deinit(allocator); + + var frames: std.ArrayList(Frame) = .empty; + defer frames.deinit(allocator); + try frames.append(allocator, .{ + .kind = .block, + .start_height = 0, + .label_types = func_type.results, + .result_types = func_type.results, + .reachable = true, + }); + + var tmp_pops: std.ArrayList(VReg) = .empty; + defer tmp_pops.deinit(allocator); + var tmp_pushes: std.ArrayList(VReg) = .empty; + defer tmp_pushes.deinit(allocator); + + var next_vreg: VReg = 0; + var pos: usize = 0; + const code = body.code; + var instr_idx: u32 = 0; + + while (pos < code.len) { + const op = code[pos]; + pos += 1; + + tmp_pops.clearRetainingCapacity(); + tmp_pushes.clearRetainingCapacity(); + + var ann = AnnotatedInstr{ + .opcode = op, + .imm = .none, + .effect = .{}, + .instr_idx = instr_idx, + .result_type = null, + }; + instr_idx += 1; + + const frame = &frames.items[frames.items.len - 1]; + const reachable = frame.reachable; + + switch (op) { + 0x00 => { // unreachable + if (reachable) { + stack.shrinkRetainingCapacity(frame.start_height); + frame.reachable = false; + } + }, + 0x01 => {}, + 0x02 => { + const bt = try readBlockType(code, &pos); + const res = try blockTypeResults(mod, bt); + try frames.append(allocator, .{ + .kind = .block, + .start_height = stack.items.len, + .label_types = res, + .result_types = res, + .reachable = reachable, + }); + }, + 0x03 => { + const bt = try readBlockType(code, &pos); + const params = try blockTypeParams(mod, bt); + const res = try blockTypeResults(mod, bt); + try frames.append(allocator, .{ + .kind = .loop, + .start_height = stack.items.len, + .label_types = params, + .result_types = res, + .reachable = reachable, + }); + }, + 0x04 => { + const bt = try readBlockType(code, &pos); + const res = try blockTypeResults(mod, bt); + if (reachable) _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + try frames.append(allocator, .{ + .kind = .@"if", + .start_height = stack.items.len, + .label_types = res, + .result_types = res, + .reachable = reachable, + }); + ann.imm = .{ .i32 = @intCast(bt) }; + }, + 0x05 => { + const cur = &frames.items[frames.items.len - 1]; + if (cur.kind != .@"if") return error.ElseWithoutIf; + if (cur.reachable) try checkStackTypes(&stack, cur.start_height, cur.result_types); + stack.shrinkRetainingCapacity(cur.start_height); + cur.kind = .@"else"; + cur.reachable = frames.items[frames.items.len - 2].reachable; + }, + 0x0B => { + if (frames.items.len == 1) { + if (frames.items[0].reachable) try checkStackTypes(&stack, 0, frames.items[0].result_types); + if (!frames.items[0].reachable) { + // If the function tail is polymorphic-unreachable, materialize typed results. + try emitMergeResults(allocator, &stack, &tmp_pushes, frames.items[0].result_types, &next_vreg, &ann.result_type); + } else if (frames.items[0].result_types.len == 1) { + ann.result_type = frames.items[0].result_types[0]; + } + frame.reachable = true; + pos = code.len; + } else { + const cur = frames.pop().?; + if (cur.reachable) { + try preserveBlockResults(allocator, &stack, cur.start_height, cur.result_types); + if (cur.result_types.len == 1) ann.result_type = cur.result_types[0]; + } else { + stack.shrinkRetainingCapacity(cur.start_height); + try emitMergeResults(allocator, &stack, &tmp_pushes, cur.result_types, &next_vreg, &ann.result_type); + } + } + }, + 0x0C => { + const depth = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = depth }; + if (depth >= frames.items.len) return error.InvalidLabelDepth; + if (reachable) { + const target = &frames.items[frames.items.len - 1 - depth]; + try popLabelTypes(allocator, &stack, &tmp_pops, frame.start_height, target.label_types); + } + stack.shrinkRetainingCapacity(frame.start_height); + frame.reachable = false; + }, + 0x0D => { + const depth = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = depth }; + if (depth >= frames.items.len) return error.InvalidLabelDepth; + if (reachable) { + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + const target = &frames.items[frames.items.len - 1 - depth]; + try checkLabelTypes(&stack, frame.start_height, target.label_types); + } + }, + 0x0E => { + const n = try readULEB128(u32, code, &pos); + const entries = try allocator.alloc(u32, n + 1); + errdefer allocator.free(entries); + var label_types: ?[]const module.ValType = null; + var i: u32 = 0; + while (i <= n) : (i += 1) { + const depth = try readULEB128(u32, code, &pos); + entries[i] = depth; + if (depth >= frames.items.len) return error.InvalidLabelDepth; + const target = &frames.items[frames.items.len - 1 - depth]; + if (label_types == null) { + label_types = target.label_types; + } else if (!sameValTypeSlice(label_types.?, target.label_types)) { + return error.TypeMismatch; + } + } + ann.imm = .{ .br_table = entries }; + if (reachable) _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + stack.shrinkRetainingCapacity(frame.start_height); + frame.reachable = false; + }, + 0x0F => { + if (reachable) try popLabelTypes(allocator, &stack, &tmp_pops, frame.start_height, frames.items[0].result_types); + stack.shrinkRetainingCapacity(frame.start_height); + frame.reachable = false; + }, + 0x10 => { + const fidx = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = fidx }; + if (fidx >= total_funcs) return error.UndefinedFunction; + const ft = try getFuncType(mod, fidx, num_imported_funcs); + if (reachable) { + try popTypesReverse(allocator, &stack, &tmp_pops, frame.start_height, ft.params); + try pushResultTypes(allocator, &stack, &tmp_pushes, ft.results, &next_vreg, &ann.result_type); + } + }, + 0x11 => { + const type_idx = try readULEB128(u32, code, &pos); + const table_idx = try readULEB128(u32, code, &pos); + ann.imm = .{ .two_u32 = .{ .a = type_idx, .b = table_idx } }; + if (type_idx >= mod.types.len) return error.InvalidTypeIndex; + if (table_idx >= total_tables) return error.UndefinedTable; + if (reachable) { + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + const ft = &mod.types[type_idx]; + try popTypesReverse(allocator, &stack, &tmp_pops, frame.start_height, ft.params); + try pushResultTypes(allocator, &stack, &tmp_pushes, ft.results, &next_vreg, &ann.result_type); + } + }, + 0x1A => { + if (reachable) _ = try popAnyVReg(allocator, &stack, &tmp_pops, frame.start_height); + }, + 0x1B => { + if (reachable) { + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + const rhs = try popAnyVReg(allocator, &stack, &tmp_pops, frame.start_height); + const lhs = try popAnyVReg(allocator, &stack, &tmp_pops, frame.start_height); + if (lhs.valtype != rhs.valtype) return error.TypeMismatch; + try pushExisting(allocator, &stack, &tmp_pushes, lhs.vreg, lhs.valtype); + ann.result_type = lhs.valtype; + } + }, + 0x1C => { + const n = try readULEB128(u32, code, &pos); + if (n != 1) return error.TypeMismatch; + const t = try decodeValType(try readByte(code, &pos)); + if (reachable) { + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, t); + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, t); + try pushNew(allocator, &stack, &tmp_pushes, t, &next_vreg); + ann.result_type = t; + } + }, + 0x20 => { + const idx = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = idx }; + if (idx >= local_types.items.len) return error.UndefinedLocal; + if (reachable) { + const vt = local_types.items[idx]; + try pushNew(allocator, &stack, &tmp_pushes, vt, &next_vreg); + ann.result_type = vt; + } + }, + 0x21 => { + const idx = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = idx }; + if (idx >= local_types.items.len) return error.UndefinedLocal; + if (reachable) _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, local_types.items[idx]); + }, + 0x22 => { + const idx = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = idx }; + if (idx >= local_types.items.len) return error.UndefinedLocal; + if (reachable) { + const v = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, local_types.items[idx]); + try pushExisting(allocator, &stack, &tmp_pushes, v.vreg, v.valtype); + ann.result_type = v.valtype; + } + }, + 0x23 => { + const idx = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = idx }; + const gt = try getGlobalType(mod, imported_globals.items, idx); + if (reachable) { + try pushNew(allocator, &stack, &tmp_pushes, gt.valtype, &next_vreg); + ann.result_type = gt.valtype; + } + }, + 0x24 => { + const idx = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = idx }; + const gt = try getGlobalType(mod, imported_globals.items, idx); + if (!gt.mutable) return error.ImmutableGlobal; + if (reachable) _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, gt.valtype); + }, + 0x28...0x35 => { + const mem_align = try readULEB128(u32, code, &pos); + const offset = try readULEB128(u32, code, &pos); + ann.imm = .{ .mem = .{ .@"align" = mem_align, .offset = offset } }; + if (total_memories == 0) return error.UndefinedMemory; + if (mem_align > naturalAlignmentLog2ForLoad(op)) return error.InvalidAlignment; + if (reachable) { + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + const rt = memLoadResultType(op); + try pushNew(allocator, &stack, &tmp_pushes, rt, &next_vreg); + ann.result_type = rt; + } + }, + 0x36...0x3E => { + const mem_align = try readULEB128(u32, code, &pos); + const offset = try readULEB128(u32, code, &pos); + ann.imm = .{ .mem = .{ .@"align" = mem_align, .offset = offset } }; + if (total_memories == 0) return error.UndefinedMemory; + if (mem_align > naturalAlignmentLog2ForStore(op)) return error.InvalidAlignment; + if (reachable) { + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, memStoreValType(op)); + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + } + }, + 0x3F, 0x40 => { + _ = try readByte(code, &pos); + if (total_memories == 0) return error.UndefinedMemory; + if (reachable) { + if (op == 0x40) _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + try pushNew(allocator, &stack, &tmp_pushes, .i32, &next_vreg); + ann.result_type = .i32; + } + }, + 0x41 => { + const val = try readSLEB128(i32, code, &pos); + ann.imm = .{ .i32 = val }; + if (reachable) { + try pushNew(allocator, &stack, &tmp_pushes, .i32, &next_vreg); + ann.result_type = .i32; + } + }, + 0x42 => { + const val = try readSLEB128(i64, code, &pos); + ann.imm = .{ .i64 = val }; + if (reachable) { + try pushNew(allocator, &stack, &tmp_pushes, .i64, &next_vreg); + ann.result_type = .i64; + } + }, + 0x43 => { + if (pos + 4 > code.len) return error.UnexpectedEof; + const raw = std.mem.readInt(u32, code[pos..][0..4], .little); + pos += 4; + ann.imm = .{ .f32 = @bitCast(raw) }; + if (reachable) { + try pushNew(allocator, &stack, &tmp_pushes, .f32, &next_vreg); + ann.result_type = .f32; + } + }, + 0x44 => { + if (pos + 8 > code.len) return error.UnexpectedEof; + const raw = std.mem.readInt(u64, code[pos..][0..8], .little); + pos += 8; + ann.imm = .{ .f64 = @bitCast(raw) }; + if (reachable) { + try pushNew(allocator, &stack, &tmp_pushes, .f64, &next_vreg); + ann.result_type = .f64; + } + }, + 0x45 => if (reachable) try unaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .i32, .i32, &next_vreg, &ann), + 0x46...0x4F => if (reachable) try binaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .i32, .i32, &next_vreg, &ann), + 0x50 => if (reachable) try unaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .i64, .i32, &next_vreg, &ann), + 0x51...0x5A => if (reachable) try binaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .i64, .i32, &next_vreg, &ann), + 0x5B...0x60 => if (reachable) try binaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .f32, .i32, &next_vreg, &ann), + 0x61...0x66 => if (reachable) try binaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .f64, .i32, &next_vreg, &ann), + 0x67...0x69 => if (reachable) try unaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .i32, .i32, &next_vreg, &ann), + 0x6A...0x78 => if (reachable) try binaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .i32, .i32, &next_vreg, &ann), + 0x79...0x7B => if (reachable) try unaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .i64, .i64, &next_vreg, &ann), + 0x7C...0x8A => if (reachable) try binaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .i64, .i64, &next_vreg, &ann), + 0x8B...0x91 => if (reachable) try unaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .f32, .f32, &next_vreg, &ann), + 0x92...0x98 => if (reachable) try binaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .f32, .f32, &next_vreg, &ann), + 0x99...0x9F => if (reachable) try unaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .f64, .f64, &next_vreg, &ann), + 0xA0...0xA6 => if (reachable) try binaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .f64, .f64, &next_vreg, &ann), + 0xA7...0xBF => if (reachable) try conversionOp(allocator, op, &stack, &tmp_pops, &tmp_pushes, frame.start_height, &next_vreg, &ann), + 0xC0, 0xC1 => if (reachable) try unaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .i32, .i32, &next_vreg, &ann), + 0xC2, 0xC3, 0xC4 => if (reachable) try unaryOp(allocator, &stack, &tmp_pops, &tmp_pushes, frame.start_height, .i64, .i64, &next_vreg, &ann), + 0xFC => { + const subop = try readULEB128(u32, code, &pos); + switch (subop) { + 0...7 => if (reachable) try truncSatOp(allocator, subop, &stack, &tmp_pops, &tmp_pushes, frame.start_height, &next_vreg, &ann), + 8 => { + const data_idx = try readULEB128(u32, code, &pos); + const mem_idx = try readULEB128(u32, code, &pos); + ann.imm = .{ .two_u32 = .{ .a = data_idx, .b = mem_idx } }; + if (data_idx >= mod.datas.len) return error.TypeMismatch; + if (total_memories == 0 or mem_idx != 0) return error.UndefinedMemory; + if (reachable) { + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + } + }, + 9 => { + const data_idx = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = data_idx }; + if (data_idx >= mod.datas.len) return error.TypeMismatch; + }, + 10 => { + const dst_mem = try readULEB128(u32, code, &pos); + const src_mem = try readULEB128(u32, code, &pos); + ann.imm = .{ .two_u32 = .{ .a = dst_mem, .b = src_mem } }; + if (total_memories == 0 or dst_mem != 0 or src_mem != 0) return error.UndefinedMemory; + if (reachable) { + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + } + }, + 11 => { + const mem_idx = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = mem_idx }; + if (total_memories == 0 or mem_idx != 0) return error.UndefinedMemory; + if (reachable) { + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + _ = try popExpectVReg(allocator, &stack, &tmp_pops, frame.start_height, .i32); + } + }, + 16 => { + const table_idx = try readULEB128(u32, code, &pos); + ann.imm = .{ .u32 = table_idx }; + if (table_idx >= total_tables) return error.UndefinedTable; + if (reachable) { + try pushNew(allocator, &stack, &tmp_pushes, .i32, &next_vreg); + ann.result_type = .i32; + } + }, + else => return error.UnsupportedOpcode, + } + }, + else => return error.UnsupportedOpcode, + } + + ann.effect = .{ + .pops = try allocator.dupe(VReg, tmp_pops.items), + .pushes = try allocator.dupe(VReg, tmp_pushes.items), + }; + errdefer ann.deinit(allocator); + try instrs.append(allocator, ann); + } + + return instrs.toOwnedSlice(allocator); +} + +fn pushNew( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + pushes: *std.ArrayList(VReg), + vt: module.ValType, + next_vreg: *VReg, +) StackifyError!void { + const vr = next_vreg.*; + next_vreg.* += 1; + try stack.append(allocator, .{ .vreg = vr, .valtype = vt }); + try pushes.append(allocator, vr); +} + +fn pushExisting( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + pushes: *std.ArrayList(VReg), + vr: VReg, + vt: module.ValType, +) StackifyError!void { + try stack.append(allocator, .{ .vreg = vr, .valtype = vt }); + try pushes.append(allocator, vr); +} + +fn popAnyVReg( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + pops: *std.ArrayList(VReg), + min_height: usize, +) StackifyError!StackVal { + if (stack.items.len <= min_height) return error.StackUnderflow; + const v = stack.pop().?; + try pops.append(allocator, v.vreg); + return v; +} + +fn popExpectVReg( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + pops: *std.ArrayList(VReg), + min_height: usize, + expected: module.ValType, +) StackifyError!StackVal { + const v = try popAnyVReg(allocator, stack, pops, min_height); + if (v.valtype != expected) return error.TypeMismatch; + return v; +} + +fn checkStackTypes(stack: *std.ArrayList(StackVal), base: usize, expected: []const module.ValType) StackifyError!void { + if (stack.items.len < base + expected.len) return error.StackUnderflow; + for (expected, 0..) |et, i| { + if (stack.items[base + i].valtype != et) return error.TypeMismatch; + } +} + +fn checkLabelTypes(stack: *std.ArrayList(StackVal), base: usize, expected: []const module.ValType) StackifyError!void { + if (stack.items.len < base + expected.len) return error.StackUnderflow; + const start = stack.items.len - expected.len; + if (start < base) return error.StackUnderflow; + for (expected, 0..) |et, i| if (stack.items[start + i].valtype != et) return error.TypeMismatch; +} + +fn popLabelTypes( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + pops: *std.ArrayList(VReg), + base: usize, + label_types: []const module.ValType, +) StackifyError!void { + var i: usize = label_types.len; + while (i > 0) : (i -= 1) { + _ = try popExpectVReg(allocator, stack, pops, base, label_types[i - 1]); + } +} + +fn popTypesReverse( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + pops: *std.ArrayList(VReg), + base: usize, + types: []const module.ValType, +) StackifyError!void { + var i: usize = types.len; + while (i > 0) : (i -= 1) { + _ = try popExpectVReg(allocator, stack, pops, base, types[i - 1]); + } +} + +fn pushResultTypes( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + pushes: *std.ArrayList(VReg), + results: []const module.ValType, + next_vreg: *VReg, + result_type: *?module.ValType, +) StackifyError!void { + if (results.len > 1) return error.UnsupportedOpcode; + for (results) |rt| { + try pushNew(allocator, stack, pushes, rt, next_vreg); + result_type.* = rt; + } +} + +fn emitMergeResults( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + pushes: *std.ArrayList(VReg), + results: []const module.ValType, + next_vreg: *VReg, + result_type: *?module.ValType, +) StackifyError!void { + if (results.len > 1) return error.UnsupportedOpcode; + for (results) |rt| { + try pushNew(allocator, stack, pushes, rt, next_vreg); + result_type.* = rt; + } +} + +fn preserveBlockResults( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + start_height: usize, + results: []const module.ValType, +) StackifyError!void { + if (results.len > 1) return error.UnsupportedOpcode; + try checkStackTypes(stack, start_height, results); + if (results.len == 0) { + stack.shrinkRetainingCapacity(start_height); + return; + } + const tail_start = stack.items.len - results.len; + const saved = try allocator.dupe(StackVal, stack.items[tail_start..]); + defer allocator.free(saved); + stack.shrinkRetainingCapacity(start_height); + try stack.appendSlice(allocator, saved); +} + +fn unaryOp( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + pops: *std.ArrayList(VReg), + pushes: *std.ArrayList(VReg), + base: usize, + in_t: module.ValType, + out_t: module.ValType, + next_vreg: *VReg, + ann: *AnnotatedInstr, +) StackifyError!void { + _ = try popExpectVReg(allocator, stack, pops, base, in_t); + try pushNew(allocator, stack, pushes, out_t, next_vreg); + ann.result_type = out_t; +} + +fn binaryOp( + allocator: std.mem.Allocator, + stack: *std.ArrayList(StackVal), + pops: *std.ArrayList(VReg), + pushes: *std.ArrayList(VReg), + base: usize, + in_t: module.ValType, + out_t: module.ValType, + next_vreg: *VReg, + ann: *AnnotatedInstr, +) StackifyError!void { + _ = try popExpectVReg(allocator, stack, pops, base, in_t); + _ = try popExpectVReg(allocator, stack, pops, base, in_t); + try pushNew(allocator, stack, pushes, out_t, next_vreg); + ann.result_type = out_t; +} + +fn conversionOp( + allocator: std.mem.Allocator, + op: u8, + stack: *std.ArrayList(StackVal), + pops: *std.ArrayList(VReg), + pushes: *std.ArrayList(VReg), + base: usize, + next_vreg: *VReg, + ann: *AnnotatedInstr, +) StackifyError!void { + const in_t: module.ValType = switch (op) { + 0xA7 => .i64, + 0xA8, 0xA9 => .f32, + 0xAA, 0xAB => .f64, + 0xAC, 0xAD => .i32, + 0xAE, 0xAF => .f32, + 0xB0, 0xB1 => .f64, + 0xB2, 0xB3 => .i32, + 0xB4, 0xB5 => .i64, + 0xB6 => .f64, + 0xB7, 0xB8 => .i32, + 0xB9, 0xBA => .i64, + 0xBB => .f32, + 0xBC => .f32, + 0xBD => .f64, + 0xBE => .i32, + 0xBF => .i64, + else => return error.UnsupportedOpcode, + }; + const out_t: module.ValType = convertResultType(op); + _ = try popExpectVReg(allocator, stack, pops, base, in_t); + try pushNew(allocator, stack, pushes, out_t, next_vreg); + ann.result_type = out_t; +} + +fn truncSatOp( + allocator: std.mem.Allocator, + subop: u32, + stack: *std.ArrayList(StackVal), + pops: *std.ArrayList(VReg), + pushes: *std.ArrayList(VReg), + base: usize, + next_vreg: *VReg, + ann: *AnnotatedInstr, +) StackifyError!void { + const in_t: module.ValType = switch (subop) { + 0, 1, 4, 5 => .f32, + 2, 3, 6, 7 => .f64, + else => return error.UnsupportedOpcode, + }; + const out_t: module.ValType = if (subop <= 3) .i32 else .i64; + _ = try popExpectVReg(allocator, stack, pops, base, in_t); + try pushNew(allocator, stack, pushes, out_t, next_vreg); + ann.result_type = out_t; +} + +fn convertResultType(op: u8) module.ValType { + return switch (op) { + 0xA7, 0xA8, 0xA9, 0xAA, 0xAB, 0xBC => .i32, + 0xAC, 0xAD, 0xAE, 0xAF, 0xB0, 0xB1, 0xBD => .i64, + 0xB2, 0xB3, 0xB4, 0xB5, 0xB6, 0xBE => .f32, + 0xB7, 0xB8, 0xB9, 0xBA, 0xBB, 0xBF => .f64, + else => .i32, + }; +} + +fn readByte(code: []const u8, pos: *usize) StackifyError!u8 { + if (pos.* >= code.len) return error.UnexpectedEof; + const b = code[pos.*]; + pos.* += 1; + return b; +} + +fn readULEB128(comptime T: type, code: []const u8, pos: *usize) StackifyError!T { + return binary.readULEB128(T, code, pos) catch |e| switch (e) { + error.UnexpectedEof => error.UnexpectedEof, + else => error.TypeMismatch, + }; +} + +fn readSLEB128(comptime T: type, code: []const u8, pos: *usize) StackifyError!T { + return binary.readSLEB128(T, code, pos) catch |e| switch (e) { + error.UnexpectedEof => error.UnexpectedEof, + else => error.TypeMismatch, + }; +} + +fn readBlockType(code: []const u8, pos: *usize) StackifyError!i33 { + return readSLEB128(i33, code, pos); +} + +fn decodeValType(b: u8) StackifyError!module.ValType { + return switch (b) { + 0x7F => .i32, + 0x7E => .i64, + 0x7D => .f32, + 0x7C => .f64, + else => error.InvalidValueType, + }; +} + +fn blockTypeResults(mod: *const module.Module, bt: i33) StackifyError![]const module.ValType { + return switch (bt) { + -1 => &[_]module.ValType{.i32}, + -2 => &[_]module.ValType{.i64}, + -3 => &[_]module.ValType{.f32}, + -4 => &[_]module.ValType{.f64}, + -64 => &.{}, + else => if (bt >= 0) blk: { + const idx: u32 = @intCast(bt); + if (idx >= mod.types.len) return error.InvalidTypeIndex; + break :blk mod.types[idx].results; + } else error.InvalidTypeIndex, + }; +} + +fn blockTypeParams(mod: *const module.Module, bt: i33) StackifyError![]const module.ValType { + if (bt < 0) return &.{}; + const idx: u32 = @intCast(bt); + if (idx >= mod.types.len) return error.InvalidTypeIndex; + return mod.types[idx].params; +} + +fn getFuncType(mod: *const module.Module, fidx: u32, num_imported: u32) StackifyError!*const module.FuncType { + if (fidx < num_imported) { + var count: u32 = 0; + for (mod.imports) |imp| { + if (imp.desc == .func) { + if (count == fidx) return &mod.types[imp.desc.func]; + count += 1; + } + } + return error.InvalidFunctionIndex; + } + const local_idx = fidx - num_imported; + if (local_idx >= mod.functions.len) return error.InvalidFunctionIndex; + const type_idx = mod.functions[local_idx]; + if (type_idx >= mod.types.len) return error.InvalidTypeIndex; + return &mod.types[type_idx]; +} + +fn getGlobalType(mod: *const module.Module, imported_globals: []const module.GlobalType, idx: u32) StackifyError!module.GlobalType { + if (idx < imported_globals.len) return imported_globals[idx]; + const local_idx = idx - @as(u32, @intCast(imported_globals.len)); + if (local_idx >= mod.globals.len) return error.UndefinedGlobal; + return mod.globals[local_idx].type; +} + +fn memLoadResultType(op: u8) module.ValType { + return switch (op) { + 0x28, 0x2C, 0x2D, 0x2E, 0x2F => .i32, + 0x29, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35 => .i64, + 0x2A => .f32, + 0x2B => .f64, + else => .i32, + }; +} + +fn memStoreValType(op: u8) module.ValType { + return switch (op) { + 0x36, 0x3A, 0x3B => .i32, + 0x37, 0x3C, 0x3D, 0x3E => .i64, + 0x38 => .f32, + 0x39 => .f64, + else => .i32, + }; +} + +fn naturalAlignmentLog2ForLoad(op: u8) u32 { + return switch (op) { + 0x28 => 2, + 0x29 => 3, + 0x2A => 2, + 0x2B => 3, + 0x2C, 0x2D => 0, + 0x2E, 0x2F => 1, + 0x30, 0x31 => 0, + 0x32, 0x33 => 1, + 0x34, 0x35 => 2, + else => 0, + }; +} + +fn naturalAlignmentLog2ForStore(op: u8) u32 { + return switch (op) { + 0x36 => 2, + 0x37 => 3, + 0x38 => 2, + 0x39 => 3, + 0x3A, 0x3C => 0, + 0x3B, 0x3D => 1, + 0x3E => 2, + else => 0, + }; +} + +fn sameValTypeSlice(a: []const module.ValType, b: []const module.ValType) bool { + if (a.len != b.len) return false; + for (a, 0..) |vt, i| if (vt != b[i]) return false; + return true; +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +test "stackify straight-line function" { + const code = [_]u8{ 0x41, 0x01, 0x41, 0x02, 0x6a, 0x0b }; + const body = module.FunctionBody{ .locals = &.{}, .code = &code }; + const ft = module.FuncType{ .params = &.{}, .results = &.{.i32} }; + const mod = module.Module{ + .types = &.{}, + .imports = &.{}, + .functions = &.{}, + .tables = &.{}, + .memories = &.{}, + .globals = &.{}, + .exports = &.{}, + .start = null, + .elements = &.{}, + .codes = &.{}, + .datas = &.{}, + .allocator = std.testing.allocator, + }; + const ally = std.testing.allocator; + const instrs = try stackify(ally, &body, &ft, &mod); + defer deinitInstrs(ally, instrs); + + try std.testing.expectEqual(@as(usize, 4), instrs.len); + const vr0 = instrs[0].effect.pushes[0]; + const vr1 = instrs[1].effect.pushes[0]; + try std.testing.expectEqual(@as(usize, 2), instrs[2].effect.pops.len); + try std.testing.expectEqual(vr1, instrs[2].effect.pops[0]); + try std.testing.expectEqual(vr0, instrs[2].effect.pops[1]); +} + +test "stackify call uses function signature" { + const code = [_]u8{ 0x41, 0x07, 0x10, 0x00, 0x0b }; + const body = module.FunctionBody{ .locals = &.{}, .code = &code }; + + var callee_params = [_]module.ValType{.i32}; + var callee_results = [_]module.ValType{.i64}; + var types = [_]module.FuncType{.{ .params = &callee_params, .results = &callee_results }}; + var funcs = [_]u32{0}; + var codes = [_]module.FunctionBody{body}; + + const mod = module.Module{ + .types = &types, + .imports = &.{}, + .functions = &funcs, + .tables = &.{}, + .memories = &.{}, + .globals = &.{}, + .exports = &.{}, + .start = null, + .elements = &.{}, + .codes = &codes, + .datas = &.{}, + .allocator = std.testing.allocator, + }; + + const ally = std.testing.allocator; + const instrs = try stackify(ally, &body, &types[0], &mod); + defer deinitInstrs(ally, instrs); + + try std.testing.expectEqual(@as(usize, 1), instrs[1].effect.pops.len); + try std.testing.expectEqual(@as(usize, 1), instrs[1].effect.pushes.len); + try std.testing.expectEqual(module.ValType.i64, instrs[1].result_type.?); +} diff --git a/src/wasm/jit/x86_64.zig b/src/wasm/jit/x86_64.zig new file mode 100644 index 0000000..edf130f --- /dev/null +++ b/src/wasm/jit/x86_64.zig @@ -0,0 +1,1163 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const binary = @import("../binary.zig"); +const module = @import("../module.zig"); +const codebuf = @import("codebuf.zig"); +const aarch64 = @import("aarch64.zig"); + +pub const JitResult = aarch64.JitResult; +pub const HelperAddrs = aarch64.HelperAddrs; + +const frame_size_bytes: usize = 0x408; +const local_base_bytes: usize = 32; + +const EndKind = enum { hit_end, hit_else }; +const ControlKind = enum { block, loop, @"if" }; + +const ControlFrame = struct { + kind: ControlKind, + entry_depth: usize, + label_arity: u8, + label_type: ?module.ValType, + end_arity: u8, + end_type: ?module.ValType, + loop_head_pos: usize, + end_patches: std.ArrayList(usize), +}; + +const Context = struct { + allocator: std.mem.Allocator, + mod: *const module.Module, + num_imported_funcs: u32, + helpers: HelperAddrs, + buf: *codebuf.CodeBuffer, + stack_depth: usize, + max_stack_depth: usize, + local_count: u32, + result_type: ?module.ValType, + local_types: []const module.ValType, + control: std.ArrayList(ControlFrame), +}; + +pub fn compileFunctionI32( + allocator: std.mem.Allocator, + mod: *const module.Module, + num_imported_funcs: u32, + current_func_idx: u32, + body: *const module.FunctionBody, + ft: *const module.FuncType, + helpers: HelperAddrs, +) !?JitResult { + if (builtin.cpu.arch != .x86_64) return null; + _ = current_func_idx; + if (ft.results.len > 1) return null; + for (ft.params) |p| if (!(p == .i32 or p == .i64 or p == .f32 or p == .f64)) return null; + + var local_count: usize = ft.params.len; + for (body.locals) |decl| { + if (!(decl.valtype == .i32 or decl.valtype == .i64 or decl.valtype == .f32 or decl.valtype == .f64)) return null; + local_count += decl.count; + } + var local_types = try allocator.alloc(module.ValType, local_count); + defer allocator.free(local_types); + for (ft.params, 0..) |p, i| local_types[i] = p; + var lt_i: usize = ft.params.len; + for (body.locals) |decl| { + var j: u32 = 0; + while (j < decl.count) : (j += 1) { + local_types[lt_i] = decl.valtype; + lt_i += 1; + } + } + + const operand_base_bytes = std.mem.alignForward(usize, local_base_bytes + local_count * 8, 16); + if (operand_base_bytes >= frame_size_bytes) return null; + const max_stack_depth = (frame_size_bytes - operand_base_bytes) / 8; + if (max_stack_depth == 0) return null; + if (local_base_bytes + local_count * 8 > frame_size_bytes) return null; + + var buf = try codebuf.CodeBuffer.init(allocator, 8192); + errdefer buf.deinit(); + + emitPrologue(&buf, @intCast(ft.params.len), @intCast(local_count), @intCast(operand_base_bytes)); + + var cx = Context{ + .allocator = allocator, + .mod = mod, + .num_imported_funcs = num_imported_funcs, + .helpers = helpers, + .buf = &buf, + .stack_depth = 0, + .max_stack_depth = max_stack_depth, + .local_count = @intCast(local_count), + .result_type = if (ft.results.len == 1) ft.results[0] else null, + .local_types = local_types, + .control = .empty, + }; + defer { + for (cx.control.items) |*fr| fr.end_patches.deinit(allocator); + cx.control.deinit(allocator); + } + + const fn_arity: u8 = if (ft.results.len == 1) 1 else 0; + const fn_type: ?module.ValType = if (ft.results.len == 1) ft.results[0] else null; + try cx.control.append(allocator, .{ + .kind = .block, + .entry_depth = 0, + .label_arity = fn_arity, + .label_type = fn_type, + .end_arity = fn_arity, + .end_type = fn_type, + .loop_head_pos = 0, + .end_patches = .empty, + }); + + var pos: usize = 0; + const end_kind = compileBlock(&cx, body.code, &pos, false) catch return null; + if (end_kind != .hit_end) return null; + + if (ft.results.len == 1) { + if (cx.stack_depth != 1) return null; + switch (ft.results[0]) { + .i32, .f32 => try popW(&cx, @intFromEnum(Reg.rax)), + .i64, .f64 => try popX(&cx, @intFromEnum(Reg.rax)), + } + } else { + if (cx.stack_depth != 0) return null; + } + + emitEpilogueAndRet(&buf); + try buf.finalize(); + return .{ .buf = buf, .arity = 0 }; +} + +fn compileBlock(cx: *Context, code: []const u8, pos: *usize, allow_else: bool) !EndKind { + while (pos.* < code.len) { + const op = code[pos.*]; + pos.* += 1; + switch (op) { + 0x0B => { + const fr = currentFrame(cx); + const end_pos = cx.buf.cursor(); + for (fr.end_patches.items) |patch_imm_pos| patchRel32(cx.buf, patch_imm_pos, end_pos); + try setStackDepth(cx, fr.entry_depth + fr.end_arity); + return .hit_end; + }, + 0x05 => { + if (!allow_else) return error.MalformedControlFlow; + return .hit_else; + }, + 0x0C => { + const depth = try binary.readULEB128(u32, code, pos); + try emitBrToDepth(cx, depth); + }, + 0x0D => { + const depth = try binary.readULEB128(u32, code, pos); + try popW(cx, @intFromEnum(Reg.r11)); + emitTestWReg(cx.buf, @intFromEnum(Reg.r11), @intFromEnum(Reg.r11)); + const not_taken = emitJccPlaceholder(cx.buf, .z); + const fallthrough_depth = cx.stack_depth; + try emitBrToDepth(cx, depth); + cx.stack_depth = fallthrough_depth; + patchRel32(cx.buf, not_taken, cx.buf.cursor()); + }, + 0x0E => { + const n = try binary.readULEB128(u32, code, pos); + const table = try cx.allocator.alloc(u32, n + 1); + defer cx.allocator.free(table); + for (table) |*d| d.* = try binary.readULEB128(u32, code, pos); + + try popW(cx, @intFromEnum(Reg.r11)); + const fallthrough_depth = cx.stack_depth; + var i: u32 = 0; + while (i < n) : (i += 1) { + emitCmpWImm32(cx.buf, @intFromEnum(Reg.r11), @bitCast(i)); + const skip = emitJccPlaceholder(cx.buf, .ne); + try emitBrToDepth(cx, table[i]); + cx.stack_depth = fallthrough_depth; + patchRel32(cx.buf, skip, cx.buf.cursor()); + } + try emitBrToDepth(cx, table[n]); + }, + 0x20 => { + const idx = try binary.readULEB128(u32, code, pos); + if (idx >= cx.local_count) return error.UnsupportedOpcode; + const off: i32 = @intCast(local_base_bytes + idx * 8); + switch (cx.local_types[idx]) { + .i32, .f32 => { + emitMovWFromRspDisp(cx.buf, @intFromEnum(Reg.r9), off); + try pushW(cx, @intFromEnum(Reg.r9)); + }, + .i64, .f64 => { + emitMovXFromRspDisp(cx.buf, @intFromEnum(Reg.r9), off); + try pushX(cx, @intFromEnum(Reg.r9)); + }, + } + }, + 0x21 => { + const idx = try binary.readULEB128(u32, code, pos); + if (idx >= cx.local_count) return error.UnsupportedOpcode; + const off: i32 = @intCast(local_base_bytes + idx * 8); + switch (cx.local_types[idx]) { + .i32, .f32 => { + try popW(cx, @intFromEnum(Reg.r9)); + emitMovRspDispFromW(cx.buf, off, @intFromEnum(Reg.r9)); + }, + .i64, .f64 => { + try popX(cx, @intFromEnum(Reg.r9)); + emitMovRspDispFromX(cx.buf, off, @intFromEnum(Reg.r9)); + }, + } + }, + 0x22 => { + const idx = try binary.readULEB128(u32, code, pos); + if (idx >= cx.local_count) return error.UnsupportedOpcode; + const off: i32 = @intCast(local_base_bytes + idx * 8); + switch (cx.local_types[idx]) { + .i32, .f32 => { + try popW(cx, @intFromEnum(Reg.r9)); + emitMovRspDispFromW(cx.buf, off, @intFromEnum(Reg.r9)); + try pushW(cx, @intFromEnum(Reg.r9)); + }, + .i64, .f64 => { + try popX(cx, @intFromEnum(Reg.r9)); + emitMovRspDispFromX(cx.buf, off, @intFromEnum(Reg.r9)); + try pushX(cx, @intFromEnum(Reg.r9)); + }, + } + }, + 0x41 => { + const v = try binary.readSLEB128(i32, code, pos); + emitMovImm32(cx.buf, @intFromEnum(Reg.r9), @bitCast(v)); + try pushW(cx, @intFromEnum(Reg.r9)); + }, + 0x42 => { + const v = try binary.readSLEB128(i64, code, pos); + emitMovImm64(cx.buf, @intFromEnum(Reg.r9), @bitCast(v)); + try pushX(cx, @intFromEnum(Reg.r9)); + }, + 0x43 => { + if (pos.* + 4 > code.len) return error.UnexpectedEof; + const bits = std.mem.readInt(u32, code[pos.*..][0..4], .little); + pos.* += 4; + emitMovImm32(cx.buf, @intFromEnum(Reg.r9), bits); + try pushW(cx, @intFromEnum(Reg.r9)); + }, + 0x44 => { + if (pos.* + 8 > code.len) return error.UnexpectedEof; + const bits = std.mem.readInt(u64, code[pos.*..][0..8], .little); + pos.* += 8; + emitMovImm64(cx.buf, @intFromEnum(Reg.r9), bits); + try pushX(cx, @intFromEnum(Reg.r9)); + }, + 0x1A => { + try popX(cx, @intFromEnum(Reg.r9)); + }, + 0x1B => { + try popW(cx, @intFromEnum(Reg.r11)); + try popX(cx, @intFromEnum(Reg.r10)); + try popX(cx, @intFromEnum(Reg.r9)); + emitTestWReg(cx.buf, @intFromEnum(Reg.r11), @intFromEnum(Reg.r11)); + const use_rhs = emitJccPlaceholder(cx.buf, .z); + const done = emitJmpPlaceholder(cx.buf); + patchRel32(cx.buf, use_rhs, cx.buf.cursor()); + emitMovXReg(cx.buf, @intFromEnum(Reg.r9), @intFromEnum(Reg.r10)); + patchRel32(cx.buf, done, cx.buf.cursor()); + try pushX(cx, @intFromEnum(Reg.r9)); + }, + 0x1C => { + const n = try binary.readULEB128(u32, code, pos); + if (n != 1) return error.UnsupportedOpcode; + if (pos.* >= code.len) return error.UnexpectedEof; + const vt = try decodeValType(code[pos.*]); + pos.* += 1; + + try popW(cx, @intFromEnum(Reg.r11)); + switch (vt) { + .i32, .f32 => { + try popW(cx, @intFromEnum(Reg.r10)); + try popW(cx, @intFromEnum(Reg.r9)); + emitTestWReg(cx.buf, @intFromEnum(Reg.r11), @intFromEnum(Reg.r11)); + const use_rhs = emitJccPlaceholder(cx.buf, .z); + const done = emitJmpPlaceholder(cx.buf); + patchRel32(cx.buf, use_rhs, cx.buf.cursor()); + emitMovWReg(cx.buf, @intFromEnum(Reg.r9), @intFromEnum(Reg.r10)); + patchRel32(cx.buf, done, cx.buf.cursor()); + try pushW(cx, @intFromEnum(Reg.r9)); + }, + .i64, .f64 => { + try popX(cx, @intFromEnum(Reg.r10)); + try popX(cx, @intFromEnum(Reg.r9)); + emitTestWReg(cx.buf, @intFromEnum(Reg.r11), @intFromEnum(Reg.r11)); + const use_rhs = emitJccPlaceholder(cx.buf, .z); + const done = emitJmpPlaceholder(cx.buf); + patchRel32(cx.buf, use_rhs, cx.buf.cursor()); + emitMovXReg(cx.buf, @intFromEnum(Reg.r9), @intFromEnum(Reg.r10)); + patchRel32(cx.buf, done, cx.buf.cursor()); + try pushX(cx, @intFromEnum(Reg.r9)); + }, + } + }, + 0x45, 0x67...0x69 => { + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.i32_unary); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0x46...0x4F => { + try popW(cx, @intFromEnum(Reg.rdx)); + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.i32_cmp); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0x6A...0x78 => { + try popW(cx, @intFromEnum(Reg.rdx)); + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.i32_binary); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0x50 => { + try popX(cx, @intFromEnum(Reg.rdi)); + emitCallAbs(cx.buf, cx.helpers.i64_eqz); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0x51...0x5A => { + try popX(cx, @intFromEnum(Reg.rdx)); + try popX(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.i64_cmp); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0x79...0x7B => { + try popX(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.i64_unary); + try pushX(cx, @intFromEnum(Reg.rax)); + }, + 0x7C...0x8A => { + try popX(cx, @intFromEnum(Reg.rdx)); + try popX(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.i64_binary); + try pushX(cx, @intFromEnum(Reg.rax)); + }, + 0x5B...0x60 => { + try popW(cx, @intFromEnum(Reg.rdx)); + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.f32_cmp); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0x61...0x66 => { + try popX(cx, @intFromEnum(Reg.rdx)); + try popX(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.f64_cmp); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0x8B...0x91 => { + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.f32_unary); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0x92...0x98 => { + try popW(cx, @intFromEnum(Reg.rdx)); + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.f32_binary); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0x99...0x9F => { + try popX(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.f64_unary); + try pushX(cx, @intFromEnum(Reg.rax)); + }, + 0xA0...0xA6 => { + try popX(cx, @intFromEnum(Reg.rdx)); + try popX(cx, @intFromEnum(Reg.rsi)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.f64_binary); + try pushX(cx, @intFromEnum(Reg.rax)); + }, + 0xA7...0xBF => { + switch (op) { + 0xA8, 0xA9, 0xAC, 0xAD, 0xB2, 0xB3, 0xB7, 0xB8, 0xBB, 0xBC, 0xBE => try popW(cx, @intFromEnum(Reg.rsi)), + else => try popX(cx, @intFromEnum(Reg.rsi)), + } + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.convert); + switch (convertResultType(op)) { + .i32, .f32 => try pushW(cx, @intFromEnum(Reg.rax)), + .i64, .f64 => try pushX(cx, @intFromEnum(Reg.rax)), + } + }, + 0xC0...0xC4 => { + switch (op) { + 0xC0, 0xC1 => try popW(cx, @intFromEnum(Reg.rsi)), + else => try popX(cx, @intFromEnum(Reg.rsi)), + } + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), op); + emitCallAbs(cx.buf, cx.helpers.i_extend); + switch (op) { + 0xC0, 0xC1 => try pushW(cx, @intFromEnum(Reg.rax)), + else => try pushX(cx, @intFromEnum(Reg.rax)), + } + }, + 0x10 => { + const fidx = try binary.readULEB128(u32, code, pos); + const cft = try getFuncType(cx.mod, cx.num_imported_funcs, fidx); + if (cft.results.len > 1) return error.UnsupportedOpcode; + if (cx.stack_depth < cft.params.len) return error.StackUnderflow; + + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rsi), fidx); + if (cft.params.len == 0) { + emitMovXReg(cx.buf, @intFromEnum(Reg.rdx), @intFromEnum(Reg.rbx)); + } else { + const bytes = cft.params.len * 8; + emitMovXReg(cx.buf, @intFromEnum(Reg.rdx), @intFromEnum(Reg.rbx)); + emitSubXImm32(cx.buf, @intFromEnum(Reg.rdx), @intCast(bytes)); + } + emitMovImm32(cx.buf, @intFromEnum(Reg.rcx), @intCast(cft.params.len)); + emitCallAbs(cx.buf, cx.helpers.call); + + if (cft.params.len > 0) { + const bytes = cft.params.len * 8; + emitSubXImm32(cx.buf, @intFromEnum(Reg.rbx), @intCast(bytes)); + cx.stack_depth -= cft.params.len; + } + if (cft.results.len == 1) { + switch (cft.results[0]) { + .i32, .f32 => try pushW(cx, @intFromEnum(Reg.rax)), + .i64, .f64 => try pushX(cx, @intFromEnum(Reg.rax)), + } + } + }, + 0x11 => { + const type_idx = try binary.readULEB128(u32, code, pos); + const table_idx = try binary.readULEB128(u32, code, pos); + if (type_idx >= cx.mod.types.len) return error.UnsupportedOpcode; + const cft = &cx.mod.types[type_idx]; + if (cft.results.len > 1) return error.UnsupportedOpcode; + if (cx.stack_depth < cft.params.len + 1) return error.StackUnderflow; + + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rsi), type_idx); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdx), table_idx); + try popW(cx, @intFromEnum(Reg.rcx)); + if (cft.params.len == 0) { + emitMovXReg(cx.buf, @intFromEnum(Reg.r8), @intFromEnum(Reg.rbx)); + } else { + const bytes = cft.params.len * 8; + emitMovXReg(cx.buf, @intFromEnum(Reg.r8), @intFromEnum(Reg.rbx)); + emitSubXImm32(cx.buf, @intFromEnum(Reg.r8), @intCast(bytes)); + } + emitMovImm32(cx.buf, @intFromEnum(Reg.r9), @intCast(cft.params.len)); + emitCallAbs(cx.buf, cx.helpers.call_indirect); + + if (cft.params.len > 0) { + const bytes = cft.params.len * 8; + emitSubXImm32(cx.buf, @intFromEnum(Reg.rbx), @intCast(bytes)); + cx.stack_depth -= cft.params.len; + } + if (cft.results.len == 1) { + switch (cft.results[0]) { + .i32, .f32 => try pushW(cx, @intFromEnum(Reg.rax)), + .i64, .f64 => try pushX(cx, @intFromEnum(Reg.rax)), + } + } + }, + 0x23 => { + const gidx = try binary.readULEB128(u32, code, pos); + const gvt = try getGlobalValType(cx.mod, gidx); + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rsi), gidx); + emitCallAbs(cx.buf, cx.helpers.global_get); + switch (gvt) { + .i32, .f32 => try pushW(cx, @intFromEnum(Reg.rax)), + .i64, .f64 => try pushX(cx, @intFromEnum(Reg.rax)), + } + }, + 0x24 => { + const gidx = try binary.readULEB128(u32, code, pos); + const gvt = try getGlobalValType(cx.mod, gidx); + switch (gvt) { + .i32, .f32 => try popW(cx, @intFromEnum(Reg.rdx)), + .i64, .f64 => try popX(cx, @intFromEnum(Reg.rdx)), + } + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rsi), gidx); + emitCallAbs(cx.buf, cx.helpers.global_set); + }, + 0x28...0x35 => { + _ = try binary.readULEB128(u32, code, pos); + const offset = try binary.readULEB128(u32, code, pos); + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdx), offset); + emitMovImm32(cx.buf, @intFromEnum(Reg.rcx), op); + emitCallAbs(cx.buf, cx.helpers.mem_load); + switch (memLoadResultType(op)) { + .i32, .f32 => try pushW(cx, @intFromEnum(Reg.rax)), + .i64, .f64 => try pushX(cx, @intFromEnum(Reg.rax)), + } + }, + 0x36...0x3E => { + _ = try binary.readULEB128(u32, code, pos); + const offset = try binary.readULEB128(u32, code, pos); + switch (memStoreValueType(op)) { + .i32, .f32 => try popW(cx, @intFromEnum(Reg.r8)), + .i64, .f64 => try popX(cx, @intFromEnum(Reg.r8)), + } + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rdx), offset); + emitMovImm32(cx.buf, @intFromEnum(Reg.rcx), op); + emitCallAbs(cx.buf, cx.helpers.mem_store); + }, + 0x3F => { + _ = try binary.readULEB128(u8, code, pos); + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitCallAbs(cx.buf, cx.helpers.memory_size); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0x40 => { + _ = try binary.readULEB128(u8, code, pos); + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitCallAbs(cx.buf, cx.helpers.memory_grow); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + 0xFC => { + const subop = try binary.readULEB128(u32, code, pos); + switch (subop) { + 0...7 => { + switch (subop) { + 0, 1, 4, 5 => try popW(cx, @intFromEnum(Reg.rsi)), + 2, 3, 6, 7 => try popX(cx, @intFromEnum(Reg.rsi)), + else => unreachable, + } + emitMovImm32(cx.buf, @intFromEnum(Reg.rdi), @intCast(subop)); + emitCallAbs(cx.buf, cx.helpers.trunc_sat); + switch (subop) { + 0, 1, 2, 3 => try pushW(cx, @intFromEnum(Reg.rax)), + 4, 5, 6, 7 => try pushX(cx, @intFromEnum(Reg.rax)), + else => unreachable, + } + }, + 8 => { + const data_idx = try binary.readULEB128(u32, code, pos); + const mem_idx = try binary.readULEB128(u32, code, pos); + if (mem_idx != 0) return error.UnsupportedOpcode; + try popW(cx, @intFromEnum(Reg.rcx)); + try popW(cx, @intFromEnum(Reg.rdx)); + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitMovImm32(cx.buf, @intFromEnum(Reg.r8), data_idx); + emitCallAbs(cx.buf, cx.helpers.memory_init); + }, + 9 => { + const data_idx = try binary.readULEB128(u32, code, pos); + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rsi), data_idx); + emitCallAbs(cx.buf, cx.helpers.data_drop); + }, + 10 => { + const dst_mem = try binary.readULEB128(u32, code, pos); + const src_mem = try binary.readULEB128(u32, code, pos); + if (dst_mem != 0 or src_mem != 0) return error.UnsupportedOpcode; + try popW(cx, @intFromEnum(Reg.rcx)); + try popW(cx, @intFromEnum(Reg.rdx)); + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitCallAbs(cx.buf, cx.helpers.memory_copy); + }, + 11 => { + const mem_idx = try binary.readULEB128(u32, code, pos); + if (mem_idx != 0) return error.UnsupportedOpcode; + try popW(cx, @intFromEnum(Reg.rcx)); + try popW(cx, @intFromEnum(Reg.rdx)); + try popW(cx, @intFromEnum(Reg.rsi)); + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitCallAbs(cx.buf, cx.helpers.memory_fill); + }, + 16 => { + const table_idx = try binary.readULEB128(u32, code, pos); + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitMovImm32(cx.buf, @intFromEnum(Reg.rsi), table_idx); + emitCallAbs(cx.buf, cx.helpers.table_size); + try pushW(cx, @intFromEnum(Reg.rax)); + }, + else => return error.UnsupportedOpcode, + } + }, + 0x02, 0x03 => { + const is_loop = op == 0x03; + const sig = try readBlockSig(cx, code, pos, is_loop); + try cx.control.append(cx.allocator, .{ + .kind = if (is_loop) .loop else .block, + .entry_depth = cx.stack_depth, + .label_arity = if (is_loop) 0 else sig.arity, + .label_type = if (is_loop) null else sig.val_type, + .end_arity = sig.arity, + .end_type = sig.val_type, + .loop_head_pos = cx.buf.cursor(), + .end_patches = .empty, + }); + const nested_end = try compileBlock(cx, code, pos, false); + if (nested_end != .hit_end) return error.MalformedControlFlow; + var fr = cx.control.pop().?; + fr.end_patches.deinit(cx.allocator); + }, + 0x04 => { + const sig = try readBlockSig(cx, code, pos, false); + try popW(cx, @intFromEnum(Reg.r9)); + emitTestWReg(cx.buf, @intFromEnum(Reg.r9), @intFromEnum(Reg.r9)); + const entry_depth = cx.stack_depth; + try cx.control.append(cx.allocator, .{ + .kind = .@"if", + .entry_depth = entry_depth, + .label_arity = sig.arity, + .label_type = sig.val_type, + .end_arity = sig.arity, + .end_type = sig.val_type, + .loop_head_pos = 0, + .end_patches = .empty, + }); + const jz_pos = emitJccPlaceholder(cx.buf, .z); + const then_end = try compileBlock(cx, code, pos, true); + if (then_end == .hit_else) { + const jump_end = emitJmpPlaceholder(cx.buf); + try currentFrame(cx).end_patches.append(cx.allocator, jump_end); + patchRel32(cx.buf, jz_pos, cx.buf.cursor()); + try setStackDepth(cx, entry_depth); + const else_end = try compileBlock(cx, code, pos, false); + if (else_end != .hit_end) return error.MalformedControlFlow; + } else { + patchRel32(cx.buf, jz_pos, cx.buf.cursor()); + } + var fr = cx.control.pop().?; + fr.end_patches.deinit(cx.allocator); + }, + 0x00 => { + emitMovXReg(cx.buf, @intFromEnum(Reg.rdi), @intFromEnum(Reg.r13)); + emitCallAbs(cx.buf, cx.helpers.@"unreachable"); + }, + 0x01 => {}, + 0x0F => { + if (cx.result_type) |rt| { + switch (rt) { + .i32, .f32 => try popW(cx, @intFromEnum(Reg.rax)), + .i64, .f64 => try popX(cx, @intFromEnum(Reg.rax)), + } + } + emitEpilogueAndRet(cx.buf); + return .hit_end; + }, + else => return error.UnsupportedOpcode, + } + } + return error.UnexpectedEof; +} + +const BlockSig = struct { + arity: u8, + val_type: ?module.ValType, +}; + +fn readBlockSig(cx: *Context, code: []const u8, pos: *usize, is_loop: bool) !BlockSig { + const bt = try binary.readSLEB128(i33, code, pos); + if (bt == -0x40) return .{ .arity = 0, .val_type = null }; + if (bt == -0x01) return .{ .arity = if (is_loop) 0 else 1, .val_type = .i32 }; + if (bt == -0x02) return .{ .arity = if (is_loop) 0 else 1, .val_type = .i64 }; + if (bt == -0x03) return .{ .arity = if (is_loop) 0 else 1, .val_type = .f32 }; + if (bt == -0x04) return .{ .arity = if (is_loop) 0 else 1, .val_type = .f64 }; + if (bt < 0) return error.UnsupportedOpcode; + const type_idx: u32 = @intCast(bt); + if (type_idx >= cx.mod.types.len) return error.UnsupportedOpcode; + const ft = &cx.mod.types[type_idx]; + if (ft.params.len != 0) return error.UnsupportedOpcode; + if (ft.results.len == 0) return .{ .arity = 0, .val_type = null }; + if (ft.results.len == 1) return .{ .arity = if (is_loop) 0 else 1, .val_type = ft.results[0] }; + return error.UnsupportedOpcode; +} + +fn decodeValType(b: u8) !module.ValType { + return switch (b) { + 0x7F => .i32, + 0x7E => .i64, + 0x7D => .f32, + 0x7C => .f64, + else => error.UnsupportedOpcode, + }; +} + +fn memLoadResultType(op: u8) module.ValType { + return switch (op) { + 0x28, 0x2C, 0x2D, 0x2E, 0x2F => .i32, + 0x29, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35 => .i64, + 0x2A => .f32, + 0x2B => .f64, + else => .i32, + }; +} + +fn memStoreValueType(op: u8) module.ValType { + return switch (op) { + 0x36, 0x3A, 0x3B => .i32, + 0x37, 0x3C, 0x3D, 0x3E => .i64, + 0x38 => .f32, + 0x39 => .f64, + else => .i32, + }; +} + +fn convertResultType(op: u8) module.ValType { + return switch (op) { + 0xA7, 0xA8, 0xA9, 0xAA, 0xAB, 0xBC => .i32, + 0xAC, 0xAD, 0xAE, 0xAF, 0xB0, 0xB1, 0xBD => .i64, + 0xB2, 0xB3, 0xB4, 0xB5, 0xB6, 0xBE => .f32, + 0xB7, 0xB8, 0xB9, 0xBA, 0xBB, 0xBF => .f64, + else => .i32, + }; +} + +fn currentFrame(cx: *Context) *ControlFrame { + return &cx.control.items[cx.control.items.len - 1]; +} + +fn setStackDepth(cx: *Context, depth: usize) !void { + if (depth > cx.max_stack_depth) return error.StackOverflow; + if (depth == cx.stack_depth) return; + if (depth > cx.stack_depth) { + const bytes = (depth - cx.stack_depth) * 8; + emitAddXImm32(cx.buf, @intFromEnum(Reg.rbx), @intCast(bytes)); + } else { + const bytes = (cx.stack_depth - depth) * 8; + emitSubXImm32(cx.buf, @intFromEnum(Reg.rbx), @intCast(bytes)); + } + cx.stack_depth = depth; +} + +fn emitBrToDepth(cx: *Context, depth: u32) !void { + if (depth >= cx.control.items.len) return error.MalformedControlFlow; + const target_idx = cx.control.items.len - 1 - depth; + const target = &cx.control.items[target_idx]; + + const result_reg: u4 = @intFromEnum(Reg.r9); + if (target.label_arity == 1) { + const t = target.label_type orelse return error.UnsupportedOpcode; + switch (t) { + .i32, .f32 => try popW(cx, result_reg), + .i64, .f64 => try popX(cx, result_reg), + } + } else if (target.label_arity != 0) { + return error.UnsupportedOpcode; + } + + try setStackDepth(cx, target.entry_depth); + if (target.label_arity == 1 and target.kind != .loop) { + const t = target.label_type orelse return error.UnsupportedOpcode; + switch (t) { + .i32, .f32 => try pushW(cx, result_reg), + .i64, .f64 => try pushX(cx, result_reg), + } + } + + if (target.kind == .loop) { + const p = emitJmpPlaceholder(cx.buf); + patchRel32(cx.buf, p, target.loop_head_pos); + } else { + const p = emitJmpPlaceholder(cx.buf); + try target.end_patches.append(cx.allocator, p); + } +} + +fn pushW(cx: *Context, reg: u4) !void { + if (cx.stack_depth >= cx.max_stack_depth) return error.StackOverflow; + emitMovMemRbxFromW(cx.buf, reg); + emitAddXImm32(cx.buf, @intFromEnum(Reg.rbx), 8); + cx.stack_depth += 1; +} + +fn popW(cx: *Context, reg: u4) !void { + if (cx.stack_depth == 0) return error.StackUnderflow; + emitSubXImm32(cx.buf, @intFromEnum(Reg.rbx), 8); + emitMovWFromMemRbx(cx.buf, reg); + cx.stack_depth -= 1; +} + +fn pushX(cx: *Context, reg: u4) !void { + if (cx.stack_depth >= cx.max_stack_depth) return error.StackOverflow; + emitMovMemRbxFromX(cx.buf, reg); + emitAddXImm32(cx.buf, @intFromEnum(Reg.rbx), 8); + cx.stack_depth += 1; +} + +fn popX(cx: *Context, reg: u4) !void { + if (cx.stack_depth == 0) return error.StackUnderflow; + emitSubXImm32(cx.buf, @intFromEnum(Reg.rbx), 8); + emitMovXFromMemRbx(cx.buf, reg); + cx.stack_depth -= 1; +} + +fn getFuncType(mod: *const module.Module, num_imported: u32, fidx: u32) !*const module.FuncType { + if (fidx < num_imported) { + var count: u32 = 0; + for (mod.imports) |imp| { + if (imp.desc == .func) { + if (count == fidx) return &mod.types[imp.desc.func]; + count += 1; + } + } + return error.InvalidFunctionIndex; + } + const local_idx = fidx - num_imported; + if (local_idx >= mod.functions.len) return error.InvalidFunctionIndex; + const type_idx = mod.functions[local_idx]; + if (type_idx >= mod.types.len) return error.InvalidTypeIndex; + return &mod.types[type_idx]; +} + +fn getGlobalValType(mod: *const module.Module, gidx: u32) !module.ValType { + var import_global_count: u32 = 0; + for (mod.imports) |imp| { + if (imp.desc == .global) { + if (import_global_count == gidx) return imp.desc.global.valtype; + import_global_count += 1; + } + } + const local_idx = gidx - import_global_count; + if (local_idx >= mod.globals.len) return error.InvalidGlobalIndex; + return mod.globals[local_idx].type.valtype; +} + +fn emitPrologue(buf: *codebuf.CodeBuffer, param_count: u32, local_count: u32, operand_base_bytes: u32) void { + emitSubXImm32(buf, @intFromEnum(Reg.rsp), @intCast(frame_size_bytes)); + + emitMovRspDispFromX(buf, 0, @intFromEnum(Reg.rbx)); + emitMovRspDispFromX(buf, 8, @intFromEnum(Reg.r12)); + emitMovRspDispFromX(buf, 16, @intFromEnum(Reg.r13)); + + emitMovXReg(buf, @intFromEnum(Reg.r13), @intFromEnum(Reg.rdi)); + + var i: u32 = 0; + while (i < param_count) : (i += 1) { + const src_off: i32 = @intCast(i * 8); + const dst_off: i32 = @intCast(local_base_bytes + i * 8); + emitMovXFromBaseDisp(buf, @intFromEnum(Reg.r9), @intFromEnum(Reg.rsi), src_off); + emitMovBaseDispFromX(buf, @intFromEnum(Reg.rsp), dst_off, @intFromEnum(Reg.r9)); + } + + emitMovImm64(buf, @intFromEnum(Reg.r9), 0); + var j: u32 = param_count; + while (j < local_count) : (j += 1) { + const dst_off: i32 = @intCast(local_base_bytes + j * 8); + emitMovRspDispFromX(buf, dst_off, @intFromEnum(Reg.r9)); + } + + emitLeaRegRspDisp(buf, @intFromEnum(Reg.rbx), @intCast(operand_base_bytes)); +} + +fn emitEpilogueAndRet(buf: *codebuf.CodeBuffer) void { + emitMovXFromRspDisp(buf, @intFromEnum(Reg.r13), 16); + emitMovXFromRspDisp(buf, @intFromEnum(Reg.r12), 8); + emitMovXFromRspDisp(buf, @intFromEnum(Reg.rbx), 0); + emitAddXImm32(buf, @intFromEnum(Reg.rsp), @intCast(frame_size_bytes)); + buf.emit1(0xC3); +} + +const Reg = enum(u4) { + rax = 0, + rcx = 1, + rdx = 2, + rbx = 3, + rsp = 4, + rbp = 5, + rsi = 6, + rdi = 7, + r8 = 8, + r9 = 9, + r10 = 10, + r11 = 11, + r12 = 12, + r13 = 13, + r14 = 14, + r15 = 15, +}; + +const Jcc = enum(u8) { + z = 0x84, + ne = 0x85, +}; + +fn emitRex(buf: *codebuf.CodeBuffer, w: bool, r: u1, x: u1, b: u1) void { + const rex: u8 = 0x40 | (if (w) 0x08 else 0) | (@as(u8, r) << 2) | (@as(u8, x) << 1) | @as(u8, b); + if (rex != 0x40) buf.emit1(rex); +} + +fn emitModRM(buf: *codebuf.CodeBuffer, mod: u2, reg: u3, rm: u3) void { + buf.emit1((@as(u8, mod) << 6) | (@as(u8, reg) << 3) | @as(u8, rm)); +} + +fn emitSIB(buf: *codebuf.CodeBuffer, scale: u2, index: u3, base: u3) void { + buf.emit1((@as(u8, scale) << 6) | (@as(u8, index) << 3) | @as(u8, base)); +} + +fn emitMovImm32(buf: *codebuf.CodeBuffer, reg: u4, imm: u32) void { + emitRex(buf, false, 0, 0, @truncate(reg >> 3)); + buf.emit1(0xB8 + @as(u8, reg & 7)); + buf.emitU32Le(imm); +} + +fn emitMovImm64(buf: *codebuf.CodeBuffer, reg: u4, imm: u64) void { + emitRex(buf, true, 0, 0, @truncate(reg >> 3)); + buf.emit1(0xB8 + @as(u8, reg & 7)); + std.mem.writeInt(u64, buf.buf[buf.pos..][0..8], imm, .little); + buf.pos += 8; +} + +fn emitMovXReg(buf: *codebuf.CodeBuffer, dst: u4, src: u4) void { + emitRex(buf, true, @truncate(src >> 3), 0, @truncate(dst >> 3)); + buf.emit1(0x89); + emitModRM(buf, 0b11, @truncate(src & 7), @truncate(dst & 7)); +} + +fn emitMovWReg(buf: *codebuf.CodeBuffer, dst: u4, src: u4) void { + emitRex(buf, false, @truncate(src >> 3), 0, @truncate(dst >> 3)); + buf.emit1(0x89); + emitModRM(buf, 0b11, @truncate(src & 7), @truncate(dst & 7)); +} + +fn emitAddXImm32(buf: *codebuf.CodeBuffer, reg: u4, imm: i32) void { + emitRex(buf, true, 0, 0, @truncate(reg >> 3)); + buf.emit1(0x81); + emitModRM(buf, 0b11, 0, @truncate(reg & 7)); + buf.emitI32Le(imm); +} + +fn emitSubXImm32(buf: *codebuf.CodeBuffer, reg: u4, imm: i32) void { + emitRex(buf, true, 0, 0, @truncate(reg >> 3)); + buf.emit1(0x81); + emitModRM(buf, 0b11, 5, @truncate(reg & 7)); + buf.emitI32Le(imm); +} + +fn emitCmpWImm32(buf: *codebuf.CodeBuffer, reg: u4, imm: u32) void { + emitRex(buf, false, 0, 0, @truncate(reg >> 3)); + buf.emit1(0x81); + emitModRM(buf, 0b11, 7, @truncate(reg & 7)); + buf.emitU32Le(imm); +} + +fn emitTestWReg(buf: *codebuf.CodeBuffer, a: u4, b: u4) void { + emitRex(buf, false, @truncate(b >> 3), 0, @truncate(a >> 3)); + buf.emit1(0x85); + emitModRM(buf, 0b11, @truncate(b & 7), @truncate(a & 7)); +} + +fn emitMovXFromBaseDisp(buf: *codebuf.CodeBuffer, dst: u4, base: u4, disp: i32) void { + emitRex(buf, true, @truncate(dst >> 3), 0, @truncate(base >> 3)); + buf.emit1(0x8B); + if ((base & 7) == 4) { + emitModRM(buf, 0b10, @truncate(dst & 7), 4); + emitSIB(buf, 0, 4, @truncate(base & 7)); + } else { + emitModRM(buf, 0b10, @truncate(dst & 7), @truncate(base & 7)); + } + buf.emitI32Le(disp); +} + +fn emitMovBaseDispFromX(buf: *codebuf.CodeBuffer, base: u4, disp: i32, src: u4) void { + emitRex(buf, true, @truncate(src >> 3), 0, @truncate(base >> 3)); + buf.emit1(0x89); + if ((base & 7) == 4) { + emitModRM(buf, 0b10, @truncate(src & 7), 4); + emitSIB(buf, 0, 4, @truncate(base & 7)); + } else { + emitModRM(buf, 0b10, @truncate(src & 7), @truncate(base & 7)); + } + buf.emitI32Le(disp); +} + +fn emitMovWFromBaseDisp(buf: *codebuf.CodeBuffer, dst: u4, base: u4, disp: i32) void { + emitRex(buf, false, @truncate(dst >> 3), 0, @truncate(base >> 3)); + buf.emit1(0x8B); + if ((base & 7) == 4) { + emitModRM(buf, 0b10, @truncate(dst & 7), 4); + emitSIB(buf, 0, 4, @truncate(base & 7)); + } else { + emitModRM(buf, 0b10, @truncate(dst & 7), @truncate(base & 7)); + } + buf.emitI32Le(disp); +} + +fn emitMovBaseDispFromW(buf: *codebuf.CodeBuffer, base: u4, disp: i32, src: u4) void { + emitRex(buf, false, @truncate(src >> 3), 0, @truncate(base >> 3)); + buf.emit1(0x89); + if ((base & 7) == 4) { + emitModRM(buf, 0b10, @truncate(src & 7), 4); + emitSIB(buf, 0, 4, @truncate(base & 7)); + } else { + emitModRM(buf, 0b10, @truncate(src & 7), @truncate(base & 7)); + } + buf.emitI32Le(disp); +} + +fn emitMovXFromRspDisp(buf: *codebuf.CodeBuffer, dst: u4, disp: i32) void { + emitMovXFromBaseDisp(buf, dst, @intFromEnum(Reg.rsp), disp); +} + +fn emitMovRspDispFromX(buf: *codebuf.CodeBuffer, disp: i32, src: u4) void { + emitMovBaseDispFromX(buf, @intFromEnum(Reg.rsp), disp, src); +} + +fn emitMovWFromRspDisp(buf: *codebuf.CodeBuffer, dst: u4, disp: i32) void { + emitMovWFromBaseDisp(buf, dst, @intFromEnum(Reg.rsp), disp); +} + +fn emitMovRspDispFromW(buf: *codebuf.CodeBuffer, disp: i32, src: u4) void { + emitMovBaseDispFromW(buf, @intFromEnum(Reg.rsp), disp, src); +} + +fn emitMovMemRbxFromX(buf: *codebuf.CodeBuffer, src: u4) void { + emitRex(buf, true, @truncate(src >> 3), 0, 0); + buf.emit1(0x89); + emitModRM(buf, 0b00, @truncate(src & 7), 3); +} + +fn emitMovXFromMemRbx(buf: *codebuf.CodeBuffer, dst: u4) void { + emitRex(buf, true, @truncate(dst >> 3), 0, 0); + buf.emit1(0x8B); + emitModRM(buf, 0b00, @truncate(dst & 7), 3); +} + +fn emitMovMemRbxFromW(buf: *codebuf.CodeBuffer, src: u4) void { + emitRex(buf, false, @truncate(src >> 3), 0, 0); + buf.emit1(0x89); + emitModRM(buf, 0b00, @truncate(src & 7), 3); +} + +fn emitMovWFromMemRbx(buf: *codebuf.CodeBuffer, dst: u4) void { + emitRex(buf, false, @truncate(dst >> 3), 0, 0); + buf.emit1(0x8B); + emitModRM(buf, 0b00, @truncate(dst & 7), 3); +} + +fn emitLeaRegRspDisp(buf: *codebuf.CodeBuffer, dst: u4, disp: i32) void { + emitRex(buf, true, @truncate(dst >> 3), 0, 0); + buf.emit1(0x8D); + emitModRM(buf, 0b10, @truncate(dst & 7), 4); + emitSIB(buf, 0, 4, 4); + buf.emitI32Le(disp); +} + +fn emitCallReg(buf: *codebuf.CodeBuffer, reg: u4) void { + emitRex(buf, false, 0, 0, @truncate(reg >> 3)); + buf.emit1(0xFF); + emitModRM(buf, 0b11, 2, @truncate(reg & 7)); +} + +fn emitCallAbs(buf: *codebuf.CodeBuffer, addr: usize) void { + emitMovImm64(buf, @intFromEnum(Reg.rax), @intCast(addr)); + emitCallReg(buf, @intFromEnum(Reg.rax)); +} + +fn emitJmpPlaceholder(buf: *codebuf.CodeBuffer) usize { + buf.emit1(0xE9); + const imm_pos = buf.cursor(); + buf.emitI32Le(0); + return imm_pos; +} + +fn emitJccPlaceholder(buf: *codebuf.CodeBuffer, cc: Jcc) usize { + buf.emit1(0x0F); + buf.emit1(@intFromEnum(cc)); + const imm_pos = buf.cursor(); + buf.emitI32Le(0); + return imm_pos; +} + +fn patchRel32(buf: *codebuf.CodeBuffer, imm_pos: usize, target_pos: usize) void { + const next_ip: isize = @intCast(imm_pos + 4); + const target: isize = @intCast(target_pos); + const rel: i32 = @intCast(target - next_ip); + buf.patchI32(imm_pos, rel); +} + +test "x86_64 compileFunctionI32 executes const return" { + if (builtin.cpu.arch != .x86_64) return error.SkipZigTest; + + var params = [_]module.ValType{}; + var results = [_]module.ValType{.i32}; + const ft = module.FuncType{ .params = ¶ms, .results = &results }; + var bodies = [_]module.FunctionBody{.{ + .locals = &.{}, + .code = &[_]u8{ 0x41, 0x2a, 0x0b }, + }}; + var types = [_]module.FuncType{ft}; + var funcs = [_]u32{0}; + const mod = module.Module{ + .types = &types, + .imports = &.{}, + .functions = &funcs, + .tables = &.{}, + .memories = &.{}, + .globals = &.{}, + .exports = &.{}, + .start = null, + .elements = &.{}, + .codes = &bodies, + .datas = &.{}, + .allocator = std.testing.allocator, + }; + + const helpers: HelperAddrs = .{ + .call = 0, + .@"unreachable" = 0, + .global_get = 0, + .global_set = 0, + .mem_load = 0, + .mem_store = 0, + .i32_unary = 0, + .i32_cmp = 0, + .i32_binary = 0, + .i32_div_s = 0, + .i32_div_u = 0, + .i32_rem_s = 0, + .i32_rem_u = 0, + .i64_eqz = 0, + .i64_cmp = 0, + .i64_unary = 0, + .i64_binary = 0, + .f32_cmp = 0, + .f64_cmp = 0, + .f32_unary = 0, + .f32_binary = 0, + .f64_unary = 0, + .f64_binary = 0, + .convert = 0, + .trunc_sat = 0, + .i_extend = 0, + .memory_init = 0, + .data_drop = 0, + .memory_copy = 0, + .memory_fill = 0, + .table_size = 0, + .memory_size = 0, + .memory_grow = 0, + .call_indirect = 0, + }; + + var jit = (try compileFunctionI32(std.testing.allocator, &mod, 0, 0, &bodies[0], &ft, helpers)) orelse return error.TestUnexpectedResult; + defer jit.buf.deinit(); + + var zero: u64 = 0; + const fn_ptr = jit.buf.funcPtr(fn (*anyopaque, [*]const u64, u32) callconv(.c) u64, 0); + const r = fn_ptr(@ptrFromInt(1), @ptrCast(&zero), 0); + try std.testing.expectEqual(@as(u64, 42), r); +} diff --git a/src/wasm/jit_tests.zig b/src/wasm/jit_tests.zig new file mode 100644 index 0000000..5b6ac8e --- /dev/null +++ b/src/wasm/jit_tests.zig @@ -0,0 +1,11 @@ +// Test runner for all jit/* submodules. +// Root at src/wasm/ so that "../module.zig" from jit/ files resolves correctly. +comptime { + _ = @import("jit/stackify.zig"); + _ = @import("jit/liveness.zig"); + _ = @import("jit/regalloc.zig"); + _ = @import("jit/codebuf.zig"); + _ = @import("jit/aarch64.zig"); + _ = @import("jit/x86_64.zig"); + _ = @import("jit/codegen.zig"); +} diff --git a/src/wasm/module.zig b/src/wasm/module.zig new file mode 100644 index 0000000..caf8c5d --- /dev/null +++ b/src/wasm/module.zig @@ -0,0 +1,138 @@ +const std = @import("std"); + +pub const ValType = enum(u8) { + i32 = 0x7F, + i64 = 0x7E, + f32 = 0x7D, + f64 = 0x7C, +}; + +pub const SectionId = enum(u8) { + custom = 0, + type = 1, + import = 2, + function = 3, + table = 4, + memory = 5, + global = 6, + @"export" = 7, + start = 8, + element = 9, + code = 10, + data = 11, +}; + +pub const FuncType = struct { + params: []const ValType, + results: []const ValType, +}; + +pub const MemoryType = struct { min: u32, max: ?u32 }; +pub const TableType = struct { elem_type: u8, min: u32, max: ?u32 }; +pub const GlobalType = struct { valtype: ValType, mutable: bool }; + +pub const ConstExpr = union(enum) { + i32_const: i32, + i64_const: i64, + f32_const: f32, + f64_const: f64, + global_get: u32, +}; + +pub const GlobalDef = struct { type: GlobalType, init: ConstExpr }; + +pub const ImportDesc = union(enum) { + func: u32, + table: TableType, + memory: MemoryType, + global: GlobalType, +}; + +pub const Import = struct { + module: []const u8, + name: []const u8, + desc: ImportDesc, +}; + +pub const ExportDesc = union(enum) { + func: u32, + table: u32, + memory: u32, + global: u32, +}; + +pub const Export = struct { + name: []const u8, + desc: ExportDesc, +}; + +pub const LocalDecl = struct { count: u32, valtype: ValType }; + +pub const FunctionBody = struct { + locals: []LocalDecl, + code: []const u8, +}; + +pub const ElementSegment = struct { + table_idx: u32, + offset: ConstExpr, + func_indices: []u32, +}; + +pub const DataSegment = struct { + kind: enum { active, passive }, + memory_idx: u32, + offset: ?ConstExpr, + bytes: []const u8, +}; + +pub const Module = struct { + types: []FuncType, + imports: []Import, + functions: []u32, + tables: []TableType, + memories: []MemoryType, + globals: []GlobalDef, + exports: []Export, + start: ?u32, + elements: []ElementSegment, + codes: []FunctionBody, + datas: []DataSegment, + allocator: std.mem.Allocator, + + pub fn parse(allocator: std.mem.Allocator, bytes: []const u8) !Module { + return @import("binary.zig").parse(allocator, bytes); + } + + pub fn deinit(self: *Module) void { + const ally = self.allocator; + for (self.types) |t| { + ally.free(t.params); + ally.free(t.results); + } + ally.free(self.types); + for (self.imports) |imp| { + ally.free(imp.module); + ally.free(imp.name); + } + ally.free(self.imports); + ally.free(self.functions); + ally.free(self.tables); + ally.free(self.memories); + ally.free(self.globals); + for (self.exports) |exp| { + ally.free(exp.name); + } + ally.free(self.exports); + for (self.elements) |elem| { + ally.free(elem.func_indices); + } + ally.free(self.elements); + for (self.codes) |body| { + ally.free(body.locals); + } + ally.free(self.codes); + ally.free(self.datas); + self.* = undefined; + } +}; diff --git a/src/wasm/runtime.zig b/src/wasm/runtime.zig new file mode 100644 index 0000000..30c32ac --- /dev/null +++ b/src/wasm/runtime.zig @@ -0,0 +1,115 @@ +const std = @import("std"); +const module = @import("module.zig"); + +pub const PAGE_SIZE: u32 = 65536; + +pub const Value = union(module.ValType) { + i32: i32, + i64: i64, + f32: f32, + f64: f64, +}; + +pub const Memory = struct { + bytes: []u8, + max_pages: ?u32, + + pub fn init(allocator: std.mem.Allocator, min_pages: u32, max_pages: ?u32) !Memory { + const size = @as(usize, min_pages) * PAGE_SIZE; + const bytes = try allocator.alloc(u8, size); + @memset(bytes, 0); + return .{ .bytes = bytes, .max_pages = max_pages }; + } + + pub fn deinit(self: *Memory, allocator: std.mem.Allocator) void { + allocator.free(self.bytes); + self.* = undefined; + } + + /// Grow by delta_pages. Returns old page count, or error.OutOfMemory / error.GrowthExceedsMax. + pub fn grow(self: *Memory, allocator: std.mem.Allocator, delta_pages: u32) !u32 { + const old_pages: u32 = @intCast(self.bytes.len / PAGE_SIZE); + const new_pages = old_pages + delta_pages; + if (self.max_pages) |max| { + if (new_pages > max) return error.GrowthExceedsMax; + } + const new_size = @as(usize, new_pages) * PAGE_SIZE; + const new_bytes = try allocator.realloc(self.bytes, new_size); + @memset(new_bytes[self.bytes.len..], 0); + self.bytes = new_bytes; + return old_pages; + } + + pub fn load(self: *const Memory, comptime T: type, addr: u32) !T { + const size = @sizeOf(T); + if (@as(usize, addr) + size > self.bytes.len) return error.OutOfBounds; + const Bits = std.meta.Int(.unsigned, @bitSizeOf(T)); + const raw = std.mem.readInt(Bits, self.bytes[addr..][0..size], .little); + return @bitCast(raw); + } + + pub fn store(self: *Memory, comptime T: type, addr: u32, value: T) !void { + const size = @sizeOf(T); + if (@as(usize, addr) + size > self.bytes.len) return error.OutOfBounds; + const Bits = std.meta.Int(.unsigned, @bitSizeOf(T)); + const raw: Bits = @bitCast(value); + std.mem.writeInt(Bits, self.bytes[addr..][0..size], raw, .little); + } +}; + +pub const Table = struct { + elements: []?u32, + max: ?u32, + + pub fn init(allocator: std.mem.Allocator, min: u32, max: ?u32) !Table { + const elems = try allocator.alloc(?u32, min); + @memset(elems, null); + return .{ .elements = elems, .max = max }; + } + + pub fn deinit(self: *Table, allocator: std.mem.Allocator) void { + allocator.free(self.elements); + self.* = undefined; + } +}; + +test "memory load/store round-trip i32" { + const ally = std.testing.allocator; + var mem = try Memory.init(ally, 1, null); + defer mem.deinit(ally); + try mem.store(i32, 0, 42); + const v = try mem.load(i32, 0); + try std.testing.expectEqual(@as(i32, 42), v); +} + +test "memory out-of-bounds returns error" { + const ally = std.testing.allocator; + var mem = try Memory.init(ally, 1, null); + defer mem.deinit(ally); + try std.testing.expectError(error.OutOfBounds, mem.load(i32, PAGE_SIZE - 2)); +} + +test "memory grow" { + const ally = std.testing.allocator; + var mem = try Memory.init(ally, 1, 4); + defer mem.deinit(ally); + const old = try mem.grow(ally, 1); + try std.testing.expectEqual(@as(u32, 1), old); + try std.testing.expectEqual(@as(usize, 2 * PAGE_SIZE), mem.bytes.len); +} + +test "memory grow beyond max fails" { + const ally = std.testing.allocator; + var mem = try Memory.init(ally, 1, 2); + defer mem.deinit(ally); + try std.testing.expectError(error.GrowthExceedsMax, mem.grow(ally, 2)); +} + +test "memory store/load f64" { + const ally = std.testing.allocator; + var mem = try Memory.init(ally, 1, null); + defer mem.deinit(ally); + try mem.store(f64, 8, 3.14); + const v = try mem.load(f64, 8); + try std.testing.expectApproxEqAbs(3.14, v, 1e-10); +} diff --git a/src/wasm/trap.zig b/src/wasm/trap.zig new file mode 100644 index 0000000..6e4dbf7 --- /dev/null +++ b/src/wasm/trap.zig @@ -0,0 +1,20 @@ +pub const TrapCode = enum(u32) { + @"unreachable", + memory_out_of_bounds, + undefined_global, + undefined_table, + invalid_function, + integer_divide_by_zero, + integer_overflow, + invalid_conversion_to_integer, + stack_overflow, + indirect_call_type_mismatch, + undefined_element, + uninitialized_element, + call_stack_exhausted, +}; + +pub const Trap = struct { + code: TrapCode, + message: []const u8, +}; diff --git a/src/wasm/validator.zig b/src/wasm/validator.zig new file mode 100644 index 0000000..e0acdef --- /dev/null +++ b/src/wasm/validator.zig @@ -0,0 +1,881 @@ +const std = @import("std"); +const module = @import("module.zig"); + +pub const ValidationError = error{ + TypeMismatch, + StackUnderflow, + UndefinedFunction, + UndefinedLocal, + UndefinedGlobal, + UndefinedMemory, + UndefinedTable, + InvalidLabelDepth, + ImmutableGlobal, + InvalidTypeIndex, + InvalidFunctionIndex, + ElseWithoutIf, + InvalidAlignment, + UnsupportedOpcode, + OutOfMemory, +}; + +/// Validate a parsed Module. Returns void on success, error on failure. +pub fn validate(mod: *const module.Module) ValidationError!void { + var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); + defer arena.deinit(); + const ally = arena.allocator(); + + var num_imported_funcs: u32 = 0; + var num_imported_tables: u32 = 0; + var num_imported_memories: u32 = 0; + var imported_globals: std.ArrayList(module.GlobalType) = .empty; + defer imported_globals.deinit(ally); + + for (mod.imports) |imp| { + switch (imp.desc) { + .func => |type_idx| { + if (type_idx >= mod.types.len) return ValidationError.InvalidTypeIndex; + num_imported_funcs += 1; + }, + .table => num_imported_tables += 1, + .memory => num_imported_memories += 1, + .global => |gt| try imported_globals.append(ally, gt), + } + } + const total_funcs: u32 = num_imported_funcs + @as(u32, @intCast(mod.functions.len)); + const total_tables: u32 = num_imported_tables + @as(u32, @intCast(mod.tables.len)); + const total_memories: u32 = num_imported_memories + @as(u32, @intCast(mod.memories.len)); + const total_globals: u32 = @as(u32, @intCast(imported_globals.items.len)) + @as(u32, @intCast(mod.globals.len)); + + if (mod.codes.len != mod.functions.len) return ValidationError.InvalidFunctionIndex; + if (mod.start) |start_idx| { + if (start_idx >= total_funcs) return ValidationError.InvalidFunctionIndex; + } + + for (mod.codes, 0..) |body, i| { + const type_idx = mod.functions[i]; + if (type_idx >= mod.types.len) return ValidationError.InvalidTypeIndex; + const func_type = &mod.types[type_idx]; + try validateFunction(mod, func_type, &body, total_funcs, num_imported_funcs, total_tables, total_memories, imported_globals.items); + } + + for (mod.exports) |exp| { + switch (exp.desc) { + .func => |idx| if (idx >= total_funcs) return ValidationError.InvalidFunctionIndex, + .table => |idx| if (idx >= total_tables) return ValidationError.UndefinedTable, + .memory => |idx| if (idx >= total_memories) return ValidationError.UndefinedMemory, + .global => |idx| if (idx >= total_globals) return ValidationError.UndefinedGlobal, + } + } + + for (mod.elements) |elem| { + if (elem.table_idx >= total_tables) return ValidationError.UndefinedTable; + for (elem.func_indices) |fi| { + if (fi >= total_funcs) return ValidationError.InvalidFunctionIndex; + } + } + + for (mod.datas) |seg| { + if (seg.kind == .active and seg.memory_idx >= total_memories) return ValidationError.UndefinedMemory; + } +} + +const StackVal = union(enum) { + val: module.ValType, + any, // polymorphic (after unreachable) +}; + +const Frame = struct { + kind: Kind, + start_height: usize, + label_types: []const module.ValType, + result_types: []const module.ValType, + reachable: bool, + const Kind = enum { block, loop, @"if", @"else" }; +}; + +fn validateFunction( + mod: *const module.Module, + func_type: *const module.FuncType, + body: *const module.FunctionBody, + total_funcs: u32, + num_imported_funcs: u32, + total_tables: u32, + total_memories: u32, + imported_globals: []const module.GlobalType, +) ValidationError!void { + var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); + defer arena.deinit(); + const ally = arena.allocator(); + + // Local types: params then declared locals + var local_types: std.ArrayList(module.ValType) = .empty; + try local_types.appendSlice(ally, func_type.params); + for (body.locals) |decl| { + for (0..decl.count) |_| try local_types.append(ally, decl.valtype); + } + + var stack: std.ArrayList(StackVal) = .empty; + var frames: std.ArrayList(Frame) = .empty; + + // Implicit function frame + try frames.append(ally, .{ + .kind = .block, + .start_height = 0, + .label_types = func_type.results, + .result_types = func_type.results, + .reachable = true, + }); + + var pos: usize = 0; + const code = body.code; + + while (pos < code.len) { + const op = code[pos]; + pos += 1; + + const frame = &frames.items[frames.items.len - 1]; + const reachable = frame.reachable; + + switch (op) { + 0x00 => { // unreachable + if (reachable) { + stack.shrinkRetainingCapacity(frame.start_height); + frame.reachable = false; + } + }, + 0x01 => {}, // nop + 0x02 => { // block bt + const bt = try readBlockType(code, &pos); + const res = blockTypeResults(mod, bt); + try frames.append(ally, .{ + .kind = .block, + .start_height = stack.items.len, + .label_types = res, + .result_types = res, + .reachable = reachable, + }); + }, + 0x03 => { // loop bt + const bt = try readBlockType(code, &pos); + const params = blockTypeParams(mod, bt); + const res = blockTypeResults(mod, bt); + try frames.append(ally, .{ + .kind = .loop, + .start_height = stack.items.len, + .label_types = params, + .result_types = res, + .reachable = reachable, + }); + }, + 0x04 => { // if bt + const bt = try readBlockType(code, &pos); + const res = blockTypeResults(mod, bt); + if (reachable) try popExpect(ally, &stack, frame.start_height, .i32); + try frames.append(ally, .{ + .kind = .@"if", + .start_height = stack.items.len, + .label_types = res, + .result_types = res, + .reachable = reachable, + }); + }, + 0x05 => { // else + const cur = &frames.items[frames.items.len - 1]; + if (cur.kind != .@"if") return ValidationError.ElseWithoutIf; + if (cur.reachable) { + try checkStackTypes(&stack, cur.start_height, cur.result_types); + } + stack.shrinkRetainingCapacity(cur.start_height); + cur.kind = .@"else"; + cur.reachable = frames.items[frames.items.len - 2].reachable; + }, + 0x0B => { // end + if (frames.items.len == 1) { + const f = frames.items[0]; + if (f.reachable) try checkStackTypes(&stack, f.start_height, f.result_types); + break; + } + const cur = frames.pop().?; + if (cur.reachable) try checkStackTypes(&stack, cur.start_height, cur.result_types); + stack.shrinkRetainingCapacity(cur.start_height); + for (cur.result_types) |rt| try stack.append(ally, .{ .val = rt }); + // propagate unreachability + const parent = &frames.items[frames.items.len - 1]; + if (!cur.reachable) parent.reachable = false; + }, + 0x0C => { // br l + const depth = try readULEB128(u32, code, &pos); + if (depth >= frames.items.len) return ValidationError.InvalidLabelDepth; + const target = &frames.items[frames.items.len - 1 - depth]; + if (reachable) try checkStackTypes(&stack, frame.start_height, target.label_types); + stack.shrinkRetainingCapacity(frame.start_height); + frame.reachable = false; + }, + 0x0D => { // br_if l + const depth = try readULEB128(u32, code, &pos); + if (depth >= frames.items.len) return ValidationError.InvalidLabelDepth; + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i32); + const target = &frames.items[frames.items.len - 1 - depth]; + try checkStackTypes(&stack, frame.start_height, target.label_types); + } + }, + 0x0E => { // br_table + const n = try readULEB128(u32, code, &pos); + var label_types: ?[]const module.ValType = null; + var i: u32 = 0; + while (i <= n) : (i += 1) { + const depth = try readULEB128(u32, code, &pos); + if (depth >= frames.items.len) return ValidationError.InvalidLabelDepth; + const target = &frames.items[frames.items.len - 1 - depth]; + if (label_types == null) { + label_types = target.label_types; + } else if (!sameValTypeSlice(label_types.?, target.label_types)) { + return ValidationError.TypeMismatch; + } + } + if (reachable) try popExpect(ally, &stack, frame.start_height, .i32); + stack.shrinkRetainingCapacity(frame.start_height); + frame.reachable = false; + }, + 0x0F => { // return + if (reachable) { + try checkStackTypes(&stack, frame.start_height, frames.items[0].result_types); + } + stack.shrinkRetainingCapacity(frame.start_height); + frame.reachable = false; + }, + 0x10 => { // call + const fidx = try readULEB128(u32, code, &pos); + if (fidx >= total_funcs) return ValidationError.UndefinedFunction; + const ft = getFuncType(mod, fidx, num_imported_funcs); + if (reachable) { + var pi = ft.params.len; + while (pi > 0) : (pi -= 1) try popExpect(ally, &stack, frame.start_height, ft.params[pi - 1]); + for (ft.results) |rt| try stack.append(ally, .{ .val = rt }); + } + }, + 0x11 => { // call_indirect + const type_idx = try readULEB128(u32, code, &pos); + const table_idx = try readULEB128(u32, code, &pos); + if (type_idx >= mod.types.len) return ValidationError.InvalidTypeIndex; + if (table_idx >= total_tables) return ValidationError.UndefinedTable; + const ft = &mod.types[type_idx]; + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i32); + var pi = ft.params.len; + while (pi > 0) : (pi -= 1) try popExpect(ally, &stack, frame.start_height, ft.params[pi - 1]); + for (ft.results) |rt| try stack.append(ally, .{ .val = rt }); + } + }, + 0x1A => { // drop + if (reachable) _ = try popAny(&stack, frame.start_height); + }, + 0x1B => { // select + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i32); + const t2 = try popAny(&stack, frame.start_height); + const t1 = try popAny(&stack, frame.start_height); + if (t1 == .val and t2 == .val and t1.val != t2.val) return ValidationError.TypeMismatch; + try stack.append(ally, t1); + } + }, + 0x1C => { // select (typed) + const n = try readULEB128(u32, code, &pos); + if (n != 1) return ValidationError.TypeMismatch; + const b = try readByte(code, &pos); + const t: module.ValType = switch (b) { + 0x7F => .i32, + 0x7E => .i64, + 0x7D => .f32, + 0x7C => .f64, + else => return ValidationError.TypeMismatch, + }; + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i32); + try popExpect(ally, &stack, frame.start_height, t); + try popExpect(ally, &stack, frame.start_height, t); + try stack.append(ally, .{ .val = t }); + } + }, + 0x20 => { // local.get + const idx = try readULEB128(u32, code, &pos); + if (idx >= local_types.items.len) return ValidationError.UndefinedLocal; + if (reachable) try stack.append(ally, .{ .val = local_types.items[idx] }); + }, + 0x21 => { // local.set + const idx = try readULEB128(u32, code, &pos); + if (idx >= local_types.items.len) return ValidationError.UndefinedLocal; + if (reachable) try popExpect(ally, &stack, frame.start_height, local_types.items[idx]); + }, + 0x22 => { // local.tee + const idx = try readULEB128(u32, code, &pos); + if (idx >= local_types.items.len) return ValidationError.UndefinedLocal; + if (reachable) { + try popExpect(ally, &stack, frame.start_height, local_types.items[idx]); + try stack.append(ally, .{ .val = local_types.items[idx] }); + } + }, + 0x23 => { // global.get + const idx = try readULEB128(u32, code, &pos); + const gt = try getGlobalType(mod, imported_globals, idx); + if (reachable) try stack.append(ally, .{ .val = gt.valtype }); + }, + 0x24 => { // global.set + const idx = try readULEB128(u32, code, &pos); + const gt = try getGlobalType(mod, imported_globals, idx); + if (!gt.mutable) return ValidationError.ImmutableGlobal; + if (reachable) try popExpect(ally, &stack, frame.start_height, gt.valtype); + }, + 0x28...0x35 => { // memory loads + const mem_align = try readULEB128(u32, code, &pos); + _ = try readULEB128(u32, code, &pos); // offset + if (total_memories == 0) return ValidationError.UndefinedMemory; + if (mem_align > naturalAlignmentLog2ForLoad(op)) return ValidationError.InvalidAlignment; + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i32); + try stack.append(ally, .{ .val = memLoadResultType(op) }); + } + }, + 0x36...0x3E => { // memory stores + const mem_align = try readULEB128(u32, code, &pos); + _ = try readULEB128(u32, code, &pos); // offset + if (total_memories == 0) return ValidationError.UndefinedMemory; + if (mem_align > naturalAlignmentLog2ForStore(op)) return ValidationError.InvalidAlignment; + if (reachable) { + try popExpect(ally, &stack, frame.start_height, memStoreValType(op)); + try popExpect(ally, &stack, frame.start_height, .i32); + } + }, + 0x3F, 0x40 => { // memory.size / memory.grow + _ = try readByte(code, &pos); + if (total_memories == 0) return ValidationError.UndefinedMemory; + if (reachable) { + if (op == 0x40) try popExpect(ally, &stack, frame.start_height, .i32); + try stack.append(ally, .{ .val = .i32 }); + } + }, + 0xFC => { // bulk memory + const subop = try readULEB128(u32, code, &pos); + switch (subop) { + 0, 1 => { // i32.trunc_sat_f32_{s,u} + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .f32); + try stack.append(ally, .{ .val = .i32 }); + } + }, + 2, 3 => { // i32.trunc_sat_f64_{s,u} + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .f64); + try stack.append(ally, .{ .val = .i32 }); + } + }, + 4, 5 => { // i64.trunc_sat_f32_{s,u} + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .f32); + try stack.append(ally, .{ .val = .i64 }); + } + }, + 6, 7 => { // i64.trunc_sat_f64_{s,u} + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .f64); + try stack.append(ally, .{ .val = .i64 }); + } + }, + 8 => { // memory.init + const data_idx = try readULEB128(u32, code, &pos); + const mem_idx = try readULEB128(u32, code, &pos); + if (data_idx >= mod.datas.len) return ValidationError.TypeMismatch; + if (total_memories == 0) return ValidationError.UndefinedMemory; + if (mem_idx != 0) return ValidationError.UndefinedMemory; + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i32); + try popExpect(ally, &stack, frame.start_height, .i32); + try popExpect(ally, &stack, frame.start_height, .i32); + } + }, + 9 => { // data.drop + const data_idx = try readULEB128(u32, code, &pos); + if (data_idx >= mod.datas.len) return ValidationError.TypeMismatch; + }, + 10 => { // memory.copy + const dst_mem = try readULEB128(u32, code, &pos); + const src_mem = try readULEB128(u32, code, &pos); + if (total_memories == 0) return ValidationError.UndefinedMemory; + if (dst_mem != 0 or src_mem != 0) return ValidationError.UndefinedMemory; + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i32); + try popExpect(ally, &stack, frame.start_height, .i32); + try popExpect(ally, &stack, frame.start_height, .i32); + } + }, + 11 => { // memory.fill + const mem_idx = try readULEB128(u32, code, &pos); + if (total_memories == 0) return ValidationError.UndefinedMemory; + if (mem_idx != 0) return ValidationError.UndefinedMemory; + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i32); + try popExpect(ally, &stack, frame.start_height, .i32); + try popExpect(ally, &stack, frame.start_height, .i32); + } + }, + 16 => { // table.size + const table_idx = try readULEB128(u32, code, &pos); + if (table_idx >= total_tables) return ValidationError.UndefinedTable; + if (reachable) try stack.append(ally, .{ .val = .i32 }); + }, + else => return ValidationError.TypeMismatch, + } + }, + 0x41 => { + _ = try readSLEB128(i32, code, &pos); + if (reachable) try stack.append(ally, .{ .val = .i32 }); + }, + 0x42 => { + _ = try readSLEB128(i64, code, &pos); + if (reachable) try stack.append(ally, .{ .val = .i64 }); + }, + 0x43 => { + if (pos + 4 > code.len) return ValidationError.StackUnderflow; + pos += 4; + if (reachable) try stack.append(ally, .{ .val = .f32 }); + }, + 0x44 => { + if (pos + 8 > code.len) return ValidationError.StackUnderflow; + pos += 8; + if (reachable) try stack.append(ally, .{ .val = .f64 }); + }, + // i32 eqz (unary) + 0x45 => { + if (reachable) { try popExpect(ally, &stack, frame.start_height, .i32); try stack.append(ally, .{ .val = .i32 }); } + }, + // i32 comparisons (binary -> i32) + 0x46...0x4F => { + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i32); + try popExpect(ally, &stack, frame.start_height, .i32); + try stack.append(ally, .{ .val = .i32 }); + } + }, + // i64 eqz + 0x50 => { + if (reachable) { try popExpect(ally, &stack, frame.start_height, .i64); try stack.append(ally, .{ .val = .i32 }); } + }, + // i64 comparisons + 0x51...0x5A => { + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i64); + try popExpect(ally, &stack, frame.start_height, .i64); + try stack.append(ally, .{ .val = .i32 }); + } + }, + // f32 comparisons + 0x5B...0x60 => { + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .f32); + try popExpect(ally, &stack, frame.start_height, .f32); + try stack.append(ally, .{ .val = .i32 }); + } + }, + // f64 comparisons + 0x61...0x66 => { + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .f64); + try popExpect(ally, &stack, frame.start_height, .f64); + try stack.append(ally, .{ .val = .i32 }); + } + }, + // i32 unary ops (clz, ctz, popcnt) + 0x67, 0x68, 0x69 => { + if (reachable) { try popExpect(ally, &stack, frame.start_height, .i32); try stack.append(ally, .{ .val = .i32 }); } + }, + // i32 binary ops + 0x6A...0x78 => { + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i32); + try popExpect(ally, &stack, frame.start_height, .i32); + try stack.append(ally, .{ .val = .i32 }); + } + }, + // i64 unary ops + 0x79, 0x7A, 0x7B => { + if (reachable) { try popExpect(ally, &stack, frame.start_height, .i64); try stack.append(ally, .{ .val = .i64 }); } + }, + // i64 binary ops + 0x7C...0x8A => { + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .i64); + try popExpect(ally, &stack, frame.start_height, .i64); + try stack.append(ally, .{ .val = .i64 }); + } + }, + // f32 unary ops + 0x8B...0x91 => { + if (reachable) { try popExpect(ally, &stack, frame.start_height, .f32); try stack.append(ally, .{ .val = .f32 }); } + }, + // f32 binary ops + 0x92...0x98 => { + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .f32); + try popExpect(ally, &stack, frame.start_height, .f32); + try stack.append(ally, .{ .val = .f32 }); + } + }, + // f64 unary ops + 0x99...0x9F => { + if (reachable) { try popExpect(ally, &stack, frame.start_height, .f64); try stack.append(ally, .{ .val = .f64 }); } + }, + // f64 binary ops + 0xA0...0xA6 => { + if (reachable) { + try popExpect(ally, &stack, frame.start_height, .f64); + try popExpect(ally, &stack, frame.start_height, .f64); + try stack.append(ally, .{ .val = .f64 }); + } + }, + // Conversions + 0xA7 => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .i64); try stack.append(ally, .{ .val = .i32 }); } }, // i32.wrap_i64 + 0xA8, 0xA9 => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .f32); try stack.append(ally, .{ .val = .i32 }); } }, + 0xAA, 0xAB => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .f64); try stack.append(ally, .{ .val = .i32 }); } }, + 0xAC, 0xAD => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .i32); try stack.append(ally, .{ .val = .i64 }); } }, + 0xAE, 0xAF => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .f32); try stack.append(ally, .{ .val = .i64 }); } }, + 0xB0, 0xB1 => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .f64); try stack.append(ally, .{ .val = .i64 }); } }, + 0xB2, 0xB3 => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .i32); try stack.append(ally, .{ .val = .f32 }); } }, + 0xB4, 0xB5 => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .i64); try stack.append(ally, .{ .val = .f32 }); } }, + 0xB6 => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .f64); try stack.append(ally, .{ .val = .f32 }); } }, + 0xB7, 0xB8 => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .i32); try stack.append(ally, .{ .val = .f64 }); } }, + 0xB9, 0xBA => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .i64); try stack.append(ally, .{ .val = .f64 }); } }, + 0xBB => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .f32); try stack.append(ally, .{ .val = .f64 }); } }, + 0xBC => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .f32); try stack.append(ally, .{ .val = .i32 }); } }, + 0xBD => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .f64); try stack.append(ally, .{ .val = .i64 }); } }, + 0xBE => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .i32); try stack.append(ally, .{ .val = .f32 }); } }, + 0xBF => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .i64); try stack.append(ally, .{ .val = .f64 }); } }, + 0xC0, 0xC1 => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .i32); try stack.append(ally, .{ .val = .i32 }); } }, + 0xC2, 0xC3, 0xC4 => { if (reachable) { try popExpect(ally, &stack, frame.start_height, .i64); try stack.append(ally, .{ .val = .i64 }); } }, + else => return ValidationError.UnsupportedOpcode, + } + } +} + +// Block type is encoded as SLEB128: +// -1 = i32, -2 = i64, -3 = f32, -4 = f64, -64 = void, >=0 = type index +fn readBlockType(code: []const u8, pos: *usize) ValidationError!i33 { + return @import("binary.zig").readSLEB128(i33, code, pos) catch ValidationError.TypeMismatch; +} + +fn blockTypeResults(mod: *const module.Module, bt: i33) []const module.ValType { + return switch (bt) { + -1 => &[_]module.ValType{.i32}, + -2 => &[_]module.ValType{.i64}, + -3 => &[_]module.ValType{.f32}, + -4 => &[_]module.ValType{.f64}, + -64 => &.{}, // void + else => if (bt >= 0) blk: { + const idx: u32 = @intCast(bt); + if (idx >= mod.types.len) break :blk &.{}; + break :blk mod.types[idx].results; + } else &.{}, + }; +} + +fn blockTypeParams(mod: *const module.Module, bt: i33) []const module.ValType { + if (bt < 0) return &.{}; + const idx: u32 = @intCast(bt); + if (idx >= mod.types.len) return &.{}; + return mod.types[idx].params; +} + +fn getFuncType(mod: *const module.Module, fidx: u32, num_imported: u32) *const module.FuncType { + if (fidx < num_imported) { + var count: u32 = 0; + for (mod.imports) |imp| { + if (imp.desc == .func) { + if (count == fidx) return &mod.types[imp.desc.func]; + count += 1; + } + } + } + const local_idx = fidx - num_imported; + return &mod.types[mod.functions[local_idx]]; +} + +fn getGlobalType(mod: *const module.Module, imported_globals: []const module.GlobalType, idx: u32) ValidationError!module.GlobalType { + if (idx < imported_globals.len) return imported_globals[idx]; + const local_idx = idx - @as(u32, @intCast(imported_globals.len)); + if (local_idx >= mod.globals.len) return ValidationError.UndefinedGlobal; + return mod.globals[local_idx].type; +} + +fn popAny(stack: *std.ArrayList(StackVal), min_height: usize) ValidationError!StackVal { + if (stack.items.len <= min_height) return ValidationError.StackUnderflow; + return stack.pop().?; +} + +fn popExpect(ally: std.mem.Allocator, stack: *std.ArrayList(StackVal), min_height: usize, expected: module.ValType) ValidationError!void { + _ = ally; + if (stack.items.len <= min_height) return ValidationError.StackUnderflow; + const top = stack.pop().?; + switch (top) { + .any => {}, + .val => |v| if (v != expected) return ValidationError.TypeMismatch, + } +} + +fn checkStackTypes(stack: *std.ArrayList(StackVal), base: usize, expected: []const module.ValType) ValidationError!void { + const actual_count = stack.items.len - base; + if (actual_count < expected.len) return ValidationError.StackUnderflow; + for (expected, 0..) |et, i| { + const stack_idx = base + i; + switch (stack.items[stack_idx]) { + .any => {}, + .val => |v| if (v != et) return ValidationError.TypeMismatch, + } + } +} + +fn readByte(code: []const u8, pos: *usize) ValidationError!u8 { + if (pos.* >= code.len) return ValidationError.StackUnderflow; + const b = code[pos.*]; + pos.* += 1; + return b; +} + +fn readULEB128(comptime T: type, code: []const u8, pos: *usize) ValidationError!T { + return @import("binary.zig").readULEB128(T, code, pos) catch ValidationError.TypeMismatch; +} + +fn readSLEB128(comptime T: type, code: []const u8, pos: *usize) ValidationError!T { + return @import("binary.zig").readSLEB128(T, code, pos) catch ValidationError.TypeMismatch; +} + +fn memLoadResultType(op: u8) module.ValType { + return switch (op) { + 0x28, 0x2C, 0x2D, 0x2E, 0x2F => .i32, + 0x29, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35 => .i64, + 0x2A => .f32, + 0x2B => .f64, + else => .i32, + }; +} + +fn memStoreValType(op: u8) module.ValType { + return switch (op) { + 0x36, 0x3A, 0x3B => .i32, + 0x37, 0x3C, 0x3D, 0x3E => .i64, + 0x38 => .f32, + 0x39 => .f64, + else => .i32, + }; +} + +fn naturalAlignmentLog2ForLoad(op: u8) u32 { + return switch (op) { + 0x28 => 2, // i32.load + 0x29 => 3, // i64.load + 0x2A => 2, // f32.load + 0x2B => 3, // f64.load + 0x2C, 0x2D => 0, // i32.load8_{s,u} + 0x2E, 0x2F => 1, // i32.load16_{s,u} + 0x30, 0x31 => 0, // i64.load8_{s,u} + 0x32, 0x33 => 1, // i64.load16_{s,u} + 0x34, 0x35 => 2, // i64.load32_{s,u} + else => 0, + }; +} + +fn naturalAlignmentLog2ForStore(op: u8) u32 { + return switch (op) { + 0x36 => 2, // i32.store + 0x37 => 3, // i64.store + 0x38 => 2, // f32.store + 0x39 => 3, // f64.store + 0x3A, 0x3C => 0, // i32/i64.store8 + 0x3B, 0x3D => 1, // i32/i64.store16 + 0x3E => 2, // i64.store32 + else => 0, + }; +} + +fn sameValTypeSlice(a: []const module.ValType, b: []const module.ValType) bool { + if (a.len != b.len) return false; + for (a, 0..) |vt, i| { + if (vt != b[i]) return false; + } + return true; +} + +const fib_wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x06, 0x01, 0x60, 0x01, 0x7f, 0x01, 0x7f, + 0x03, 0x02, 0x01, 0x00, + 0x07, 0x07, 0x01, 0x03, 0x66, 0x69, 0x62, 0x00, 0x00, + 0x0a, 0x1e, 0x01, 0x1c, 0x00, 0x20, 0x00, 0x41, 0x02, 0x48, 0x04, + 0x7f, 0x20, 0x00, 0x05, 0x20, 0x00, 0x41, 0x01, 0x6b, 0x10, 0x00, + 0x20, 0x00, 0x41, 0x02, 0x6b, 0x10, 0x00, 0x6a, 0x0b, 0x0b, +}; + +test "validate fib module" { + const binary = @import("binary.zig"); + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &fib_wasm); + defer mod.deinit(); + try validate(&mod); +} + +test "validate minimal module" { + const binary = @import("binary.zig"); + const ally = std.testing.allocator; + var mod = try binary.parse(ally, "\x00asm\x01\x00\x00\x00"); + defer mod.deinit(); + try validate(&mod); +} + +test "validate rejects out-of-bounds local" { + // (func (result i32) local.get 99) — 99 as LEB128 = 0x63 + // body: 0x00(locals) 0x20(local.get) 0x63(99) 0x0b(end) = 4 bytes + // section: 0x01(count) 0x04(body_size) + 4 body bytes = 6 bytes total + const binary = @import("binary.zig"); + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type section: () -> (i32) + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + // function section: type 0 + 0x03, 0x02, 0x01, 0x00, + // code section: body_size=4, 0 locals, local.get 99, end + 0x0a, 0x06, 0x01, 0x04, 0x00, 0x20, 0x63, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + try std.testing.expectError(ValidationError.UndefinedLocal, validate(&mod)); +} + +test "validate rejects function and code section length mismatch" { + const binary = @import("binary.zig"); + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type section: () -> () + 0x01, 0x04, 0x01, 0x60, 0x00, 0x00, + // function section: one local function + 0x03, 0x02, 0x01, 0x00, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + try std.testing.expectError(ValidationError.InvalidFunctionIndex, validate(&mod)); +} + +test "validate rejects call_indirect with out-of-range table index" { + const binary = @import("binary.zig"); + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type section: () -> () + 0x01, 0x04, 0x01, 0x60, 0x00, 0x00, + // function section: one function, type 0 + 0x03, 0x02, 0x01, 0x00, + // table section: one funcref table + 0x04, 0x04, 0x01, 0x70, 0x00, 0x01, + // code section: i32.const 0; call_indirect type=0 table=1; end + 0x0a, 0x09, 0x01, 0x07, 0x00, 0x41, 0x00, 0x11, 0x00, 0x01, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + try std.testing.expectError(ValidationError.UndefinedTable, validate(&mod)); +} + +test "validate accepts imported globals in global index space" { + const binary = @import("binary.zig"); + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type section: () -> (i32) + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + // import section: (import "env" "g" (global i32)) + 0x02, 0x0a, 0x01, 0x03, 0x65, 0x6e, 0x76, 0x01, 0x67, 0x03, 0x7f, 0x00, + // function section: one local function, type 0 + 0x03, 0x02, 0x01, 0x00, + // code section: global.get 0; end + 0x0a, 0x06, 0x01, 0x04, 0x00, 0x23, 0x00, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + try validate(&mod); +} + +test "validate rejects memory alignment larger than natural alignment" { + const binary = @import("binary.zig"); + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type section: () -> () + 0x01, 0x04, 0x01, 0x60, 0x00, 0x00, + // function section: one local function, type 0 + 0x03, 0x02, 0x01, 0x00, + // memory section: one memory min=1 + 0x05, 0x03, 0x01, 0x00, 0x01, + // code: i32.const 0; i32.load align=3 offset=0; drop; end + 0x0a, 0x0a, 0x01, 0x08, 0x00, 0x41, 0x00, 0x28, 0x03, 0x00, 0x1a, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + try std.testing.expectError(ValidationError.InvalidAlignment, validate(&mod)); +} + +test "validate rejects br_table with incompatible target label types" { + const binary = @import("binary.zig"); + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type section: () -> (i32) + 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f, + // function section: one local function, type 0 + 0x03, 0x02, 0x01, 0x00, + // code: + // block (result i32) + // block + // i32.const 0 + // br_table 0 1 + // end + // i32.const 1 + // end + // end + 0x0a, 0x12, 0x01, 0x10, 0x00, 0x02, 0x7f, 0x02, 0x40, 0x41, 0x00, 0x0e, 0x01, 0x00, 0x01, 0x0b, 0x41, 0x01, 0x0b, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try binary.parse(ally, &wasm); + defer mod.deinit(); + try std.testing.expectError(ValidationError.TypeMismatch, validate(&mod)); +} + +test "validate rejects unknown opcode" { + const binary_mod = @import("binary.zig"); + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x04, 0x01, 0x60, 0x00, 0x00, + 0x03, 0x02, 0x01, 0x00, + // body: 0 locals, 0xFF (invalid), end + 0x0a, 0x05, 0x01, 0x03, 0x00, 0xff, 0x0b, + }; + const ally = std.testing.allocator; + var mod = try binary_mod.parse(ally, &wasm); + defer mod.deinit(); + try std.testing.expectError(ValidationError.UnsupportedOpcode, validate(&mod)); +} + +test "validate rejects truncated f32.const immediate" { + const binary_mod = @import("binary.zig"); + const wasm = [_]u8{ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, + // type: () -> () + 0x01, 0x04, 0x01, 0x60, 0x00, 0x00, + // function: type 0 + 0x03, 0x02, 0x01, 0x00, + // code body declares size 4: locals=0, f32.const, only 2 bytes of immediate + 0x0a, 0x06, 0x01, 0x04, 0x00, 0x43, 0x00, 0x00, + }; + const ally = std.testing.allocator; + var mod = try binary_mod.parse(ally, &wasm); + defer mod.deinit(); + try std.testing.expectError(ValidationError.StackUnderflow, validate(&mod)); +} diff --git a/tests/wasm/fib.wasm b/tests/wasm/fib.wasm new file mode 100755 index 0000000000000000000000000000000000000000..efa9c53e4afebd1952571c0a4d9fc5e425ee1c7e GIT binary patch literal 116 zcmW;BK@Ng26h+bZeiZ{h8gRlCmcq5f)hYwiXS>m_}-L*`=pjz`!(A&d07n*OF I+^a3H1LWcung9R* literal 0 HcmV?d00001 diff --git a/tests/wasm/fib.wat b/tests/wasm/fib.wat new file mode 100644 index 0000000..12832c0 --- /dev/null +++ b/tests/wasm/fib.wat @@ -0,0 +1,35 @@ +;; tests/wasm/fib.wat +(module + (import "env" "log" (func $log (param i32 i32))) + + (memory (export "memory") 1) + (data (i32.const 0) "fib called\n") + + (func $fib_impl (param $n i32) (result i32) + local.get $n + i32.const 2 + i32.lt_s + if (result i32) + local.get $n + else + local.get $n + i32.const 1 + i32.sub + call $fib_impl + local.get $n + i32.const 2 + i32.sub + call $fib_impl + i32.add + end + ) + + (func (export "fib") (param $n i32) (result i32) + i32.const 0 + i32.const 11 + call $log + + local.get $n + call $fib_impl + ) +) diff --git a/tests/wasm/fib.zig b/tests/wasm/fib.zig new file mode 100644 index 0000000..0ff1393 --- /dev/null +++ b/tests/wasm/fib.zig @@ -0,0 +1,5 @@ +extern "env" fn log(ptr: [*]u8, len: i32) void; + +export fn init() void { + log(@ptrCast(@constCast("Hello world!\n")), 13); +}