Move LEGO generation parameters for improved configurability

This commit is contained in:
apun
2025-03-31 16:22:39 -04:00
parent f615a37e0a
commit 15d508f86a
+11 -18
View File
@@ -40,22 +40,22 @@ class LegoGPT:
self,
*,
world_dim: int = 20,
temperature: float = 0.6,
max_bricks: int = 2000,
max_brick_rejections: int = 100,
max_regenerations: int = 100,
temperature: float = 0.6,
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
):
self.world_dim = world_dim
self.temperature = temperature
self.max_bricks = max_bricks
self.max_brick_rejections = max_brick_rejections
self.max_regenerations = max_regenerations
self.temperature = temperature
self.device = device
self.llm = LLM('/data/apun/finetuned_hf/Llama-3.2-1B-Instruct_finetuned_combined_2', self.device)
def __call__(
self,
caption: str,
max_bricks: int = 2000,
) -> dict:
def __call__(self, caption: str) -> dict:
lego = None
starting_lego = LegoStructure([])
rejection_reasons = Counter()
@@ -63,11 +63,7 @@ class LegoGPT:
# Generate LEGO structure. If it is unstable, remove all bricks after the first unstable brick and regenerate.
for regeneration_num in range(self.max_regenerations + 1):
lego, rejection_reasons_lego = self._generate_structure(
caption,
starting_lego=starting_lego,
max_bricks=max_bricks,
)
lego, rejection_reasons_lego = self._generate_structure(caption, starting_lego=starting_lego)
rejection_reasons.update(rejection_reasons_lego)
if regeneration_num == self.max_regenerations or lego.is_stable:
break
@@ -83,13 +79,11 @@ class LegoGPT:
self,
caption: str,
starting_lego: LegoStructure = LegoStructure([]),
max_bricks: int = 2000,
) -> (LegoStructure, Counter):
"""
Generates a LEGO structure based on the given caption, starting with a partial LEGO structure.
:param caption: A caption for the LEGO structure to be generated.
:param starting_lego: A partial LEGO structure to which the generated bricks will be added.
:param max_bricks: The maximum number of bricks to generate.
:return: A tuple containing the generated LEGO structure and a brick rejection reasons.
"""
starting_lego = copy.deepcopy(starting_lego)
@@ -111,7 +105,7 @@ class LegoGPT:
# Generate bricks with rejection sampling
rejection_reasons = Counter()
for brick_num in range(max_bricks):
for brick_num in range(self.max_bricks):
brick, rejection_reasons_brick = self.generate_brick_with_rejection_sampling(
prompt if brick_num == 0 else None, lego=starting_lego
)
@@ -126,7 +120,6 @@ class LegoGPT:
self,
prompt: str | None = None,
lego: LegoStructure = LegoStructure([]),
max_generations_per_brick: int = 10,
) -> (str, Counter):
"""
Generates a LEGO brick to add to the LEGO structure, using rejection sampling to ensure the brick is valid.
@@ -135,14 +128,14 @@ class LegoGPT:
rejected_bricks = set()
brick = ''
for generation_num in range(max_generations_per_brick):
for generation_num in range(self.max_brick_rejections + 1):
self.llm.save_state()
brick = self.generate_brick(prompt)
if not brick: # Generation is finished
break
add_brick_result = self._try_adding_brick(brick, lego, rejected_bricks)
if add_brick_result == 'success' or generation_num == max_generations_per_brick - 1:
if add_brick_result == 'success' or generation_num == self.max_brick_rejections:
break
self.llm.rollback_to_saved_state()