Add inference masking option for LEGO brick generation

This commit is contained in:
apun
2025-04-08 17:01:36 -04:00
parent 6aaa64b13d
commit bb0acfa239
+30 -1
View File
@@ -59,6 +59,13 @@ class LegoGPTConfig:
metadata={'help': 'The maximum number of rejections per generated brick during rejection sampling. '
'Set to 0 if you want to disable rejection sampling.'},
)
use_inference_masking: bool = field(
default=False,
kw_only=True,
metadata={'help': 'Whether to use logit masking during inference '
'to enforce compliance with the LEGO brick syntax. '
'If False, the LEGO brick will be checked for validity after generation.'},
)
max_regenerations: int = field(
default=100,
kw_only=True,
@@ -78,6 +85,7 @@ class LegoGPT:
self.world_dim = cfg.world_dim
self.max_bricks = cfg.max_bricks
self.max_brick_rejections = cfg.max_brick_rejections
self.use_inference_masking = cfg.use_inference_masking
self.max_regenerations = cfg.max_regenerations
self.temperature = cfg.temperature
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -201,8 +209,29 @@ class LegoGPT:
return 'success'
def generate_brick(self, prompt: str | None = None) -> str:
if self.use_inference_masking:
return self._generate_brick_with_inference_masking(prompt)
else:
return self._generate_brick_no_inference_masking(prompt)
def _generate_brick_no_inference_masking(self, prompt: str | None = None) -> str:
"""
Generates a LEGO brick in txt format, using inference masking to enforce compliance with the LEGO brick syntax.
Generates a LEGO brick in txt format without logit masking.
:param prompt: The prompt to be given to the LLM preceding brick generation.
:return: A LEGO brick in txt format, or the empty string if generation is finished.
"""
result_ids = self.llm(
prompt,
return_as_ids=True,
max_new_tokens=20,
temperature=self.temperature,
top_k=self.top_k,
)
return self.llm.tokenizer.decode(result_ids, skip_special_tokens=True)
def _generate_brick_with_inference_masking(self, prompt: str | None = None) -> str:
"""
Generates a LEGO brick in txt format, using logit masking to enforce compliance with the LEGO brick syntax.
WARNING: Assumes each number in the brick dimensions and positions is represented by 1 token.
:param prompt: The prompt to be given to the LLM preceding brick generation.
:return: A LEGO brick in txt format, or the empty string if generation is finished.