From 15d508f86abaecfdc88c1f7fe74ff3f7696d2e22 Mon Sep 17 00:00:00 2001 From: apun Date: Mon, 31 Mar 2025 16:22:39 -0400 Subject: [PATCH] Move LEGO generation parameters for improved configurability --- src/legogpt/models/legogpt.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/src/legogpt/models/legogpt.py b/src/legogpt/models/legogpt.py index 4be3d66..9247485 100644 --- a/src/legogpt/models/legogpt.py +++ b/src/legogpt/models/legogpt.py @@ -40,22 +40,22 @@ class LegoGPT: self, *, world_dim: int = 20, - temperature: float = 0.6, + max_bricks: int = 2000, + max_brick_rejections: int = 100, max_regenerations: int = 100, + temperature: float = 0.6, device: str = 'cuda' if torch.cuda.is_available() else 'cpu', ): self.world_dim = world_dim - self.temperature = temperature + self.max_bricks = max_bricks + self.max_brick_rejections = max_brick_rejections self.max_regenerations = max_regenerations + self.temperature = temperature self.device = device self.llm = LLM('/data/apun/finetuned_hf/Llama-3.2-1B-Instruct_finetuned_combined_2', self.device) - def __call__( - self, - caption: str, - max_bricks: int = 2000, - ) -> dict: + def __call__(self, caption: str) -> dict: lego = None starting_lego = LegoStructure([]) rejection_reasons = Counter() @@ -63,11 +63,7 @@ class LegoGPT: # Generate LEGO structure. If it is unstable, remove all bricks after the first unstable brick and regenerate. for regeneration_num in range(self.max_regenerations + 1): - lego, rejection_reasons_lego = self._generate_structure( - caption, - starting_lego=starting_lego, - max_bricks=max_bricks, - ) + lego, rejection_reasons_lego = self._generate_structure(caption, starting_lego=starting_lego) rejection_reasons.update(rejection_reasons_lego) if regeneration_num == self.max_regenerations or lego.is_stable: break @@ -83,13 +79,11 @@ class LegoGPT: self, caption: str, starting_lego: LegoStructure = LegoStructure([]), - max_bricks: int = 2000, ) -> (LegoStructure, Counter): """ Generates a LEGO structure based on the given caption, starting with a partial LEGO structure. :param caption: A caption for the LEGO structure to be generated. :param starting_lego: A partial LEGO structure to which the generated bricks will be added. - :param max_bricks: The maximum number of bricks to generate. :return: A tuple containing the generated LEGO structure and a brick rejection reasons. """ starting_lego = copy.deepcopy(starting_lego) @@ -111,7 +105,7 @@ class LegoGPT: # Generate bricks with rejection sampling rejection_reasons = Counter() - for brick_num in range(max_bricks): + for brick_num in range(self.max_bricks): brick, rejection_reasons_brick = self.generate_brick_with_rejection_sampling( prompt if brick_num == 0 else None, lego=starting_lego ) @@ -126,7 +120,6 @@ class LegoGPT: self, prompt: str | None = None, lego: LegoStructure = LegoStructure([]), - max_generations_per_brick: int = 10, ) -> (str, Counter): """ Generates a LEGO brick to add to the LEGO structure, using rejection sampling to ensure the brick is valid. @@ -135,14 +128,14 @@ class LegoGPT: rejected_bricks = set() brick = '' - for generation_num in range(max_generations_per_brick): + for generation_num in range(self.max_brick_rejections + 1): self.llm.save_state() brick = self.generate_brick(prompt) if not brick: # Generation is finished break add_brick_result = self._try_adding_brick(brick, lego, rejected_bricks) - if add_brick_result == 'success' or generation_num == max_generations_per_brick - 1: + if add_brick_result == 'success' or generation_num == self.max_brick_rejections: break self.llm.rollback_to_saved_state()