mirror of
https://github.com/AvaLovelace1/LegoGPT.git
synced 2026-05-25 01:58:21 -05:00
Add inference masking option for LEGO brick generation
This commit is contained in:
@@ -59,6 +59,13 @@ class LegoGPTConfig:
|
||||
metadata={'help': 'The maximum number of rejections per generated brick during rejection sampling. '
|
||||
'Set to 0 if you want to disable rejection sampling.'},
|
||||
)
|
||||
use_inference_masking: bool = field(
|
||||
default=False,
|
||||
kw_only=True,
|
||||
metadata={'help': 'Whether to use logit masking during inference '
|
||||
'to enforce compliance with the LEGO brick syntax. '
|
||||
'If False, the LEGO brick will be checked for validity after generation.'},
|
||||
)
|
||||
max_regenerations: int = field(
|
||||
default=100,
|
||||
kw_only=True,
|
||||
@@ -78,6 +85,7 @@ class LegoGPT:
|
||||
self.world_dim = cfg.world_dim
|
||||
self.max_bricks = cfg.max_bricks
|
||||
self.max_brick_rejections = cfg.max_brick_rejections
|
||||
self.use_inference_masking = cfg.use_inference_masking
|
||||
self.max_regenerations = cfg.max_regenerations
|
||||
self.temperature = cfg.temperature
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
@@ -201,8 +209,29 @@ class LegoGPT:
|
||||
return 'success'
|
||||
|
||||
def generate_brick(self, prompt: str | None = None) -> str:
|
||||
if self.use_inference_masking:
|
||||
return self._generate_brick_with_inference_masking(prompt)
|
||||
else:
|
||||
return self._generate_brick_no_inference_masking(prompt)
|
||||
|
||||
def _generate_brick_no_inference_masking(self, prompt: str | None = None) -> str:
|
||||
"""
|
||||
Generates a LEGO brick in txt format, using inference masking to enforce compliance with the LEGO brick syntax.
|
||||
Generates a LEGO brick in txt format without logit masking.
|
||||
:param prompt: The prompt to be given to the LLM preceding brick generation.
|
||||
:return: A LEGO brick in txt format, or the empty string if generation is finished.
|
||||
"""
|
||||
result_ids = self.llm(
|
||||
prompt,
|
||||
return_as_ids=True,
|
||||
max_new_tokens=20,
|
||||
temperature=self.temperature,
|
||||
top_k=self.top_k,
|
||||
)
|
||||
return self.llm.tokenizer.decode(result_ids, skip_special_tokens=True)
|
||||
|
||||
def _generate_brick_with_inference_masking(self, prompt: str | None = None) -> str:
|
||||
"""
|
||||
Generates a LEGO brick in txt format, using logit masking to enforce compliance with the LEGO brick syntax.
|
||||
WARNING: Assumes each number in the brick dimensions and positions is represented by 1 token.
|
||||
:param prompt: The prompt to be given to the LLM preceding brick generation.
|
||||
:return: A LEGO brick in txt format, or the empty string if generation is finished.
|
||||
|
||||
Reference in New Issue
Block a user