diff --git a/zig-ecs/examples/simple.zig b/zig-ecs/examples/simple.zig index a67c325..9a559b1 100644 --- a/zig-ecs/examples/simple.zig +++ b/zig-ecs/examples/simple.zig @@ -19,7 +19,7 @@ pub fn main() !void { reg.add(e2, Position{ .x = 10, .y = 10 }); reg.add(e2, Velocity{ .x = 15, .y = 17 }); - var view = reg.view(.{ Velocity, Position }); + var view = reg.view(.{ Velocity, Position }, .{}); var iter = view.iterator(); while (iter.next()) |entity| { diff --git a/zig-ecs/src/ecs/registry.zig b/zig-ecs/src/ecs/registry.zig index 96d8944..4de5f06 100644 --- a/zig-ecs/src/ecs/registry.zig +++ b/zig-ecs/src/ecs/registry.zig @@ -254,24 +254,34 @@ pub const Registry = struct { null; } - pub fn view(self: *Registry, comptime includes: var) ViewType(includes) { + pub fn view(self: *Registry, comptime includes: var, comptime excludes: var) ViewType(includes, excludes) { + if (@typeInfo(@TypeOf(includes)) != .Struct) + @compileError("Expected tuple or struct argument, found " ++ @typeName(@TypeOf(args))); + if (@typeInfo(@TypeOf(excludes)) != .Struct) + @compileError("Expected tuple or struct argument, found " ++ @typeName(@TypeOf(excludes))); std.debug.assert(includes.len > 0); - if (includes.len == 1) + if (includes.len == 1 and excludes.len == 0) return BasicView(includes[0]).init(self.assure(includes[0])); - var arr: [includes.len]u32 = undefined; + var includes_arr: [includes.len]u32 = undefined; inline for (includes) |t, i| { _ = self.assure(t); - arr[i] = @as(u32, self.typemap.get(t)); + includes_arr[i] = @as(u32, self.typemap.get(t)); } - return BasicMultiView(includes.len).init(arr, self); + var excludes_arr: [excludes.len]u32 = undefined; + inline for (excludes) |t, i| { + _ = self.assure(t); + excludes_arr[i] = @as(u32, self.typemap.get(t)); + } + + return BasicMultiView(includes.len, excludes.len).init(includes_arr, excludes_arr, self); } - fn ViewType(comptime includes: var) type { - if (includes.len == 1) return BasicView(includes[0]); - return BasicMultiView(includes.len); + fn ViewType(comptime includes: var, comptime excludes: var) type { + if (includes.len == 1 and excludes.len == 0) return BasicView(includes[0]); + return BasicMultiView(includes.len, excludes.len); } }; diff --git a/zig-ecs/src/ecs/view.zig b/zig-ecs/src/ecs/view.zig index e921556..bb21caa 100644 --- a/zig-ecs/src/ecs/view.zig +++ b/zig-ecs/src/ecs/view.zig @@ -5,7 +5,6 @@ const Registry = @import("registry.zig").Registry; const Storage = @import("registry.zig").Storage; const Entity = @import("registry.zig").Entity; - /// single item view. Iterating raw() directly is the fastest way to get at the data. pub fn BasicView(comptime T: type) type { return struct { @@ -14,7 +13,7 @@ pub fn BasicView(comptime T: type) type { storage: *Storage(T), pub fn init(storage: *Storage(T)) Self { - return Self { + return Self{ .storage = storage, }; } @@ -37,14 +36,19 @@ pub fn BasicView(comptime T: type) type { pub fn get(self: Self, entity: Entity) *T { return self.storage.get(entity); } + + pub fn getConst(self: *Self, comptime T: type, entity: Entity) T { + return self.storage.getConst(entity); + } }; } -pub fn BasicMultiView(comptime n: usize) type { +pub fn BasicMultiView(comptime n_includes: usize, comptime n_excludes: usize) type { return struct { const Self = @This(); - type_ids: [n]u32, + type_ids: [n_includes]u32, + exclude_type_ids: [n_excludes]u32, registry: *Registry, pub const Iterator = struct { @@ -69,10 +73,19 @@ pub fn BasicMultiView(comptime n: usize) type { // entity must be in all other Storages for (it.view.type_ids) |tid| { const ptr = it.view.registry.components.getValue(@intCast(u8, tid)).?; - if (!@intToPtr(*Storage(u8), ptr).contains(entity)) { + if (!@intToPtr(*Storage(u1), ptr).contains(entity)) { break :blk; } } + + // entity must not be in all other excluded Storages + for (it.view.exclude_type_ids) |tid| { + const ptr = it.view.registry.components.getValue(@intCast(u8, tid)).?; + if (@intToPtr(*Storage(u1), ptr).contains(entity)) { + break :blk; + } + } + it.index += 1; return entity; } @@ -86,9 +99,10 @@ pub fn BasicMultiView(comptime n: usize) type { } }; - pub fn init(type_ids: [n]u32, registry: *Registry) Self { + pub fn init(type_ids: [n_includes]u32, exclude_type_ids: [n_excludes]u32, registry: *Registry) Self { return Self{ .type_ids = type_ids, + .exclude_type_ids = exclude_type_ids, .registry = registry, }; } @@ -97,8 +111,6 @@ pub fn BasicMultiView(comptime n: usize) type { const type_id = self.registry.typemap.get(T); const ptr = self.registry.components.getValue(type_id).?; const store = @intToPtr(*Storage(T), ptr); - - std.debug.assert(store.contains(entity)); return store.get(entity); } @@ -106,14 +118,12 @@ pub fn BasicMultiView(comptime n: usize) type { const type_id = self.registry.typemap.get(T); const ptr = self.registry.components.getValue(type_id).?; const store = @intToPtr(*Storage(T), ptr); - - std.debug.assert(store.contains(entity)); return store.getConst(entity); } fn sort(self: *Self) void { // get our component counts in an array so we can sort the type_ids based on how many entities are in each - var sub_items: [n]usize = undefined; + var sub_items: [n_includes]usize = undefined; for (self.type_ids) |tid, i| { const ptr = self.registry.components.getValue(@intCast(u8, tid)).?; const store = @intToPtr(*Storage(u8), ptr); @@ -130,7 +140,6 @@ pub fn BasicMultiView(comptime n: usize) type { }; } - test "single basic view" { var store = Storage(f32).init(std.testing.allocator); defer store.deinit(); @@ -194,8 +203,8 @@ test "basic multi view" { reg.add(e0, @as(u32, 0)); reg.add(e2, @as(u32, 2)); - var single_view = reg.view(.{u32}); - var view = reg.view(.{ i32, u32 }); + var single_view = reg.view(.{u32}, .{}); + var view = reg.view(.{ i32, u32 }, .{}); var iterated_entities: usize = 0; var iter = view.iterator(); @@ -214,4 +223,42 @@ test "basic multi view" { } std.testing.expectEqual(iterated_entities, 1); -} \ No newline at end of file +} + +test "basic multi view with excludes" { + var reg = Registry.init(std.testing.allocator); + defer reg.deinit(); + + var e0 = reg.create(); + var e1 = reg.create(); + var e2 = reg.create(); + + reg.add(e0, @as(i32, -0)); + reg.add(e1, @as(i32, -1)); + reg.add(e2, @as(i32, -2)); + + reg.add(e0, @as(u32, 0)); + reg.add(e2, @as(u32, 2)); + + reg.add(e2, @as(u8, 255)); + + var view = reg.view(.{ i32, u32 }, .{u8}); + + var iterated_entities: usize = 0; + var iter = view.iterator(); + while (iter.next()) |entity| { + iterated_entities += 1; + } + + std.testing.expectEqual(iterated_entities, 1); + iterated_entities = 0; + + reg.remove(u8, e2); + + iter.reset(); + while (iter.next()) |entity| { + iterated_entities += 1; + } + + std.testing.expectEqual(iterated_entities, 2); +}