Improved image retention callback

This commit is contained in:
Dillon DuPont
2025-08-28 18:18:40 -04:00
parent 8fa5d7d314
commit ddd01ee719

View File

@@ -50,90 +50,41 @@ class ImageRetentionCallback(AsyncCallbackHandler):
"""
if self.only_n_most_recent_images is None:
return messages
# First pass: Assign call_id to reasoning items based on the next computer_call
messages_with_call_ids = []
for i, msg in enumerate(messages):
msg_copy = msg.copy() if isinstance(msg, dict) else msg
# If this is a reasoning item without a call_id, find the next computer_call
if (msg_copy.get("type") == "reasoning" and
not msg_copy.get("call_id")):
# Look ahead for the next computer_call
for j in range(i + 1, len(messages)):
next_msg = messages[j]
if (next_msg.get("type") == "computer_call" and
next_msg.get("call_id")):
msg_copy["call_id"] = next_msg.get("call_id")
break
messages_with_call_ids.append(msg_copy)
# Find all computer_call_output items with images and their call_ids
image_call_ids = []
for msg in reversed(messages_with_call_ids): # Process in reverse to get most recent first
if (msg.get("type") == "computer_call_output" and
isinstance(msg.get("output"), dict) and
"image_url" in msg.get("output", {})):
call_id = msg.get("call_id")
if call_id and call_id not in image_call_ids:
image_call_ids.append(call_id)
if len(image_call_ids) >= self.only_n_most_recent_images:
break
# Keep the most recent N image call_ids (reverse to get chronological order)
keep_call_ids = set(image_call_ids[:self.only_n_most_recent_images])
# Filter messages: remove computer_call, computer_call_output, and reasoning for old images
filtered_messages = []
for msg in messages_with_call_ids:
msg_type = msg.get("type")
call_id = msg.get("call_id")
# Remove old computer_call items
if msg_type == "computer_call" and call_id not in keep_call_ids:
# Check if this call_id corresponds to an image call
has_image_output = any(
m.get("type") == "computer_call_output" and
m.get("call_id") == call_id and
isinstance(m.get("output"), dict) and
"image_url" in m.get("output", {})
for m in messages_with_call_ids
)
if has_image_output:
continue # Skip this computer_call
# Remove old computer_call_output items with images
if (msg_type == "computer_call_output" and
call_id not in keep_call_ids and
isinstance(msg.get("output"), dict) and
"image_url" in msg.get("output", {})):
continue # Skip this computer_call_output
# Remove old reasoning items that are paired with removed computer calls
if (msg_type == "reasoning" and
call_id and call_id not in keep_call_ids):
# Check if this call_id corresponds to an image call that's being removed
has_image_output = any(
m.get("type") == "computer_call_output" and
m.get("call_id") == call_id and
isinstance(m.get("output"), dict) and
"image_url" in m.get("output", {})
for m in messages_with_call_ids
)
if has_image_output:
continue # Skip this reasoning item
filtered_messages.append(msg)
# Clean up: Remove call_id from reasoning items before returning
final_messages = []
for msg in filtered_messages:
if msg.get("type") == "reasoning" and "call_id" in msg:
# Create a copy without call_id for reasoning items
cleaned_msg = {k: v for k, v in msg.items() if k != "call_id"}
final_messages.append(cleaned_msg)
else:
final_messages.append(msg)
return final_messages
# Gather indices of all computer_call_output messages that contain an image_url
output_indices: List[int] = []
for idx, msg in enumerate(messages):
if msg["type"] == "computer_call_output":
out = msg["output"]
if isinstance(out, dict) and ("image_url" in out):
output_indices.append(idx)
# Nothing to trim
if len(output_indices) <= self.only_n_most_recent_images:
return messages
# Determine which outputs to keep (most recent N)
keep_output_indices = set(output_indices[-self.only_n_most_recent_images :])
# Build set of indices to remove in one pass
to_remove: set[int] = set()
for idx in output_indices:
if idx in keep_output_indices:
continue # keep this screenshot and its context
to_remove.add(idx) # remove the computer_call_output itself
# Remove the immediately preceding computer_call with matching call_id (if present)
call_id = messages[idx]["call_id"]
prev_idx = idx - 1
if prev_idx >= 0 and messages[prev_idx]["type"] == "computer_call" and messages[prev_idx]["call_id"] == call_id:
to_remove.add(prev_idx)
# Check a single reasoning immediately before that computer_call
r_idx = prev_idx - 1
if r_idx >= 0 and messages[r_idx]["type"] == "reasoning":
to_remove.add(r_idx)
# Construct filtered list
filtered = [m for i, m in enumerate(messages) if i not in to_remove]
return filtered