From ddd01ee719a85ee80a9b6156d80b56190237bf52 Mon Sep 17 00:00:00 2001 From: Dillon DuPont Date: Thu, 28 Aug 2025 18:18:40 -0400 Subject: [PATCH] Improved image retention callback --- .../agent/agent/callbacks/image_retention.py | 125 ++++++------------ 1 file changed, 38 insertions(+), 87 deletions(-) diff --git a/libs/python/agent/agent/callbacks/image_retention.py b/libs/python/agent/agent/callbacks/image_retention.py index d91754b1..ff38a6dd 100644 --- a/libs/python/agent/agent/callbacks/image_retention.py +++ b/libs/python/agent/agent/callbacks/image_retention.py @@ -50,90 +50,41 @@ class ImageRetentionCallback(AsyncCallbackHandler): """ if self.only_n_most_recent_images is None: return messages - - # First pass: Assign call_id to reasoning items based on the next computer_call - messages_with_call_ids = [] - for i, msg in enumerate(messages): - msg_copy = msg.copy() if isinstance(msg, dict) else msg - - # If this is a reasoning item without a call_id, find the next computer_call - if (msg_copy.get("type") == "reasoning" and - not msg_copy.get("call_id")): - # Look ahead for the next computer_call - for j in range(i + 1, len(messages)): - next_msg = messages[j] - if (next_msg.get("type") == "computer_call" and - next_msg.get("call_id")): - msg_copy["call_id"] = next_msg.get("call_id") - break - - messages_with_call_ids.append(msg_copy) - - # Find all computer_call_output items with images and their call_ids - image_call_ids = [] - for msg in reversed(messages_with_call_ids): # Process in reverse to get most recent first - if (msg.get("type") == "computer_call_output" and - isinstance(msg.get("output"), dict) and - "image_url" in msg.get("output", {})): - call_id = msg.get("call_id") - if call_id and call_id not in image_call_ids: - image_call_ids.append(call_id) - if len(image_call_ids) >= self.only_n_most_recent_images: - break - - # Keep the most recent N image call_ids (reverse to get chronological order) - keep_call_ids = set(image_call_ids[:self.only_n_most_recent_images]) - - # Filter messages: remove computer_call, computer_call_output, and reasoning for old images - filtered_messages = [] - for msg in messages_with_call_ids: - msg_type = msg.get("type") - call_id = msg.get("call_id") - - # Remove old computer_call items - if msg_type == "computer_call" and call_id not in keep_call_ids: - # Check if this call_id corresponds to an image call - has_image_output = any( - m.get("type") == "computer_call_output" and - m.get("call_id") == call_id and - isinstance(m.get("output"), dict) and - "image_url" in m.get("output", {}) - for m in messages_with_call_ids - ) - if has_image_output: - continue # Skip this computer_call - - # Remove old computer_call_output items with images - if (msg_type == "computer_call_output" and - call_id not in keep_call_ids and - isinstance(msg.get("output"), dict) and - "image_url" in msg.get("output", {})): - continue # Skip this computer_call_output - - # Remove old reasoning items that are paired with removed computer calls - if (msg_type == "reasoning" and - call_id and call_id not in keep_call_ids): - # Check if this call_id corresponds to an image call that's being removed - has_image_output = any( - m.get("type") == "computer_call_output" and - m.get("call_id") == call_id and - isinstance(m.get("output"), dict) and - "image_url" in m.get("output", {}) - for m in messages_with_call_ids - ) - if has_image_output: - continue # Skip this reasoning item - - filtered_messages.append(msg) - - # Clean up: Remove call_id from reasoning items before returning - final_messages = [] - for msg in filtered_messages: - if msg.get("type") == "reasoning" and "call_id" in msg: - # Create a copy without call_id for reasoning items - cleaned_msg = {k: v for k, v in msg.items() if k != "call_id"} - final_messages.append(cleaned_msg) - else: - final_messages.append(msg) - - return final_messages \ No newline at end of file + + # Gather indices of all computer_call_output messages that contain an image_url + output_indices: List[int] = [] + for idx, msg in enumerate(messages): + if msg["type"] == "computer_call_output": + out = msg["output"] + if isinstance(out, dict) and ("image_url" in out): + output_indices.append(idx) + + # Nothing to trim + if len(output_indices) <= self.only_n_most_recent_images: + return messages + + # Determine which outputs to keep (most recent N) + keep_output_indices = set(output_indices[-self.only_n_most_recent_images :]) + + # Build set of indices to remove in one pass + to_remove: set[int] = set() + + for idx in output_indices: + if idx in keep_output_indices: + continue # keep this screenshot and its context + + to_remove.add(idx) # remove the computer_call_output itself + + # Remove the immediately preceding computer_call with matching call_id (if present) + call_id = messages[idx]["call_id"] + prev_idx = idx - 1 + if prev_idx >= 0 and messages[prev_idx]["type"] == "computer_call" and messages[prev_idx]["call_id"] == call_id: + to_remove.add(prev_idx) + # Check a single reasoning immediately before that computer_call + r_idx = prev_idx - 1 + if r_idx >= 0 and messages[r_idx]["type"] == "reasoning": + to_remove.add(r_idx) + + # Construct filtered list + filtered = [m for i, m in enumerate(messages) if i not in to_remove] + return filtered \ No newline at end of file