mirror of
https://github.com/AvaLovelace1/LegoGPT.git
synced 2026-05-25 01:58:21 -05:00
Move LEGO generation parameters for improved configurability
This commit is contained in:
@@ -40,22 +40,22 @@ class LegoGPT:
|
||||
self,
|
||||
*,
|
||||
world_dim: int = 20,
|
||||
temperature: float = 0.6,
|
||||
max_bricks: int = 2000,
|
||||
max_brick_rejections: int = 100,
|
||||
max_regenerations: int = 100,
|
||||
temperature: float = 0.6,
|
||||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
|
||||
):
|
||||
self.world_dim = world_dim
|
||||
self.temperature = temperature
|
||||
self.max_bricks = max_bricks
|
||||
self.max_brick_rejections = max_brick_rejections
|
||||
self.max_regenerations = max_regenerations
|
||||
self.temperature = temperature
|
||||
self.device = device
|
||||
|
||||
self.llm = LLM('/data/apun/finetuned_hf/Llama-3.2-1B-Instruct_finetuned_combined_2', self.device)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
caption: str,
|
||||
max_bricks: int = 2000,
|
||||
) -> dict:
|
||||
def __call__(self, caption: str) -> dict:
|
||||
lego = None
|
||||
starting_lego = LegoStructure([])
|
||||
rejection_reasons = Counter()
|
||||
@@ -63,11 +63,7 @@ class LegoGPT:
|
||||
|
||||
# Generate LEGO structure. If it is unstable, remove all bricks after the first unstable brick and regenerate.
|
||||
for regeneration_num in range(self.max_regenerations + 1):
|
||||
lego, rejection_reasons_lego = self._generate_structure(
|
||||
caption,
|
||||
starting_lego=starting_lego,
|
||||
max_bricks=max_bricks,
|
||||
)
|
||||
lego, rejection_reasons_lego = self._generate_structure(caption, starting_lego=starting_lego)
|
||||
rejection_reasons.update(rejection_reasons_lego)
|
||||
if regeneration_num == self.max_regenerations or lego.is_stable:
|
||||
break
|
||||
@@ -83,13 +79,11 @@ class LegoGPT:
|
||||
self,
|
||||
caption: str,
|
||||
starting_lego: LegoStructure = LegoStructure([]),
|
||||
max_bricks: int = 2000,
|
||||
) -> (LegoStructure, Counter):
|
||||
"""
|
||||
Generates a LEGO structure based on the given caption, starting with a partial LEGO structure.
|
||||
:param caption: A caption for the LEGO structure to be generated.
|
||||
:param starting_lego: A partial LEGO structure to which the generated bricks will be added.
|
||||
:param max_bricks: The maximum number of bricks to generate.
|
||||
:return: A tuple containing the generated LEGO structure and a brick rejection reasons.
|
||||
"""
|
||||
starting_lego = copy.deepcopy(starting_lego)
|
||||
@@ -111,7 +105,7 @@ class LegoGPT:
|
||||
|
||||
# Generate bricks with rejection sampling
|
||||
rejection_reasons = Counter()
|
||||
for brick_num in range(max_bricks):
|
||||
for brick_num in range(self.max_bricks):
|
||||
brick, rejection_reasons_brick = self.generate_brick_with_rejection_sampling(
|
||||
prompt if brick_num == 0 else None, lego=starting_lego
|
||||
)
|
||||
@@ -126,7 +120,6 @@ class LegoGPT:
|
||||
self,
|
||||
prompt: str | None = None,
|
||||
lego: LegoStructure = LegoStructure([]),
|
||||
max_generations_per_brick: int = 10,
|
||||
) -> (str, Counter):
|
||||
"""
|
||||
Generates a LEGO brick to add to the LEGO structure, using rejection sampling to ensure the brick is valid.
|
||||
@@ -135,14 +128,14 @@ class LegoGPT:
|
||||
rejected_bricks = set()
|
||||
|
||||
brick = ''
|
||||
for generation_num in range(max_generations_per_brick):
|
||||
for generation_num in range(self.max_brick_rejections + 1):
|
||||
self.llm.save_state()
|
||||
brick = self.generate_brick(prompt)
|
||||
if not brick: # Generation is finished
|
||||
break
|
||||
|
||||
add_brick_result = self._try_adding_brick(brick, lego, rejected_bricks)
|
||||
if add_brick_result == 'success' or generation_num == max_generations_per_brick - 1:
|
||||
if add_brick_result == 'success' or generation_num == self.max_brick_rejections:
|
||||
break
|
||||
|
||||
self.llm.rollback_to_saved_state()
|
||||
|
||||
Reference in New Issue
Block a user