Inference masking fix (allow ")<eot>" in addition to ")\n<eot>")

This commit is contained in:
apun
2025-03-31 10:49:18 -04:00
parent 72ebc032e5
commit c608390d65
+1 -1
View File
@@ -125,7 +125,7 @@ class LegoGPT:
result_ids = []
for allowed_strs in [
allowed_dims + (self.llm.tokenizer.eos_token,), ('x',), allowed_dims, (' (',), allowed_posns, (',',),
allowed_posns, (',',), allowed_posns, (')\n',),
allowed_posns, (',',), allowed_posns, (')\n', ')'),
]:
next_token_id = self.llm(
prompt,