mirror of
https://github.com/AvaLovelace1/LegoGPT.git
synced 2026-05-25 10:19:23 -05:00
Add out-of-bounds check to rejection sampling
This commit is contained in:
@@ -120,10 +120,14 @@ class LegoStructure:
|
||||
self.voxel_occupancy[brick.slice] -= 1
|
||||
self.bricks.pop()
|
||||
|
||||
def brick_in_bounds(self, brick: LegoBrick) -> bool:
|
||||
return (all(slice_.start >= 0 and slice_.stop <= self.world_dim for slice_ in brick.slice_2d)
|
||||
and brick.z >= 0 and brick.z < self.world_dim)
|
||||
|
||||
def has_collisions(self) -> bool:
|
||||
return np.any(self.voxel_occupancy > 1)
|
||||
|
||||
def brick_collides(self, brick: LegoBrick):
|
||||
def brick_collides(self, brick: LegoBrick) -> bool:
|
||||
return np.any(self.voxel_occupancy[brick.slice])
|
||||
|
||||
def has_floating_bricks(self) -> bool:
|
||||
|
||||
@@ -95,12 +95,11 @@ class LegoGPT:
|
||||
except ValueError: # Brick is badly formatted
|
||||
return 'ill_formatted'
|
||||
|
||||
# Try adding brick to the LEGO structure
|
||||
if not lego.brick_in_bounds(brick):
|
||||
return 'out_of_bounds'
|
||||
if lego.brick_collides(brick):
|
||||
result = 'collision'
|
||||
else:
|
||||
result = 'success'
|
||||
return result
|
||||
return 'collision'
|
||||
return 'success'
|
||||
|
||||
def generate_brick(self, prompt: str | None = None) -> str:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user