mirror of
https://github.com/trycua/computer.git
synced 2026-01-04 04:19:57 -06:00
Improved image retention callback
This commit is contained in:
@@ -50,90 +50,41 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
if self.only_n_most_recent_images is None:
|
||||
return messages
|
||||
|
||||
# First pass: Assign call_id to reasoning items based on the next computer_call
|
||||
messages_with_call_ids = []
|
||||
for i, msg in enumerate(messages):
|
||||
msg_copy = msg.copy() if isinstance(msg, dict) else msg
|
||||
|
||||
# If this is a reasoning item without a call_id, find the next computer_call
|
||||
if (msg_copy.get("type") == "reasoning" and
|
||||
not msg_copy.get("call_id")):
|
||||
# Look ahead for the next computer_call
|
||||
for j in range(i + 1, len(messages)):
|
||||
next_msg = messages[j]
|
||||
if (next_msg.get("type") == "computer_call" and
|
||||
next_msg.get("call_id")):
|
||||
msg_copy["call_id"] = next_msg.get("call_id")
|
||||
break
|
||||
|
||||
messages_with_call_ids.append(msg_copy)
|
||||
|
||||
# Find all computer_call_output items with images and their call_ids
|
||||
image_call_ids = []
|
||||
for msg in reversed(messages_with_call_ids): # Process in reverse to get most recent first
|
||||
if (msg.get("type") == "computer_call_output" and
|
||||
isinstance(msg.get("output"), dict) and
|
||||
"image_url" in msg.get("output", {})):
|
||||
call_id = msg.get("call_id")
|
||||
if call_id and call_id not in image_call_ids:
|
||||
image_call_ids.append(call_id)
|
||||
if len(image_call_ids) >= self.only_n_most_recent_images:
|
||||
break
|
||||
|
||||
# Keep the most recent N image call_ids (reverse to get chronological order)
|
||||
keep_call_ids = set(image_call_ids[:self.only_n_most_recent_images])
|
||||
|
||||
# Filter messages: remove computer_call, computer_call_output, and reasoning for old images
|
||||
filtered_messages = []
|
||||
for msg in messages_with_call_ids:
|
||||
msg_type = msg.get("type")
|
||||
call_id = msg.get("call_id")
|
||||
|
||||
# Remove old computer_call items
|
||||
if msg_type == "computer_call" and call_id not in keep_call_ids:
|
||||
# Check if this call_id corresponds to an image call
|
||||
has_image_output = any(
|
||||
m.get("type") == "computer_call_output" and
|
||||
m.get("call_id") == call_id and
|
||||
isinstance(m.get("output"), dict) and
|
||||
"image_url" in m.get("output", {})
|
||||
for m in messages_with_call_ids
|
||||
)
|
||||
if has_image_output:
|
||||
continue # Skip this computer_call
|
||||
|
||||
# Remove old computer_call_output items with images
|
||||
if (msg_type == "computer_call_output" and
|
||||
call_id not in keep_call_ids and
|
||||
isinstance(msg.get("output"), dict) and
|
||||
"image_url" in msg.get("output", {})):
|
||||
continue # Skip this computer_call_output
|
||||
|
||||
# Remove old reasoning items that are paired with removed computer calls
|
||||
if (msg_type == "reasoning" and
|
||||
call_id and call_id not in keep_call_ids):
|
||||
# Check if this call_id corresponds to an image call that's being removed
|
||||
has_image_output = any(
|
||||
m.get("type") == "computer_call_output" and
|
||||
m.get("call_id") == call_id and
|
||||
isinstance(m.get("output"), dict) and
|
||||
"image_url" in m.get("output", {})
|
||||
for m in messages_with_call_ids
|
||||
)
|
||||
if has_image_output:
|
||||
continue # Skip this reasoning item
|
||||
|
||||
filtered_messages.append(msg)
|
||||
|
||||
# Clean up: Remove call_id from reasoning items before returning
|
||||
final_messages = []
|
||||
for msg in filtered_messages:
|
||||
if msg.get("type") == "reasoning" and "call_id" in msg:
|
||||
# Create a copy without call_id for reasoning items
|
||||
cleaned_msg = {k: v for k, v in msg.items() if k != "call_id"}
|
||||
final_messages.append(cleaned_msg)
|
||||
else:
|
||||
final_messages.append(msg)
|
||||
|
||||
return final_messages
|
||||
|
||||
# Gather indices of all computer_call_output messages that contain an image_url
|
||||
output_indices: List[int] = []
|
||||
for idx, msg in enumerate(messages):
|
||||
if msg["type"] == "computer_call_output":
|
||||
out = msg["output"]
|
||||
if isinstance(out, dict) and ("image_url" in out):
|
||||
output_indices.append(idx)
|
||||
|
||||
# Nothing to trim
|
||||
if len(output_indices) <= self.only_n_most_recent_images:
|
||||
return messages
|
||||
|
||||
# Determine which outputs to keep (most recent N)
|
||||
keep_output_indices = set(output_indices[-self.only_n_most_recent_images :])
|
||||
|
||||
# Build set of indices to remove in one pass
|
||||
to_remove: set[int] = set()
|
||||
|
||||
for idx in output_indices:
|
||||
if idx in keep_output_indices:
|
||||
continue # keep this screenshot and its context
|
||||
|
||||
to_remove.add(idx) # remove the computer_call_output itself
|
||||
|
||||
# Remove the immediately preceding computer_call with matching call_id (if present)
|
||||
call_id = messages[idx]["call_id"]
|
||||
prev_idx = idx - 1
|
||||
if prev_idx >= 0 and messages[prev_idx]["type"] == "computer_call" and messages[prev_idx]["call_id"] == call_id:
|
||||
to_remove.add(prev_idx)
|
||||
# Check a single reasoning immediately before that computer_call
|
||||
r_idx = prev_idx - 1
|
||||
if r_idx >= 0 and messages[r_idx]["type"] == "reasoning":
|
||||
to_remove.add(r_idx)
|
||||
|
||||
# Construct filtered list
|
||||
filtered = [m for i, m in enumerate(messages) if i not in to_remove]
|
||||
return filtered
|
||||
Reference in New Issue
Block a user