mirror of
https://github.com/trycua/computer.git
synced 2025-12-20 12:29:50 -06:00
Add Cua Preview
This commit is contained in:
233
.cursorignore
Normal file
233
.cursorignore
Normal file
@@ -0,0 +1,233 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
!libs/lume/scripts/build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Scripts
|
||||
server/scripts/
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Conda
|
||||
.conda/
|
||||
|
||||
# Local environment
|
||||
.env.local
|
||||
|
||||
# macOS DS_Store
|
||||
.DS_Store
|
||||
|
||||
weights/
|
||||
weights/icon_detect/
|
||||
weights/icon_detect/model.pt
|
||||
weights/icon_detect/model.pt.zip
|
||||
weights/icon_detect/model.pt.zip.part*
|
||||
|
||||
libs/omniparser/weights/icon_detect/model.pt
|
||||
|
||||
# Example test data and output
|
||||
examples/test_data/
|
||||
examples/output/
|
||||
|
||||
/screenshots/
|
||||
|
||||
/experiments/
|
||||
|
||||
/logs/
|
||||
|
||||
# Xcode
|
||||
#
|
||||
# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
|
||||
|
||||
## User settings
|
||||
xcuserdata/
|
||||
|
||||
## Obj-C/Swift specific
|
||||
*.hmap
|
||||
|
||||
## App packaging
|
||||
*.ipa
|
||||
*.dSYM.zip
|
||||
*.dSYM
|
||||
|
||||
## Playgrounds
|
||||
timeline.xctimeline
|
||||
playground.xcworkspace
|
||||
|
||||
# Swift Package Manager
|
||||
#
|
||||
# Add this line if you want to avoid checking in source code from Swift Package Manager dependencies.
|
||||
# Packages/
|
||||
# Package.pins
|
||||
# Package.resolved
|
||||
# *.xcodeproj
|
||||
#
|
||||
# Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata
|
||||
# hence it is not needed unless you have added a package configuration file to your project
|
||||
.swiftpm/
|
||||
.build/
|
||||
|
||||
# CocoaPods
|
||||
#
|
||||
# We recommend against adding the Pods directory to your .gitignore. However
|
||||
# you should judge for yourself, the pros and cons are mentioned at:
|
||||
# https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control
|
||||
#
|
||||
# Pods/
|
||||
#
|
||||
# Add this line if you want to avoid checking in source code from the Xcode workspace
|
||||
# *.xcworkspace
|
||||
|
||||
# Carthage
|
||||
#
|
||||
# Add this line if you want to avoid checking in source code from Carthage dependencies.
|
||||
# Carthage/Checkouts
|
||||
Carthage/Build/
|
||||
|
||||
# fastlane
|
||||
#
|
||||
# It is recommended to not store the screenshots in the git repo.
|
||||
# Instead, use fastlane to re-generate the screenshots whenever they are needed.
|
||||
# For more information about the recommended setup visit:
|
||||
# https://docs.fastlane.tools/best-practices/source-control/#source-control
|
||||
fastlane/report.xml
|
||||
fastlane/Preview.html
|
||||
fastlane/screenshots/**/*.png
|
||||
fastlane/test_output
|
||||
|
||||
# Ignore folder
|
||||
ignore
|
||||
|
||||
# .release
|
||||
.release/
|
||||
162
.github/workflows/publish-agent.yml
vendored
Normal file
162
.github/workflows/publish-agent.yml
vendored
Normal file
@@ -0,0 +1,162 @@
|
||||
name: Publish Agent Package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'agent-v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish (without v prefix)'
|
||||
required: true
|
||||
default: '0.1.0'
|
||||
workflow_call:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
# Adding permissions at workflow level
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
version: ${{ steps.get-version.outputs.version }}
|
||||
computer_version: ${{ steps.update-deps.outputs.computer_version }}
|
||||
som_version: ${{ steps.update-deps.outputs.som_version }}
|
||||
core_version: ${{ steps.update-deps.outputs.core_version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Determine version
|
||||
id: get-version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "push" ]; then
|
||||
# Extract version from tag (for package-specific tags)
|
||||
if [[ "${{ github.ref }}" =~ ^refs/tags/agent-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then
|
||||
VERSION=${BASH_REMATCH[1]}
|
||||
else
|
||||
echo "Invalid tag format for agent"
|
||||
exit 1
|
||||
fi
|
||||
elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
# Use version from workflow dispatch
|
||||
VERSION=${{ github.event.inputs.version }}
|
||||
else
|
||||
# Use version from workflow_call
|
||||
VERSION=${{ inputs.version }}
|
||||
fi
|
||||
echo "VERSION=$VERSION"
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Update dependencies to latest versions
|
||||
id: update-deps
|
||||
run: |
|
||||
cd libs/agent
|
||||
|
||||
# Install required package for PyPI API access
|
||||
pip install requests
|
||||
|
||||
# Create a more robust Python script for PyPI version checking
|
||||
cat > get_latest_versions.py << 'EOF'
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
|
||||
def get_package_version(package_name, fallback="0.1.0"):
|
||||
try:
|
||||
response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
|
||||
print(f"API Response Status for {package_name}: {response.status_code}", file=sys.stderr)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"API request failed for {package_name}, using fallback version", file=sys.stderr)
|
||||
return fallback
|
||||
|
||||
data = json.loads(response.text)
|
||||
|
||||
if 'info' not in data:
|
||||
print(f"Missing 'info' key in API response for {package_name}, using fallback version", file=sys.stderr)
|
||||
return fallback
|
||||
|
||||
return data['info']['version']
|
||||
except Exception as e:
|
||||
print(f"Error fetching version for {package_name}: {str(e)}", file=sys.stderr)
|
||||
return fallback
|
||||
|
||||
# Get latest versions
|
||||
print(get_package_version('cua-computer'))
|
||||
print(get_package_version('cua-som'))
|
||||
print(get_package_version('cua-core'))
|
||||
EOF
|
||||
|
||||
# Execute the script to get the versions
|
||||
VERSIONS=($(python get_latest_versions.py))
|
||||
LATEST_COMPUTER=${VERSIONS[0]}
|
||||
LATEST_SOM=${VERSIONS[1]}
|
||||
LATEST_CORE=${VERSIONS[2]}
|
||||
|
||||
echo "Latest cua-computer version: $LATEST_COMPUTER"
|
||||
echo "Latest cua-som version: $LATEST_SOM"
|
||||
echo "Latest cua-core version: $LATEST_CORE"
|
||||
|
||||
# Output the versions for the next job
|
||||
echo "computer_version=$LATEST_COMPUTER" >> $GITHUB_OUTPUT
|
||||
echo "som_version=$LATEST_SOM" >> $GITHUB_OUTPUT
|
||||
echo "core_version=$LATEST_CORE" >> $GITHUB_OUTPUT
|
||||
|
||||
# Determine major version for version constraint
|
||||
COMPUTER_MAJOR=$(echo $LATEST_COMPUTER | cut -d. -f1)
|
||||
SOM_MAJOR=$(echo $LATEST_SOM | cut -d. -f1)
|
||||
CORE_MAJOR=$(echo $LATEST_CORE | cut -d. -f1)
|
||||
|
||||
NEXT_COMPUTER_MAJOR=$((COMPUTER_MAJOR + 1))
|
||||
NEXT_SOM_MAJOR=$((SOM_MAJOR + 1))
|
||||
NEXT_CORE_MAJOR=$((CORE_MAJOR + 1))
|
||||
|
||||
# Update dependencies in pyproject.toml
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
# macOS version of sed needs an empty string for -i
|
||||
sed -i '' "s/\"cua-computer>=.*,<.*\"/\"cua-computer>=$LATEST_COMPUTER,<$NEXT_COMPUTER_MAJOR.0.0\"/" pyproject.toml
|
||||
sed -i '' "s/\"cua-som>=.*,<.*\"/\"cua-som>=$LATEST_SOM,<$NEXT_SOM_MAJOR.0.0\"/" pyproject.toml
|
||||
sed -i '' "s/\"cua-core>=.*,<.*\"/\"cua-core>=$LATEST_CORE,<$NEXT_CORE_MAJOR.0.0\"/" pyproject.toml
|
||||
else
|
||||
# Linux version
|
||||
sed -i "s/\"cua-computer>=.*,<.*\"/\"cua-computer>=$LATEST_COMPUTER,<$NEXT_COMPUTER_MAJOR.0.0\"/" pyproject.toml
|
||||
sed -i "s/\"cua-som>=.*,<.*\"/\"cua-som>=$LATEST_SOM,<$NEXT_SOM_MAJOR.0.0\"/" pyproject.toml
|
||||
sed -i "s/\"cua-core>=.*,<.*\"/\"cua-core>=$LATEST_CORE,<$NEXT_CORE_MAJOR.0.0\"/" pyproject.toml
|
||||
fi
|
||||
|
||||
# Display the updated dependencies
|
||||
echo "Updated dependencies in pyproject.toml:"
|
||||
grep -E "cua-computer|cua-som|cua-core" pyproject.toml
|
||||
|
||||
publish:
|
||||
needs: prepare
|
||||
uses: ./.github/workflows/reusable-publish.yml
|
||||
with:
|
||||
package_name: "agent"
|
||||
package_dir: "libs/agent"
|
||||
version: ${{ needs.prepare.outputs.version }}
|
||||
is_lume_package: false
|
||||
base_package_name: "cua-agent"
|
||||
secrets:
|
||||
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
|
||||
set-env-variables:
|
||||
needs: [prepare, publish]
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Set environment variables for use in other jobs
|
||||
run: |
|
||||
echo "COMPUTER_VERSION=${{ needs.prepare.outputs.computer_version }}" >> $GITHUB_ENV
|
||||
echo "SOM_VERSION=${{ needs.prepare.outputs.som_version }}" >> $GITHUB_ENV
|
||||
echo "CORE_VERSION=${{ needs.prepare.outputs.core_version }}" >> $GITHUB_ENV
|
||||
227
.github/workflows/publish-all.yml
vendored
Normal file
227
.github/workflows/publish-all.yml
vendored
Normal file
@@ -0,0 +1,227 @@
|
||||
name: Publish All Packages
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*' # For global releases (vX.Y.Z format)
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish (without v prefix)'
|
||||
required: true
|
||||
default: '0.1.0'
|
||||
packages:
|
||||
description: 'Packages to publish (comma-separated)'
|
||||
required: false
|
||||
default: 'core,pylume,computer,som,agent,computer-server'
|
||||
|
||||
# Adding permissions at workflow level
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
version: ${{ steps.get-version.outputs.version }}
|
||||
publish_core: ${{ steps.set-packages.outputs.publish_core }}
|
||||
publish_pylume: ${{ steps.set-packages.outputs.publish_pylume }}
|
||||
publish_computer: ${{ steps.set-packages.outputs.publish_computer }}
|
||||
publish_som: ${{ steps.set-packages.outputs.publish_som }}
|
||||
publish_agent: ${{ steps.set-packages.outputs.publish_agent }}
|
||||
publish_computer_server: ${{ steps.set-packages.outputs.publish_computer_server }}
|
||||
steps:
|
||||
- name: Determine version
|
||||
id: get-version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "push" ]; then
|
||||
# Extract version from tag (for global releases)
|
||||
if [[ "${{ github.ref }}" =~ ^refs/tags/v([0-9]+\.[0-9]+\.[0-9]+) ]]; then
|
||||
VERSION=${BASH_REMATCH[1]}
|
||||
else
|
||||
echo "Invalid tag format for global release"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
# Use version from workflow dispatch
|
||||
VERSION=${{ github.event.inputs.version }}
|
||||
fi
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "Using version: $VERSION"
|
||||
|
||||
- name: Determine packages to publish
|
||||
id: set-packages
|
||||
run: |
|
||||
# Default to all packages for tag-based releases
|
||||
if [ "${{ github.event_name }}" == "push" ]; then
|
||||
PACKAGES="core,pylume,computer,som,agent,computer-server"
|
||||
else
|
||||
# Use packages from workflow dispatch
|
||||
PACKAGES="${{ github.event.inputs.packages }}"
|
||||
fi
|
||||
|
||||
# Set individual flags for each package
|
||||
echo "publish_core=$(echo $PACKAGES | grep -q "core" && echo "true" || echo "false")" >> $GITHUB_OUTPUT
|
||||
echo "publish_pylume=$(echo $PACKAGES | grep -q "pylume" && echo "true" || echo "false")" >> $GITHUB_OUTPUT
|
||||
echo "publish_computer=$(echo $PACKAGES | grep -q "computer" && echo "true" || echo "false")" >> $GITHUB_OUTPUT
|
||||
echo "publish_som=$(echo $PACKAGES | grep -q "som" && echo "true" || echo "false")" >> $GITHUB_OUTPUT
|
||||
echo "publish_agent=$(echo $PACKAGES | grep -q "agent" && echo "true" || echo "false")" >> $GITHUB_OUTPUT
|
||||
echo "publish_computer_server=$(echo $PACKAGES | grep -q "computer-server" && echo "true" || echo "false")" >> $GITHUB_OUTPUT
|
||||
|
||||
echo "Publishing packages: $PACKAGES"
|
||||
|
||||
publish-core:
|
||||
needs: setup
|
||||
if: ${{ needs.setup.outputs.publish_core == 'true' }}
|
||||
uses: ./.github/workflows/publish-core.yml
|
||||
with:
|
||||
version: ${{ needs.setup.outputs.version }}
|
||||
secrets: inherit
|
||||
|
||||
# Add a delay to ensure PyPI has registered the new core version
|
||||
wait-for-core:
|
||||
needs: [setup, publish-core]
|
||||
if: ${{ needs.setup.outputs.publish_core == 'true' && (needs.setup.outputs.publish_computer == 'true' || needs.setup.outputs.publish_agent == 'true') }}
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Wait for PyPI to update
|
||||
run: |
|
||||
echo "Waiting for PyPI to register the new core version..."
|
||||
sleep 60 # Wait 60 seconds for PyPI to update its index
|
||||
|
||||
publish-pylume:
|
||||
needs: setup
|
||||
if: ${{ needs.setup.outputs.publish_pylume == 'true' }}
|
||||
uses: ./.github/workflows/publish-pylume.yml
|
||||
with:
|
||||
version: ${{ needs.setup.outputs.version }}
|
||||
secrets: inherit
|
||||
|
||||
# Add a delay to ensure PyPI has registered the new pylume version
|
||||
wait-for-pylume:
|
||||
needs: [setup, publish-pylume]
|
||||
if: ${{ needs.setup.outputs.publish_pylume == 'true' && (needs.setup.outputs.publish_computer == 'true' || needs.setup.outputs.publish_som == 'true') }}
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Wait for PyPI to update
|
||||
run: |
|
||||
echo "Waiting for PyPI to register the new pylume version..."
|
||||
sleep 60 # Wait 60 seconds for PyPI to update its index
|
||||
|
||||
publish-computer:
|
||||
needs: [setup, publish-core, publish-pylume, wait-for-core, wait-for-pylume]
|
||||
if: ${{ needs.setup.outputs.publish_computer == 'true' }}
|
||||
uses: ./.github/workflows/publish-computer.yml
|
||||
with:
|
||||
version: ${{ needs.setup.outputs.version }}
|
||||
secrets: inherit
|
||||
|
||||
publish-som:
|
||||
needs: [setup, publish-pylume, wait-for-pylume]
|
||||
if: ${{ needs.setup.outputs.publish_som == 'true' }}
|
||||
uses: ./.github/workflows/publish-som.yml
|
||||
with:
|
||||
version: ${{ needs.setup.outputs.version }}
|
||||
secrets: inherit
|
||||
|
||||
# Add a delay to ensure PyPI has registered the new computer and som versions
|
||||
wait-for-deps:
|
||||
needs: [setup, publish-computer, publish-som]
|
||||
if: ${{ (needs.setup.outputs.publish_computer == 'true' || needs.setup.outputs.publish_som == 'true') && needs.setup.outputs.publish_agent == 'true' }}
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Wait for PyPI to update
|
||||
run: |
|
||||
echo "Waiting for PyPI to register new dependency versions..."
|
||||
sleep 60 # Wait 60 seconds for PyPI to update its index
|
||||
|
||||
publish-agent:
|
||||
needs: [setup, publish-core, publish-computer, publish-som, wait-for-core, wait-for-deps]
|
||||
if: ${{ needs.setup.outputs.publish_agent == 'true' }}
|
||||
uses: ./.github/workflows/publish-agent.yml
|
||||
with:
|
||||
version: ${{ needs.setup.outputs.version }}
|
||||
secrets: inherit
|
||||
|
||||
publish-computer-server:
|
||||
needs: [setup, publish-computer]
|
||||
if: ${{ needs.setup.outputs.publish_computer_server == 'true' }}
|
||||
uses: ./.github/workflows/publish-computer-server.yml
|
||||
with:
|
||||
version: ${{ needs.setup.outputs.version }}
|
||||
secrets: inherit
|
||||
|
||||
# Create a global release for the entire CUA project
|
||||
create-global-release:
|
||||
needs: [setup, publish-core, publish-pylume, publish-computer, publish-som, publish-agent, publish-computer-server]
|
||||
if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') }}
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Create Summary Release Notes
|
||||
id: release_notes
|
||||
run: |
|
||||
VERSION=${{ needs.setup.outputs.version }}
|
||||
|
||||
# Create the release notes file
|
||||
echo "# CUA v${VERSION} Release" > release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "This is a global release of the Computer Universal Automation (CUA) project." >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
|
||||
echo "## Released Packages" >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
|
||||
# Add links to individual package releases
|
||||
if [[ "${{ needs.setup.outputs.publish_core }}" == "true" ]]; then
|
||||
echo "- [cua-core v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/core-v${VERSION})" >> release_notes.md
|
||||
fi
|
||||
|
||||
if [[ "${{ needs.setup.outputs.publish_pylume }}" == "true" ]]; then
|
||||
echo "- [pylume v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/pylume-v${VERSION})" >> release_notes.md
|
||||
fi
|
||||
|
||||
if [[ "${{ needs.setup.outputs.publish_computer }}" == "true" ]]; then
|
||||
echo "- [cua-computer v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/computer-v${VERSION})" >> release_notes.md
|
||||
fi
|
||||
|
||||
if [[ "${{ needs.setup.outputs.publish_som }}" == "true" ]]; then
|
||||
echo "- [cua-som v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/som-v${VERSION})" >> release_notes.md
|
||||
fi
|
||||
|
||||
if [[ "${{ needs.setup.outputs.publish_agent }}" == "true" ]]; then
|
||||
echo "- [cua-agent v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/agent-v${VERSION})" >> release_notes.md
|
||||
fi
|
||||
|
||||
if [[ "${{ needs.setup.outputs.publish_computer_server }}" == "true" ]]; then
|
||||
echo "- [cua-computer-server v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/computer-server-v${VERSION})" >> release_notes.md
|
||||
fi
|
||||
|
||||
echo "" >> release_notes.md
|
||||
echo "## Installation" >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "### Core Libraries" >> release_notes.md
|
||||
echo "```bash" >> release_notes.md
|
||||
echo "pip install cua-core==${VERSION}" >> release_notes.md
|
||||
echo "pip install cua-computer==${VERSION}" >> release_notes.md
|
||||
echo "pip install pylume==${VERSION}" >> release_notes.md
|
||||
echo "```" >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "### Agent with SOM (Recommended)" >> release_notes.md
|
||||
echo "```bash" >> release_notes.md
|
||||
echo "pip install cua-agent[som]==${VERSION}" >> release_notes.md
|
||||
echo "```" >> release_notes.md
|
||||
|
||||
echo "Release notes created:"
|
||||
cat release_notes.md
|
||||
|
||||
- name: Create GitHub Global Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
name: "CUA v${{ needs.setup.outputs.version }}"
|
||||
body_path: release_notes.md
|
||||
draft: false
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
80
.github/workflows/publish-computer-server.yml
vendored
Normal file
80
.github/workflows/publish-computer-server.yml
vendored
Normal file
@@ -0,0 +1,80 @@
|
||||
name: Publish Computer Server Package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'computer-server-v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish (without v prefix)'
|
||||
required: true
|
||||
default: '0.1.0'
|
||||
workflow_call:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish'
|
||||
required: true
|
||||
type: string
|
||||
outputs:
|
||||
version:
|
||||
description: "The version that was published"
|
||||
value: ${{ jobs.prepare.outputs.version }}
|
||||
|
||||
# Adding permissions at workflow level
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
version: ${{ steps.get-version.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Determine version
|
||||
id: get-version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "push" ]; then
|
||||
# Extract version from tag (for package-specific tags)
|
||||
if [[ "${{ github.ref }}" =~ ^refs/tags/computer-server-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then
|
||||
VERSION=${BASH_REMATCH[1]}
|
||||
else
|
||||
echo "Invalid tag format for computer-server"
|
||||
exit 1
|
||||
fi
|
||||
elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
# Use version from workflow dispatch
|
||||
VERSION=${{ github.event.inputs.version }}
|
||||
else
|
||||
# Use version from workflow_call
|
||||
VERSION=${{ inputs.version }}
|
||||
fi
|
||||
echo "VERSION=$VERSION"
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
publish:
|
||||
needs: prepare
|
||||
uses: ./.github/workflows/reusable-publish.yml
|
||||
with:
|
||||
package_name: "computer-server"
|
||||
package_dir: "libs/computer-server"
|
||||
version: ${{ needs.prepare.outputs.version }}
|
||||
is_lume_package: false
|
||||
base_package_name: "cua-computer-server"
|
||||
secrets:
|
||||
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
|
||||
set-env-variables:
|
||||
needs: [prepare, publish]
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Set environment variables for use in other jobs
|
||||
run: |
|
||||
echo "COMPUTER_VERSION=${{ needs.prepare.outputs.version }}" >> $GITHUB_ENV
|
||||
148
.github/workflows/publish-computer.yml
vendored
Normal file
148
.github/workflows/publish-computer.yml
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
name: Publish Computer Package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'computer-v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish (without v prefix)'
|
||||
required: true
|
||||
default: '0.1.0'
|
||||
workflow_call:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
# Adding permissions at workflow level
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
version: ${{ steps.get-version.outputs.version }}
|
||||
pylume_version: ${{ steps.update-deps.outputs.pylume_version }}
|
||||
core_version: ${{ steps.update-deps.outputs.core_version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Determine version
|
||||
id: get-version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "push" ]; then
|
||||
# Extract version from tag (for package-specific tags)
|
||||
if [[ "${{ github.ref }}" =~ ^refs/tags/computer-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then
|
||||
VERSION=${BASH_REMATCH[1]}
|
||||
else
|
||||
echo "Invalid tag format for computer"
|
||||
exit 1
|
||||
fi
|
||||
elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
# Use version from workflow dispatch
|
||||
VERSION=${{ github.event.inputs.version }}
|
||||
else
|
||||
# Use version from workflow_call
|
||||
VERSION=${{ inputs.version }}
|
||||
fi
|
||||
echo "VERSION=$VERSION"
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Update dependencies to latest versions
|
||||
id: update-deps
|
||||
run: |
|
||||
cd libs/computer
|
||||
# Install required package for PyPI API access
|
||||
pip install requests
|
||||
|
||||
# Create a more robust Python script for PyPI version checking
|
||||
cat > get_latest_versions.py << 'EOF'
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
|
||||
def get_package_version(package_name, fallback="0.1.0"):
|
||||
try:
|
||||
response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
|
||||
print(f"API Response Status for {package_name}: {response.status_code}", file=sys.stderr)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"API request failed for {package_name}, using fallback version", file=sys.stderr)
|
||||
return fallback
|
||||
|
||||
data = json.loads(response.text)
|
||||
|
||||
if 'info' not in data:
|
||||
print(f"Missing 'info' key in API response for {package_name}, using fallback version", file=sys.stderr)
|
||||
return fallback
|
||||
|
||||
return data['info']['version']
|
||||
except Exception as e:
|
||||
print(f"Error fetching version for {package_name}: {str(e)}", file=sys.stderr)
|
||||
return fallback
|
||||
|
||||
# Get latest versions
|
||||
print(get_package_version('pylume'))
|
||||
print(get_package_version('cua-core'))
|
||||
EOF
|
||||
|
||||
# Execute the script to get the versions
|
||||
VERSIONS=($(python get_latest_versions.py))
|
||||
LATEST_PYLUME=${VERSIONS[0]}
|
||||
LATEST_CORE=${VERSIONS[1]}
|
||||
|
||||
echo "Latest pylume version: $LATEST_PYLUME"
|
||||
echo "Latest cua-core version: $LATEST_CORE"
|
||||
|
||||
# Output the versions for the next job
|
||||
echo "pylume_version=$LATEST_PYLUME" >> $GITHUB_OUTPUT
|
||||
echo "core_version=$LATEST_CORE" >> $GITHUB_OUTPUT
|
||||
|
||||
# Determine major version for version constraint
|
||||
CORE_MAJOR=$(echo $LATEST_CORE | cut -d. -f1)
|
||||
NEXT_CORE_MAJOR=$((CORE_MAJOR + 1))
|
||||
|
||||
# Update dependencies in pyproject.toml
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
# macOS version of sed needs an empty string for -i
|
||||
sed -i '' "s/\"pylume>=.*\"/\"pylume>=$LATEST_PYLUME\"/" pyproject.toml
|
||||
sed -i '' "s/\"cua-core>=.*,<.*\"/\"cua-core>=$LATEST_CORE,<$NEXT_CORE_MAJOR.0.0\"/" pyproject.toml
|
||||
else
|
||||
# Linux version
|
||||
sed -i "s/\"pylume>=.*\"/\"pylume>=$LATEST_PYLUME\"/" pyproject.toml
|
||||
sed -i "s/\"cua-core>=.*,<.*\"/\"cua-core>=$LATEST_CORE,<$NEXT_CORE_MAJOR.0.0\"/" pyproject.toml
|
||||
fi
|
||||
|
||||
# Display the updated dependencies
|
||||
echo "Updated dependencies in pyproject.toml:"
|
||||
grep -E "pylume|cua-core" pyproject.toml
|
||||
|
||||
publish:
|
||||
needs: prepare
|
||||
uses: ./.github/workflows/reusable-publish.yml
|
||||
with:
|
||||
package_name: "computer"
|
||||
package_dir: "libs/computer"
|
||||
version: ${{ needs.prepare.outputs.version }}
|
||||
is_lume_package: false
|
||||
base_package_name: "cua-computer"
|
||||
secrets:
|
||||
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
|
||||
set-env-variables:
|
||||
needs: [prepare, publish]
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Set environment variables for use in other jobs
|
||||
run: |
|
||||
echo "PYLUME_VERSION=${{ needs.prepare.outputs.pylume_version }}" >> $GITHUB_ENV
|
||||
echo "CORE_VERSION=${{ needs.prepare.outputs.core_version }}" >> $GITHUB_ENV
|
||||
63
.github/workflows/publish-core.yml
vendored
Normal file
63
.github/workflows/publish-core.yml
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
name: Publish Core Package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'core-v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish (without v prefix)'
|
||||
required: true
|
||||
default: '0.1.0'
|
||||
workflow_call:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
# Adding permissions at workflow level
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
version: ${{ steps.get-version.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Determine version
|
||||
id: get-version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "push" ]; then
|
||||
# Extract version from tag (for package-specific tags)
|
||||
if [[ "${{ github.ref }}" =~ ^refs/tags/core-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then
|
||||
VERSION=${BASH_REMATCH[1]}
|
||||
else
|
||||
echo "Invalid tag format for core"
|
||||
exit 1
|
||||
fi
|
||||
elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
# Use version from workflow dispatch
|
||||
VERSION=${{ github.event.inputs.version }}
|
||||
else
|
||||
# Use version from workflow_call
|
||||
VERSION=${{ inputs.version }}
|
||||
fi
|
||||
echo "VERSION=$VERSION"
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
publish:
|
||||
needs: prepare
|
||||
uses: ./.github/workflows/reusable-publish.yml
|
||||
with:
|
||||
package_name: "core"
|
||||
package_dir: "libs/core"
|
||||
version: ${{ needs.prepare.outputs.version }}
|
||||
is_lume_package: false
|
||||
base_package_name: "cua-core"
|
||||
secrets:
|
||||
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
67
.github/workflows/publish-omniparser.yml
vendored
Normal file
67
.github/workflows/publish-omniparser.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: Publish OmniParser Package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'omniparser-v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish (without v prefix)'
|
||||
required: true
|
||||
default: '0.1.0'
|
||||
workflow_call:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish'
|
||||
required: true
|
||||
type: string
|
||||
outputs:
|
||||
version:
|
||||
description: "The version that was published"
|
||||
value: ${{ jobs.determine-version.outputs.version }}
|
||||
|
||||
# Adding permissions at workflow level
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
determine-version:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
version: ${{ steps.get-version.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Determine version
|
||||
id: get-version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "push" ]; then
|
||||
# Extract version from tag (for package-specific tags)
|
||||
if [[ "${{ github.ref }}" =~ ^refs/tags/omniparser-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then
|
||||
VERSION=${BASH_REMATCH[1]}
|
||||
else
|
||||
echo "Invalid tag format for omniparser"
|
||||
exit 1
|
||||
fi
|
||||
elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
# Use version from workflow dispatch
|
||||
VERSION=${{ github.event.inputs.version }}
|
||||
else
|
||||
# Use version from workflow_call
|
||||
VERSION=${{ inputs.version }}
|
||||
fi
|
||||
echo "VERSION=$VERSION"
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
publish:
|
||||
needs: determine-version
|
||||
uses: ./.github/workflows/reusable-publish.yml
|
||||
with:
|
||||
package_name: "omniparser"
|
||||
package_dir: "libs/omniparser"
|
||||
version: ${{ needs.determine-version.outputs.version }}
|
||||
is_lume_package: false
|
||||
base_package_name: "cua-omniparser"
|
||||
secrets:
|
||||
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
67
.github/workflows/publish-pylume.yml
vendored
Normal file
67
.github/workflows/publish-pylume.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: Publish Pylume Package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'pylume-v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish (without v prefix)'
|
||||
required: true
|
||||
default: '0.1.0'
|
||||
workflow_call:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish'
|
||||
required: true
|
||||
type: string
|
||||
outputs:
|
||||
version:
|
||||
description: "The version that was published"
|
||||
value: ${{ jobs.determine-version.outputs.version }}
|
||||
|
||||
# Adding permissions at workflow level
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
determine-version:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
version: ${{ steps.get-version.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Determine version
|
||||
id: get-version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "push" ]; then
|
||||
# Extract version from tag (for package-specific tags)
|
||||
if [[ "${{ github.ref }}" =~ ^refs/tags/pylume-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then
|
||||
VERSION=${BASH_REMATCH[1]}
|
||||
else
|
||||
echo "Invalid tag format for pylume"
|
||||
exit 1
|
||||
fi
|
||||
elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
# Use version from workflow dispatch
|
||||
VERSION=${{ github.event.inputs.version }}
|
||||
else
|
||||
# Use version from workflow_call
|
||||
VERSION=${{ inputs.version }}
|
||||
fi
|
||||
echo "VERSION=$VERSION"
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
publish:
|
||||
needs: determine-version
|
||||
uses: ./.github/workflows/reusable-publish.yml
|
||||
with:
|
||||
package_name: "pylume"
|
||||
package_dir: "libs/pylume"
|
||||
version: ${{ needs.determine-version.outputs.version }}
|
||||
is_lume_package: true
|
||||
base_package_name: "pylume"
|
||||
secrets:
|
||||
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
67
.github/workflows/publish-som.yml
vendored
Normal file
67
.github/workflows/publish-som.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: Publish SOM Package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'som-v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish (without v prefix)'
|
||||
required: true
|
||||
default: '0.1.0'
|
||||
workflow_call:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to publish'
|
||||
required: true
|
||||
type: string
|
||||
outputs:
|
||||
version:
|
||||
description: "The version that was published"
|
||||
value: ${{ jobs.determine-version.outputs.version }}
|
||||
|
||||
# Adding permissions at workflow level
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
determine-version:
|
||||
runs-on: macos-latest
|
||||
outputs:
|
||||
version: ${{ steps.get-version.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Determine version
|
||||
id: get-version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "push" ]; then
|
||||
# Extract version from tag (for package-specific tags)
|
||||
if [[ "${{ github.ref }}" =~ ^refs/tags/som-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then
|
||||
VERSION=${BASH_REMATCH[1]}
|
||||
else
|
||||
echo "Invalid tag format for som"
|
||||
exit 1
|
||||
fi
|
||||
elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
# Use version from workflow dispatch
|
||||
VERSION=${{ github.event.inputs.version }}
|
||||
else
|
||||
# Use version from workflow_call
|
||||
VERSION=${{ inputs.version }}
|
||||
fi
|
||||
echo "VERSION=$VERSION"
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
publish:
|
||||
needs: determine-version
|
||||
uses: ./.github/workflows/reusable-publish.yml
|
||||
with:
|
||||
package_name: "som"
|
||||
package_dir: "libs/som" # Updated to the new directory name
|
||||
version: ${{ needs.determine-version.outputs.version }}
|
||||
is_lume_package: false
|
||||
base_package_name: "cua-som"
|
||||
secrets:
|
||||
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
269
.github/workflows/reusable-publish.yml
vendored
Normal file
269
.github/workflows/reusable-publish.yml
vendored
Normal file
@@ -0,0 +1,269 @@
|
||||
name: Reusable Package Publish Workflow
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
package_name:
|
||||
description: 'Name of the package (e.g. pylume, computer, agent)'
|
||||
required: true
|
||||
type: string
|
||||
package_dir:
|
||||
description: 'Directory containing the package relative to workspace root (e.g. libs/pylume)'
|
||||
required: true
|
||||
type: string
|
||||
version:
|
||||
description: 'Version to publish'
|
||||
required: true
|
||||
type: string
|
||||
is_lume_package:
|
||||
description: 'Whether this package includes the lume binary'
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
base_package_name:
|
||||
description: 'PyPI package name (e.g. pylume, cua-agent)'
|
||||
required: true
|
||||
type: string
|
||||
secrets:
|
||||
PYPI_TOKEN:
|
||||
required: true
|
||||
outputs:
|
||||
version:
|
||||
description: "The version that was published"
|
||||
value: ${{ jobs.build-and-publish.outputs.version }}
|
||||
|
||||
jobs:
|
||||
build-and-publish:
|
||||
runs-on: macos-latest
|
||||
permissions:
|
||||
contents: write # This permission is needed for creating releases
|
||||
outputs:
|
||||
version: ${{ steps.set-version.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Full history for release creation
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Create root pdm.lock file
|
||||
run: |
|
||||
# Create an empty pdm.lock file in the root
|
||||
touch pdm.lock
|
||||
|
||||
- name: Install PDM
|
||||
uses: pdm-project/setup-pdm@v3
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: true
|
||||
|
||||
- name: Set version
|
||||
id: set-version
|
||||
run: |
|
||||
echo "VERSION=${{ inputs.version }}" >> $GITHUB_ENV
|
||||
echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Initialize PDM in package directory
|
||||
run: |
|
||||
# Make sure we're working with a properly initialized PDM project
|
||||
cd ${{ inputs.package_dir }}
|
||||
|
||||
# Create pdm.lock if it doesn't exist
|
||||
if [ ! -f "pdm.lock" ]; then
|
||||
echo "No pdm.lock found, initializing PDM project..."
|
||||
pdm lock
|
||||
fi
|
||||
|
||||
- name: Set version in package
|
||||
run: |
|
||||
cd ${{ inputs.package_dir }}
|
||||
# Replace pdm bump with direct edit of pyproject.toml
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
# macOS version of sed needs an empty string for -i
|
||||
sed -i '' "s/version = \".*\"/version = \"$VERSION\"/" pyproject.toml
|
||||
else
|
||||
# Linux version
|
||||
sed -i "s/version = \".*\"/version = \"$VERSION\"/" pyproject.toml
|
||||
fi
|
||||
# Verify version was updated
|
||||
echo "Updated version in pyproject.toml:"
|
||||
grep "version =" pyproject.toml
|
||||
|
||||
# Conditional step for lume binary download (only for pylume package)
|
||||
- name: Download and setup lume binary
|
||||
if: inputs.is_lume_package
|
||||
run: |
|
||||
# Install required packages
|
||||
pip install requests
|
||||
|
||||
# Create a simple Python script for better error handling
|
||||
cat > get_lume_version.py << 'EOF'
|
||||
import requests, re, json, sys
|
||||
try:
|
||||
response = requests.get('https://api.github.com/repos/trycua/lume/releases/latest')
|
||||
sys.stderr.write(f"API Status Code: {response.status_code}\n")
|
||||
sys.stderr.write(f"API Response: {response.text}\n")
|
||||
data = json.loads(response.text)
|
||||
if 'tag_name' not in data:
|
||||
sys.stderr.write(f"Warning: tag_name not found in API response. Keys: {list(data.keys())}\n")
|
||||
print('0.1.9')
|
||||
else:
|
||||
tag_name = data['tag_name']
|
||||
match = re.match(r'v(\d+\.\d+\.\d+)', tag_name)
|
||||
if match:
|
||||
print(match.group(1))
|
||||
else:
|
||||
sys.stderr.write("Error: Could not parse version from tag\n")
|
||||
print('0.1.9')
|
||||
except Exception as e:
|
||||
sys.stderr.write(f"Error fetching release info: {str(e)}\n")
|
||||
print('0.1.9')
|
||||
EOF
|
||||
|
||||
# Execute the script to get the version
|
||||
LUME_VERSION=$(python get_lume_version.py)
|
||||
echo "Using lume version: $LUME_VERSION"
|
||||
|
||||
# Create a temporary directory for extraction
|
||||
mkdir -p temp_lume
|
||||
|
||||
# Download the lume release (silently)
|
||||
echo "Downloading lume version v${LUME_VERSION}..."
|
||||
curl -sL "https://github.com/trycua/lume/releases/download/v${LUME_VERSION}/lume.tar.gz" -o temp_lume/lume.tar.gz
|
||||
|
||||
# Extract the tar file (ignore ownership and suppress warnings)
|
||||
cd temp_lume && tar --no-same-owner -xzf lume.tar.gz
|
||||
|
||||
# Make the binary executable
|
||||
chmod +x lume
|
||||
|
||||
# Copy the lume binary to the correct location in the pylume package
|
||||
mkdir -p "${GITHUB_WORKSPACE}/${{ inputs.package_dir }}/pylume"
|
||||
cp lume "${GITHUB_WORKSPACE}/${{ inputs.package_dir }}/pylume/lume"
|
||||
|
||||
# Verify the binary exists and is executable
|
||||
test -x "${GITHUB_WORKSPACE}/${{ inputs.package_dir }}/pylume/lume" || { echo "lume binary not found or not executable"; exit 1; }
|
||||
|
||||
# Cleanup
|
||||
cd "${GITHUB_WORKSPACE}" && rm -rf temp_lume
|
||||
|
||||
# Save the lume version for reference
|
||||
echo "LUME_VERSION=${LUME_VERSION}" >> $GITHUB_ENV
|
||||
|
||||
- name: Build and publish
|
||||
env:
|
||||
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
cd ${{ inputs.package_dir }}
|
||||
# Build with PDM
|
||||
pdm build
|
||||
|
||||
# For pylume package, verify the binary is in the wheel
|
||||
if [ "${{ inputs.is_lume_package }}" = "true" ]; then
|
||||
python -m pip install wheel
|
||||
wheel unpack dist/*.whl --dest temp_wheel
|
||||
echo "Listing contents of wheel directory:"
|
||||
find temp_wheel -type f
|
||||
test -f temp_wheel/pylume-*/pylume/lume || { echo "lume binary not found in wheel"; exit 1; }
|
||||
rm -rf temp_wheel
|
||||
echo "Publishing ${{ inputs.base_package_name }} ${VERSION} with lume ${LUME_VERSION}"
|
||||
else
|
||||
echo "Publishing ${{ inputs.base_package_name }} ${VERSION}"
|
||||
fi
|
||||
|
||||
# Install and use twine directly instead of PDM publish
|
||||
echo "Installing twine for direct publishing..."
|
||||
pip install twine
|
||||
|
||||
echo "Publishing to PyPI using twine..."
|
||||
TWINE_USERNAME="__token__" TWINE_PASSWORD="$PYPI_TOKEN" python -m twine upload dist/*
|
||||
|
||||
# Save the wheel file path for the release
|
||||
WHEEL_FILE=$(ls dist/*.whl | head -1)
|
||||
echo "WHEEL_FILE=${WHEEL_FILE}" >> $GITHUB_ENV
|
||||
|
||||
- name: Prepare Simple Release Notes
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
run: |
|
||||
# Create release notes based on package type
|
||||
echo "# ${{ inputs.base_package_name }} v${VERSION}" > release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
|
||||
if [ "${{ inputs.package_name }}" = "pylume" ]; then
|
||||
echo "## Python SDK for lume - run macOS and Linux VMs on Apple Silicon" >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "This package provides Python bindings for the lume virtualization tool." >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "## Dependencies" >> release_notes.md
|
||||
echo "* lume binary: v${LUME_VERSION}" >> release_notes.md
|
||||
elif [ "${{ inputs.package_name }}" = "computer" ]; then
|
||||
echo "## Computer control library for the Computer Universal Automation (CUA) project" >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "## Dependencies" >> release_notes.md
|
||||
echo "* pylume: ${PYLUME_VERSION:-latest}" >> release_notes.md
|
||||
elif [ "${{ inputs.package_name }}" = "agent" ]; then
|
||||
echo "## Dependencies" >> release_notes.md
|
||||
echo "* cua-computer: ${COMPUTER_VERSION:-latest}" >> release_notes.md
|
||||
echo "* cua-som: ${SOM_VERSION:-latest}" >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "## Installation Options" >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "### Basic installation with Anthropic" >> release_notes.md
|
||||
echo '```bash' >> release_notes.md
|
||||
echo "pip install cua-agent[anthropic]==${VERSION}" >> release_notes.md
|
||||
echo '```' >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "### With SOM (recommended)" >> release_notes.md
|
||||
echo '```bash' >> release_notes.md
|
||||
echo "pip install cua-agent[som]==${VERSION}" >> release_notes.md
|
||||
echo '```' >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "### All features" >> release_notes.md
|
||||
echo '```bash' >> release_notes.md
|
||||
echo "pip install cua-agent[all]==${VERSION}" >> release_notes.md
|
||||
echo '```' >> release_notes.md
|
||||
elif [ "${{ inputs.package_name }}" = "som" ]; then
|
||||
echo "## Computer Vision and OCR library for detecting and analyzing UI elements" >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "This package provides enhanced UI understanding capabilities through computer vision and OCR." >> release_notes.md
|
||||
elif [ "${{ inputs.package_name }}" = "computer-server" ]; then
|
||||
echo "## Computer Server for the Computer Universal Automation (CUA) project" >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "A FastAPI-based server implementation for computer control." >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "## Dependencies" >> release_notes.md
|
||||
echo "* cua-computer: ${COMPUTER_VERSION:-latest}" >> release_notes.md
|
||||
echo "" >> release_notes.md
|
||||
echo "## Usage" >> release_notes.md
|
||||
echo '```bash' >> release_notes.md
|
||||
echo "# Run the server" >> release_notes.md
|
||||
echo "cua-computer-server" >> release_notes.md
|
||||
echo '```' >> release_notes.md
|
||||
fi
|
||||
|
||||
# Add installation section if not agent (which has its own installation section)
|
||||
if [ "${{ inputs.package_name }}" != "agent" ]; then
|
||||
echo "" >> release_notes.md
|
||||
echo "## Installation" >> release_notes.md
|
||||
echo '```bash' >> release_notes.md
|
||||
echo "pip install ${{ inputs.base_package_name }}==${VERSION}" >> release_notes.md
|
||||
echo '```' >> release_notes.md
|
||||
fi
|
||||
|
||||
echo "Release notes created:"
|
||||
cat release_notes.md
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
name: "${{ inputs.base_package_name }} v${{ env.VERSION }}"
|
||||
body_path: release_notes.md
|
||||
files: ${{ inputs.package_dir }}/${{ env.WHEEL_FILE }}
|
||||
draft: false
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
188
.gitignore
vendored
188
.gitignore
vendored
@@ -1,9 +1,175 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
!libs/lume/scripts/build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Scripts
|
||||
server/scripts/
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Conda
|
||||
.conda/
|
||||
|
||||
# Local environment
|
||||
.env.local
|
||||
|
||||
# macOS DS_Store
|
||||
.DS_Store
|
||||
|
||||
weights/
|
||||
weights/icon_detect/
|
||||
weights/icon_detect/model.pt
|
||||
weights/icon_detect/model.pt.zip
|
||||
weights/icon_detect/model.pt.zip.part*
|
||||
|
||||
libs/omniparser/weights/icon_detect/model.pt
|
||||
|
||||
# Example test data and output
|
||||
examples/test_data/
|
||||
examples/output/
|
||||
|
||||
/screenshots/
|
||||
|
||||
/experiments/
|
||||
|
||||
/logs/
|
||||
|
||||
# Xcode
|
||||
#
|
||||
# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
|
||||
|
||||
.DS_Store
|
||||
|
||||
## User settings
|
||||
xcuserdata/
|
||||
|
||||
@@ -29,8 +195,7 @@ playground.xcworkspace
|
||||
#
|
||||
# Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata
|
||||
# hence it is not needed unless you have added a package configuration file to your project
|
||||
# .swiftpm
|
||||
|
||||
.swiftpm/
|
||||
.build/
|
||||
|
||||
# CocoaPods
|
||||
@@ -48,7 +213,6 @@ playground.xcworkspace
|
||||
#
|
||||
# Add this line if you want to avoid checking in source code from Carthage dependencies.
|
||||
# Carthage/Checkouts
|
||||
|
||||
Carthage/Build/
|
||||
|
||||
# fastlane
|
||||
@@ -57,20 +221,22 @@ Carthage/Build/
|
||||
# Instead, use fastlane to re-generate the screenshots whenever they are needed.
|
||||
# For more information about the recommended setup visit:
|
||||
# https://docs.fastlane.tools/best-practices/source-control/#source-control
|
||||
|
||||
fastlane/report.xml
|
||||
fastlane/Preview.html
|
||||
fastlane/screenshots/**/*.png
|
||||
fastlane/test_output
|
||||
|
||||
# Local environment variables
|
||||
.env.local
|
||||
|
||||
# Ignore folder
|
||||
ignore
|
||||
|
||||
# .release
|
||||
.release/
|
||||
|
||||
# Swift Package Manager
|
||||
.swiftpm/
|
||||
# Shared folder
|
||||
shared
|
||||
|
||||
# Trajectories
|
||||
trajectories/
|
||||
|
||||
# Installation ID Storage
|
||||
.storage/
|
||||
241
.vscode/launch.json
vendored
241
.vscode/launch.json
vendored
@@ -1,241 +0,0 @@
|
||||
{
|
||||
"configurations": [
|
||||
{
|
||||
"type": "bashdb",
|
||||
"request": "launch",
|
||||
"name": "Bash-Debug (select script from list of sh files)",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"program": "${command:SelectScriptName}",
|
||||
"pathBash": "/opt/homebrew/bin/bash",
|
||||
"args": []
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"serve"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume serve",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"create",
|
||||
"macos-vm",
|
||||
"--cpu",
|
||||
"4",
|
||||
"--memory",
|
||||
"4GB",
|
||||
"--disk-size",
|
||||
"40GB",
|
||||
"--ipsw",
|
||||
"/Users/<USER>/Downloads/UniversalMac_15.2_24C101_Restore.ipsw"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume create --os macos --ipsw 'path/to/ipsw' (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"create",
|
||||
"macos-vm",
|
||||
"--cpu",
|
||||
"4",
|
||||
"--memory",
|
||||
"4GB",
|
||||
"--disk-size",
|
||||
"20GB",
|
||||
"--ipsw",
|
||||
"latest"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume create --os macos --ipsw latest (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"create",
|
||||
"linux-vm",
|
||||
"--os",
|
||||
"linux",
|
||||
"--cpu",
|
||||
"4",
|
||||
"--memory",
|
||||
"4GB",
|
||||
"--disk-size",
|
||||
"20GB"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume create --os linux (linux)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"pull",
|
||||
"macos-sequoia-vanilla:15.2",
|
||||
"--name",
|
||||
"macos-vm-cloned"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume pull (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"run",
|
||||
"macos-vm",
|
||||
"--shared-dir",
|
||||
"/Users/<USER>/repos/trycua/lume/shared_folder:rw",
|
||||
"--start-vnc"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume run (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"run",
|
||||
"linux-vm",
|
||||
"--start-vnc",
|
||||
"--mount",
|
||||
"/Users/<USER>/Downloads/ubuntu-24.04.1-live-server-arm64.iso"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume run with setup mount (linux)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"run",
|
||||
"linux-vm",
|
||||
"--start-vnc"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume run (linux)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"get",
|
||||
"macos-vm"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume get (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"ls"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume ls",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"images"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume images",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"stop",
|
||||
"macos-vm"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume stop (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "swift: Build Debug lume"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Release lume",
|
||||
"program": "${workspaceFolder:lume}/.build/release/lume",
|
||||
"preLaunchTask": "swift: Build Release lume"
|
||||
}
|
||||
]
|
||||
}
|
||||
319
.vscode/lume.code-workspace
vendored
Normal file
319
.vscode/lume.code-workspace
vendored
Normal file
@@ -0,0 +1,319 @@
|
||||
{
|
||||
"folders": [
|
||||
{
|
||||
"name": "lume",
|
||||
"path": "../libs/lume"
|
||||
}
|
||||
],
|
||||
"settings": {
|
||||
"files.exclude": {
|
||||
"**/.git": true,
|
||||
"**/.svn": true,
|
||||
"**/.hg": true,
|
||||
"**/CVS": true,
|
||||
"**/.DS_Store": true
|
||||
},
|
||||
"swift.path.swift_driver_bin": "/usr/bin/swift",
|
||||
"swift.enableLanguageServer": true,
|
||||
"files.associations": {
|
||||
"*.swift": "swift"
|
||||
},
|
||||
"[swift]": {
|
||||
"editor.formatOnSave": true,
|
||||
"editor.detectIndentation": true,
|
||||
"editor.tabSize": 4
|
||||
},
|
||||
"swift.path": "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin",
|
||||
"swift.swiftEnvironmentVariables": {
|
||||
"DEVELOPER_DIR": "/Applications/Xcode.app"
|
||||
},
|
||||
"lldb.library": "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/LLDB",
|
||||
"lldb.launch.expressions": "native"
|
||||
},
|
||||
"tasks": {
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"label": "build-debug",
|
||||
"type": "shell",
|
||||
"command": "${workspaceFolder:lume}/scripts/build/build-debug.sh",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder:lume}"
|
||||
},
|
||||
"group": {
|
||||
"kind": "build",
|
||||
"isDefault": true
|
||||
},
|
||||
"presentation": {
|
||||
"reveal": "silent",
|
||||
"panel": "shared"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "swift: Build Debug lume",
|
||||
"type": "shell",
|
||||
"command": "${workspaceFolder:lume}/scripts/build/build-debug.sh",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder:lume}"
|
||||
},
|
||||
"group": "build",
|
||||
"presentation": {
|
||||
"reveal": "silent",
|
||||
"panel": "shared"
|
||||
},
|
||||
"problemMatcher": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"launch": {
|
||||
"configurations": [
|
||||
{
|
||||
"type": "bashdb",
|
||||
"request": "launch",
|
||||
"name": "Bash-Debug (select script from list of sh files)",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"program": "${command:SelectScriptName}",
|
||||
"pathBash": "/opt/homebrew/bin/bash",
|
||||
"args": []
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"serve"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume serve",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"create",
|
||||
"macos-vm",
|
||||
"--cpu",
|
||||
"4",
|
||||
"--memory",
|
||||
"4GB",
|
||||
"--disk-size",
|
||||
"40GB",
|
||||
"--ipsw",
|
||||
"/Users/<USER>/Downloads/UniversalMac_15.2_24C101_Restore.ipsw"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume create --os macos --ipsw 'path/to/ipsw' (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"create",
|
||||
"macos-vm",
|
||||
"--cpu",
|
||||
"4",
|
||||
"--memory",
|
||||
"4GB",
|
||||
"--disk-size",
|
||||
"20GB",
|
||||
"--ipsw",
|
||||
"latest"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume create --os macos --ipsw latest (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"create",
|
||||
"linux-vm",
|
||||
"--os",
|
||||
"linux",
|
||||
"--cpu",
|
||||
"4",
|
||||
"--memory",
|
||||
"4GB",
|
||||
"--disk-size",
|
||||
"20GB"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume create --os linux (linux)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"pull",
|
||||
"macos-sequoia-vanilla:15.2",
|
||||
"--name",
|
||||
"macos-vm-cloned"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume pull (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"run",
|
||||
"macos-vm",
|
||||
"--shared-dir",
|
||||
"/Users/<USER>/repos/trycua/lume/shared_folder:rw",
|
||||
"--start-vnc"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume run (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"run",
|
||||
"linux-vm",
|
||||
"--start-vnc",
|
||||
"--mount",
|
||||
"/Users/<USER>/Downloads/ubuntu-24.04.1-live-server-arm64.iso"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume run with setup mount (linux)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"run",
|
||||
"linux-vm",
|
||||
"--start-vnc"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume run (linux)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"get",
|
||||
"macos-vm"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume get (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"ls"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume ls",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"images"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume images",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"sourceLanguages": [
|
||||
"swift"
|
||||
],
|
||||
"args": [
|
||||
"stop",
|
||||
"macos-vm"
|
||||
],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume stop (macos)",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "build-debug"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Debug lume",
|
||||
"program": "${workspaceFolder:lume}/.build/debug/lume",
|
||||
"preLaunchTask": "swift: Build Debug lume"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"name": "Release lume",
|
||||
"program": "${workspaceFolder:lume}/.build/release/lume",
|
||||
"preLaunchTask": "swift: Build Release lume"
|
||||
},
|
||||
{
|
||||
"type": "bashdb",
|
||||
"request": "launch",
|
||||
"name": "Bash-Debug (select script)",
|
||||
"cwd": "${workspaceFolder:lume}",
|
||||
"program": "${command:SelectScriptName}",
|
||||
"pathBash": "/opt/homebrew/bin/bash",
|
||||
"args": []
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
292
.vscode/py.code-workspace
vendored
Normal file
292
.vscode/py.code-workspace
vendored
Normal file
@@ -0,0 +1,292 @@
|
||||
{
|
||||
"folders": [
|
||||
{
|
||||
"name": "cua-root",
|
||||
"path": ".."
|
||||
},
|
||||
{
|
||||
"name": "computer",
|
||||
"path": "../libs/computer"
|
||||
},
|
||||
{
|
||||
"name": "agent",
|
||||
"path": "../libs/agent"
|
||||
},
|
||||
{
|
||||
"name": "som",
|
||||
"path": "../libs/som"
|
||||
},
|
||||
{
|
||||
"name": "computer-server",
|
||||
"path": "../libs/computer-server"
|
||||
},
|
||||
{
|
||||
"name": "pylume",
|
||||
"path": "../libs/pylume"
|
||||
},
|
||||
{
|
||||
"name": "core",
|
||||
"path": "../libs/core"
|
||||
}
|
||||
],
|
||||
"settings": {
|
||||
"files.exclude": {
|
||||
"**/.git": true,
|
||||
"**/.svn": true,
|
||||
"**/.hg": true,
|
||||
"**/CVS": true,
|
||||
"**/.DS_Store": true,
|
||||
"**/__pycache__": true,
|
||||
"**/.pytest_cache": true,
|
||||
"**/*.pyc": true
|
||||
},
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.nosetestsEnabled": false,
|
||||
"python.testing.pytestArgs": [
|
||||
"libs"
|
||||
],
|
||||
"python.analysis.extraPaths": [
|
||||
"${workspaceFolder:cua-root}/libs/core",
|
||||
"${workspaceFolder:cua-root}/libs/computer",
|
||||
"${workspaceFolder:cua-root}/libs/agent",
|
||||
"${workspaceFolder:cua-root}/libs/som",
|
||||
"${workspaceFolder:cua-root}/libs/pylume",
|
||||
"${workspaceFolder:cua-root}/.vscode/typings"
|
||||
],
|
||||
"python.envFile": "${workspaceFolder:cua-root}/.env",
|
||||
"python.defaultInterpreterPath": "${workspaceFolder:cua-root}/.venv/bin/python",
|
||||
"python.analysis.diagnosticMode": "workspace",
|
||||
"python.analysis.typeCheckingMode": "basic",
|
||||
"python.analysis.autoSearchPaths": true,
|
||||
"python.analysis.stubPath": "${workspaceFolder:cua-root}/.vscode/typings",
|
||||
"python.analysis.indexing": false,
|
||||
"python.analysis.exclude": [
|
||||
"**/node_modules/**",
|
||||
"**/__pycache__/**",
|
||||
"**/.*/**",
|
||||
"**/venv/**",
|
||||
"**/.venv/**",
|
||||
"**/dist/**",
|
||||
"**/build/**",
|
||||
".pdm-build/**",
|
||||
"**/.git/**",
|
||||
"examples/**",
|
||||
"notebooks/**",
|
||||
"logs/**",
|
||||
"screenshots/**"
|
||||
],
|
||||
"python.analysis.packageIndexDepths": [
|
||||
{
|
||||
"name": "computer",
|
||||
"depth": 2
|
||||
},
|
||||
{
|
||||
"name": "agent",
|
||||
"depth": 2
|
||||
},
|
||||
{
|
||||
"name": "som",
|
||||
"depth": 2
|
||||
},
|
||||
{
|
||||
"name": "pylume",
|
||||
"depth": 2
|
||||
},
|
||||
{
|
||||
"name": "core",
|
||||
"depth": 2
|
||||
}
|
||||
],
|
||||
"python.autoComplete.extraPaths": [
|
||||
"${workspaceFolder:cua-root}/libs/core",
|
||||
"${workspaceFolder:cua-root}/libs/computer",
|
||||
"${workspaceFolder:cua-root}/libs/agent",
|
||||
"${workspaceFolder:cua-root}/libs/som",
|
||||
"${workspaceFolder:cua-root}/libs/pylume"
|
||||
],
|
||||
"python.languageServer": "Pylance",
|
||||
"[python]": {
|
||||
"editor.formatOnSave": true,
|
||||
"editor.defaultFormatter": "ms-python.black-formatter",
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit"
|
||||
}
|
||||
},
|
||||
"files.associations": {
|
||||
"examples/computer_examples.py": "python",
|
||||
"examples/agent_examples.py": "python"
|
||||
},
|
||||
"python.interpreterPaths": {
|
||||
"examples/computer_examples.py": "${workspaceFolder}/libs/computer/.venv/bin/python",
|
||||
"examples/agent_examples.py": "${workspaceFolder}/libs/agent/.venv/bin/python"
|
||||
}
|
||||
},
|
||||
"tasks": {
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"label": "Build Dependencies",
|
||||
"type": "shell",
|
||||
"command": "${workspaceFolder}/scripts/build.sh",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "new",
|
||||
"clear": true
|
||||
},
|
||||
"group": {
|
||||
"kind": "build",
|
||||
"isDefault": true
|
||||
},
|
||||
"options": {
|
||||
"shell": {
|
||||
"executable": "/bin/bash",
|
||||
"args": ["-l", "-c"]
|
||||
}
|
||||
},
|
||||
"problemMatcher": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"launch": {
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Run Computer Examples",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "examples/computer_examples.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"python": "${workspaceFolder:cua-root}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder:cua-root}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run Agent Examples",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "examples/agent_examples.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"python": "${workspaceFolder:cua-root}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder:cua-root}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run PyLume Examples",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "examples/pylume_examples.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"python": "${workspaceFolder:cua-root}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder:cua-root}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "SOM: Run Experiments (No OCR)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "examples/som_examples.py",
|
||||
"args": [
|
||||
"examples/test_data",
|
||||
"--output-dir", "examples/output",
|
||||
"--ocr", "none",
|
||||
"--mode", "experiment"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"python": "${workspaceFolder:cua-root}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder:cua-root}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "SOM: Run Experiments (EasyOCR)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "examples/som_examples.py",
|
||||
"args": [
|
||||
"examples/test_data",
|
||||
"--output-dir", "examples/output",
|
||||
"--ocr", "easyocr",
|
||||
"--mode", "experiment"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"python": "${workspaceFolder:cua-root}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder:cua-root}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run Computer Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/libs/computer-server/run_server.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"python": "${workspaceFolder:cua-root}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder:cua-root}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run Computer Server with Args",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/libs/computer-server/run_server.py",
|
||||
"args": [
|
||||
"--host", "0.0.0.0",
|
||||
"--port", "8000",
|
||||
"--log-level", "debug"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"python": "${workspaceFolder:cua-root}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder:cua-root}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer-server"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Run Computer Examples + Server",
|
||||
"configurations": ["Run Computer Examples", "Run Computer Server"],
|
||||
"stopAll": true,
|
||||
"presentation": {
|
||||
"group": "Computer",
|
||||
"order": 1
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run Server with Keep-Alive Client",
|
||||
"configurations": ["Run Computer Server", "Test Server Connection (Keep Alive)"],
|
||||
"stopAll": true,
|
||||
"presentation": {
|
||||
"group": "Computer",
|
||||
"order": 2
|
||||
}
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"id": "imagePath",
|
||||
"type": "promptString",
|
||||
"description": "Path to the image file or directory for icon detection",
|
||||
"default": "${workspaceFolder}/examples/test_data"
|
||||
}
|
||||
]
|
||||
}
|
||||
18
.vscode/tasks.json
vendored
18
.vscode/tasks.json
vendored
@@ -1,18 +0,0 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"label": "build-debug",
|
||||
"type": "shell",
|
||||
"command": "${workspaceFolder:lume}/scripts/build/build-debug.sh",
|
||||
"group": {
|
||||
"kind": "build",
|
||||
"isDefault": true
|
||||
},
|
||||
"presentation": {
|
||||
"reveal": "silent"
|
||||
},
|
||||
"problemMatcher": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
# Contributing to lume
|
||||
# Contributing to cua
|
||||
|
||||
We deeply appreciate your interest in contributing to lume! Whether you're reporting bugs, suggesting enhancements, improving docs, or submitting pull requests, your contributions help improve the project for everyone.
|
||||
We deeply appreciate your interest in contributing to cua! Whether you're reporting bugs, suggesting enhancements, improving docs, or submitting pull requests, your contributions help improve the project for everyone.
|
||||
|
||||
## Reporting Bugs
|
||||
|
||||
@@ -34,6 +34,6 @@ Documentation improvements are always welcome. You can:
|
||||
- Improve API documentation
|
||||
- Add tutorials or guides
|
||||
|
||||
For detailed instructions on setting up your development environment and submitting code contributions, please see our [Development.md](docs/Development.md) guide.
|
||||
For detailed instructions on setting up your development environment and submitting code contributions, please see our [Developer-Guide.md](docs/Developer-Guide.md) guide.
|
||||
|
||||
Feel free to join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get help with your contributions.
|
||||
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 trycua
|
||||
Copyright (c) 2025 trycua
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
179
README.md
179
README.md
@@ -1,145 +1,82 @@
|
||||
<div align="center">
|
||||
<h1>
|
||||
<div class="image-wrapper" style="display: inline-block;">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="img/logo_white.png" style="display: block; margin: auto;">
|
||||
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="img/logo_black.png" style="display: block; margin: auto;">
|
||||
<img alt="Shows my svg">
|
||||
</picture>
|
||||
</div>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" alt="Cua logo" height="150" srcset="img/logo_white.png">
|
||||
<source media="(prefers-color-scheme: light)" alt="Cua logo" height="150" srcset="img/logo_black.png">
|
||||
<img alt="Cua logo" height="150" src="img/logo_black.png">
|
||||
</picture>
|
||||
|
||||
[](#)
|
||||
<!-- <h1>Cua</h1> -->
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](#)
|
||||
[](#install)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](#contributors)
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
# Cua
|
||||
|
||||
**lume** is a lightweight Command Line Interface and local API server to create, run and manage macOS and Linux virtual machines (VMs) with near-native performance on Apple Silicon, using Apple's `Virtualization.Framework`.
|
||||
Create and run high-performance macOS and Linux VMs on Apple Silicon, with built-in support for AI agents.
|
||||
|
||||
### Run prebuilt macOS images in just 1 step
|
||||
## Libraries
|
||||
|
||||
<div align="center">
|
||||
<img src="img/cli.png" alt="lume cli">
|
||||
</div>
|
||||
| Library | Description | Installation | Version |
|
||||
|---------|-------------|--------------|---------|
|
||||
| [**Lume**](./libs/lume/README.md) | CLI for running macOS/Linux VMs with near-native performance using Apple's `Virtualization.Framework`. | `brew install lume` | [](https://formulae.brew.sh/formula/lume) |
|
||||
| [**Computer**](./libs/computer/README.md) | Computer-Use Interface (CUI) framework for interacting with macOS/Linux sandboxes | `pip install cua-computer` | [](https://pypi.org/project/cua-computer/) |
|
||||
| [**Agent**](./libs/agent/README.md) | Computer-Use Agent (CUA) framework for running agentic workflows in macOS/Linux dedicated sandboxes | `pip install cua-agent` | [](https://pypi.org/project/cua-agent/) |
|
||||
|
||||
## Lume
|
||||
|
||||
```bash
|
||||
lume run macos-sequoia-vanilla:latest
|
||||
```
|
||||
|
||||
For a python interface, check out [pylume](https://github.com/trycua/pylume).
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
lume <command>
|
||||
|
||||
Commands:
|
||||
lume create <name> Create a new macOS or Linux VM
|
||||
lume run <name> Run a VM
|
||||
lume ls List all VMs
|
||||
lume get <name> Get detailed information about a VM
|
||||
lume set <name> Modify VM configuration
|
||||
lume stop <name> Stop a running VM
|
||||
lume delete <name> Delete a VM
|
||||
lume pull <image> Pull a macOS image from container registry
|
||||
lume clone <name> <new-name> Clone an existing VM
|
||||
lume images List available macOS images in local cache
|
||||
lume ipsw Get the latest macOS restore image URL
|
||||
lume prune Remove cached images
|
||||
lume serve Start the API server
|
||||
|
||||
Options:
|
||||
--help Show help [boolean]
|
||||
--version Show version number [boolean]
|
||||
|
||||
Command Options:
|
||||
create:
|
||||
--os <os> Operating system to install (macOS or linux, default: macOS)
|
||||
--cpu <cores> Number of CPU cores (default: 4)
|
||||
--memory <size> Memory size, e.g., 8GB (default: 4GB)
|
||||
--disk-size <size> Disk size, e.g., 50GB (default: 40GB)
|
||||
--display <res> Display resolution (default: 1024x768)
|
||||
--ipsw <path> Path to IPSW file or 'latest' for macOS VMs
|
||||
|
||||
run:
|
||||
--no-display Do not start the VNC client app
|
||||
--shared-dir <dir> Share directory with VM (format: path[:ro|rw])
|
||||
--mount <path> For Linux VMs only, attach a read-only disk image
|
||||
--registry <url> Container registry URL (default: ghcr.io)
|
||||
--organization <org> Organization to pull from (default: trycua)
|
||||
--vnc-port <port> Port to use for the VNC server (default: 0 for auto-assign)
|
||||
--recovery-mode <boolean> For MacOS VMs only, start VM in recovery mode (default: false)
|
||||
|
||||
set:
|
||||
--cpu <cores> New number of CPU cores (e.g., 4)
|
||||
--memory <size> New memory size (e.g., 8192MB or 8GB)
|
||||
--disk-size <size> New disk size (e.g., 40960MB or 40GB)
|
||||
--display <res> New display resolution in format WIDTHxHEIGHT (e.g., 1024x768)
|
||||
|
||||
delete:
|
||||
--force Force deletion without confirmation
|
||||
|
||||
pull:
|
||||
--registry <url> Container registry URL (default: ghcr.io)
|
||||
--organization <org> Organization to pull from (default: trycua)
|
||||
|
||||
serve:
|
||||
--port <port> Port to listen on (default: 3000)
|
||||
```
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
brew tap trycua/lume
|
||||
brew install lume
|
||||
```
|
||||
|
||||
You can also download the `lume.pkg.tar.gz` archive from the [latest release](https://github.com/trycua/lume/releases), extract it, and install the package manually.
|
||||
|
||||
## Prebuilt Images
|
||||
|
||||
Pre-built images are available in the registry [ghcr.io/trycua](https://github.com/orgs/trycua/packages).
|
||||
These images come with an SSH server pre-configured and auto-login enabled.
|
||||
|
||||
For the security of your VM, change the default password `lume` immediately after your first login.
|
||||
|
||||
| Image | Tag | Description | Size |
|
||||
|-------|------------|-------------|------|
|
||||
| `macos-sequoia-vanilla` | `latest`, `15.2` | macOS Sequoia 15.2 | 40GB |
|
||||
| `macos-sequoia-xcode` | `latest`, `15.2` | macOS Sequoia 15.2 with Xcode command line tools | 50GB |
|
||||
| `ubuntu-noble-vanilla` | `latest`, `24.04.1` | [Ubuntu Server for ARM 24.04.1 LTS](https://ubuntu.com/download/server/arm) with Ubuntu Desktop | 20GB |
|
||||
|
||||
For additional disk space, resize the VM disk after pulling the image using the `lume set <name> --disk-size <size>` command.
|
||||
|
||||
## Local API Server
|
||||
|
||||
`lume` exposes a local HTTP API server that listens on `http://localhost:3000/lume`, enabling automated management of VMs.
|
||||
|
||||
```bash
|
||||
lume serve
|
||||
```
|
||||
|
||||
For detailed API documentation, please refer to [API Reference](docs/API-Reference.md).
|
||||
**Originally looking for Lume?** If you're here for the original Lume project, it's now part of this monorepo. Simply install with `brew` and refer to its [documentation](./libs/lume/README.md).
|
||||
|
||||
## Docs
|
||||
|
||||
- [API Reference](docs/API-Reference.md)
|
||||
- [Development](docs/Development.md)
|
||||
- [FAQ](docs/FAQ.md)
|
||||
For optimal onboarding, we recommend starting with the [Computer](./libs/computer/README.md) documentation to cover the core functionality of the Computer sandbox, then exploring the [Agent](./libs/agent/README.md) documentation to understand Cua's AI agent capabilities, and finally working through the Notebook examples to try out the Computer-Use interface and agent.
|
||||
|
||||
- [Computer](./libs/computer/README.md)
|
||||
- [Agent](./libs/agent/README.md)
|
||||
- [Notebooks](./notebooks/)
|
||||
|
||||
## Demos
|
||||
|
||||
Demos of the Computer-Use Agent in action. Share your most impressive demos in Cua's [Discord community](https://discord.com/invite/mVnXXpdE85)!
|
||||
|
||||
<details open>
|
||||
<summary><b>AI-Gradio: multi-app workflow requiring browser, VS Code and terminal access</b></summary>
|
||||
<br>
|
||||
<div align="center">
|
||||
<video src="https://github.com/user-attachments/assets/c1efb4e3-2a2e-4fd5-8675-d39d9b34b2d0" width="800" controls></video>
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Notebook: Fix GitHub issue in Cursor</b></summary>
|
||||
<br>
|
||||
<div align="center">
|
||||
<video src="https://github.com/user-attachments/assets/" width="800" controls></video>
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
## Accessory Libraries
|
||||
|
||||
| Library | Description | Installation | Version |
|
||||
|---------|-------------|--------------|---------|
|
||||
| [**Core**](./libs/core/README.md) | Core functionality and utilities used by other Cua packages | `pip install cua-core` | [](https://pypi.org/project/cua-core/) |
|
||||
| [**PyLume**](./libs/pylume/README.md) | Python bindings for Lume | `pip install pylume` | [](https://pypi.org/project/pylume/) |
|
||||
| [**Computer Server**](./libs/computer-server/README.md) | Server component for the Computer-Use Interface (CUI) framework | `pip install cua-computer-server` | [](https://pypi.org/project/cua-computer-server/) |
|
||||
| [**SOM**](./libs/som/README.md) | Self-of-Mark library for Agent | `pip install cua-som` | [](https://pypi.org/project/cua-som/) |
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome and greatly appreciate contributions to lume! Whether you're improving documentation, adding new features, fixing bugs, or adding new VM images, your efforts help make lume better for everyone. For detailed instructions on how to contribute, please refer to our [Contributing Guidelines](CONTRIBUTING.md).
|
||||
We welcome and greatly appreciate contributions to Cua! Whether you're improving documentation, adding new features, fixing bugs, or adding new VM images, your efforts help make lume better for everyone. For detailed instructions on how to contribute, please refer to our [Contributing Guidelines](CONTRIBUTING.md).
|
||||
|
||||
Join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get assistance.
|
||||
|
||||
## License
|
||||
|
||||
lume is open-sourced under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
Cua is open-sourced under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
## Trademarks
|
||||
|
||||
@@ -147,7 +84,7 @@ Apple, macOS, and Apple Silicon are trademarks of Apple Inc. Ubuntu and Canonica
|
||||
|
||||
## Stargazers over time
|
||||
|
||||
[](https://starchart.cc/trycua/lume)
|
||||
[](https://starchart.cc/trycua/cua)
|
||||
|
||||
## Contributors
|
||||
|
||||
|
||||
147
docs/Developer-Guide.md
Normal file
147
docs/Developer-Guide.md
Normal file
@@ -0,0 +1,147 @@
|
||||
## Developer Guide
|
||||
|
||||
### Project Structure
|
||||
|
||||
The project is organized as a monorepo with these main packages:
|
||||
- `libs/core/` - Base package with telemetry support
|
||||
- `libs/pylume/` - Python bindings for Lume
|
||||
- `libs/computer/` - Core computer interaction library
|
||||
- `libs/agent/` - AI agent library with multi-provider support
|
||||
- `libs/som/` - Computer vision and NLP processing library (formerly omniparser)
|
||||
- `libs/computer-server/` - Server implementation for computer control
|
||||
- `libs/lume/` - Swift implementation for enhanced macOS integration
|
||||
|
||||
Each package has its own virtual environment and dependencies, managed through PDM.
|
||||
|
||||
### Local Development Setup
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/trycua/cua.git
|
||||
cd cua
|
||||
```
|
||||
|
||||
2. Create a `.env.local` file in the root directory with your API keys:
|
||||
```bash
|
||||
# Required for Anthropic provider
|
||||
ANTHROPIC_API_KEY=your_anthropic_key_here
|
||||
|
||||
# Required for OpenAI provider
|
||||
OPENAI_API_KEY=your_openai_key_here
|
||||
```
|
||||
|
||||
3. Run the build script to set up all packages:
|
||||
```bash
|
||||
./scripts/build.sh
|
||||
```
|
||||
|
||||
This will:
|
||||
- Create a virtual environment for the project
|
||||
- Install all packages in development mode
|
||||
- Set up the correct Python path
|
||||
- Install development tools
|
||||
|
||||
4. Open the workspace in VSCode or Cursor:
|
||||
```bash
|
||||
# Using VSCode or Cursor
|
||||
code .vscode/py.code-workspace
|
||||
|
||||
# For Lume (Swift) development
|
||||
code .vscode/lume.code-workspace
|
||||
```
|
||||
|
||||
Using the workspace file is strongly recommended as it:
|
||||
- Sets up correct Python environments for each package
|
||||
- Configures proper import paths
|
||||
- Enables debugging configurations
|
||||
- Maintains consistent settings across packages
|
||||
|
||||
### Cleanup and Reset
|
||||
|
||||
If you need to clean up the environment and start fresh:
|
||||
|
||||
```bash
|
||||
./scripts/cleanup.sh
|
||||
```
|
||||
|
||||
This will:
|
||||
- Remove all virtual environments
|
||||
- Clean Python cache files and directories
|
||||
- Remove build artifacts
|
||||
- Clean PDM-related files
|
||||
- Reset environment configurations
|
||||
|
||||
### Package Virtual Environments
|
||||
|
||||
The build script creates a shared virtual environment for all packages. The workspace configuration automatically handles import paths with the correct Python path settings.
|
||||
|
||||
### Running Examples
|
||||
|
||||
The Python workspace includes launch configurations for all packages:
|
||||
|
||||
- "Run Computer Examples" - Runs computer examples
|
||||
- "Run Computer API Server" - Runs the computer-server
|
||||
- "Run Omni Agent Examples" - Runs agent examples
|
||||
- "SOM" configurations - Various settings for running SOM
|
||||
|
||||
To run examples:
|
||||
1. Open the workspace file (`.vscode/py.code-workspace`)
|
||||
2. Press F5 or use the Run/Debug view
|
||||
3. Select the desired configuration
|
||||
|
||||
The workspace also includes compound launch configurations:
|
||||
- "Run Computer Examples + Server" - Runs both the Computer Examples and Server simultaneously
|
||||
|
||||
## Release and Publishing Process
|
||||
|
||||
This monorepo contains multiple Python packages that can be published to PyPI. The packages
|
||||
have dependencies on each other in the following order:
|
||||
|
||||
1. `pylume` - Base package for VM management
|
||||
2. `cua-computer` - Computer control interface (depends on pylume)
|
||||
3. `cua-som` - Parser for UI elements (independent, formerly omniparser)
|
||||
4. `cua-agent` - AI agent (depends on cua-computer and optionally cua-som)
|
||||
5. `computer-server` - Server component installed on the sandbox
|
||||
|
||||
#### Workflow Structure
|
||||
|
||||
The publishing process is managed by these GitHub workflow files:
|
||||
|
||||
- **Package-specific workflows**:
|
||||
- `.github/workflows/publish-pylume.yml`
|
||||
- `.github/workflows/publish-computer.yml`
|
||||
- `.github/workflows/publish-som.yml`
|
||||
- `.github/workflows/publish-agent.yml`
|
||||
- `.github/workflows/publish-computer-server.yml`
|
||||
|
||||
- **Coordinator workflow**:
|
||||
- `.github/workflows/publish-all.yml` - Manages global releases and manual selections
|
||||
|
||||
### Version Management
|
||||
|
||||
#### Special Considerations for Pylume
|
||||
|
||||
The `pylume` package requires special handling as it incorporates the binary executable from the [lume repository](https://github.com/trycua/lume):
|
||||
|
||||
- When releasing `pylume`, ensure the version matches a corresponding release in the lume repository
|
||||
- The workflow automatically downloads the matching lume binary and includes it in the pylume package
|
||||
- If you need to release a new version of pylume, make sure to coordinate with a matching lume release
|
||||
|
||||
## Development Workspaces
|
||||
|
||||
This monorepo includes multiple VS Code workspace configurations to optimize the development experience based on which components you're working with:
|
||||
|
||||
### Available Workspace Files
|
||||
|
||||
- **[py.code-workspace](.vscode/py.code-workspace)**: For Python package development (Computer, Agent, SOM, etc.)
|
||||
- **[lume.code-workspace](.vscode/lume.code-workspace)**: For Swift-based Lume development
|
||||
|
||||
To open a specific workspace:
|
||||
|
||||
```bash
|
||||
# For Python development
|
||||
code .vscode/py.code-workspace
|
||||
|
||||
# For Lume (Swift) development
|
||||
code .vscode/lume.code-workspace
|
||||
```
|
||||
53
docs/FAQ.md
53
docs/FAQ.md
@@ -1,55 +1,50 @@
|
||||
# FAQs
|
||||
|
||||
### Where are the VMs stored?
|
||||
### Why a local sandbox?
|
||||
|
||||
VMs are stored in `~/.lume`.
|
||||
A local sandbox is a dedicated environment that is isolated from the rest of the system. As AI agents rapidly evolve towards 70-80% success rates on average tasks, having a controlled and secure environment becomes crucial. Cua's Computer-Use AI agents run in a local sandbox to ensure reliability, safety, and controlled execution.
|
||||
|
||||
### How are images cached?
|
||||
Benefits of using a local sandbox rather than running the Computer-Use AI agent in the host system:
|
||||
|
||||
Images are cached in `~/.lume/cache`. When doing `lume pull <image>`, it will check if the image is already cached. If not, it will download the image and cache it, removing any older versions.
|
||||
- **Reliability**: The sandbox provides a reproducible environment - critical for benchmarking and debugging agent behavior. Frameworks like [OSWorld](https://github.com/xlang-ai/OSWorld), [Simular AI](https://github.com/simular-ai/Agent-S), Microsoft's [OmniTool](https://github.com/microsoft/OmniParser/tree/master/omnitool), [WindowsAgentArena](https://github.com/microsoft/WindowsAgentArena) and more are using Computer-Use AI agents running in local sandboxes.
|
||||
- **Safety & Isolation**: The sandbox is isolated from the rest of the system, protecting sensitive data and system resources. As CUA agent capabilities grow, this isolation becomes increasingly important for preventing potential safety breaches.
|
||||
- **Control**: The sandbox can be easily monitored and terminated if needed, providing oversight for autonomous agent operation.
|
||||
|
||||
### Are VM disks taking up all the disk space?
|
||||
### Where are the sandbox images stored?
|
||||
|
||||
Sandbox are stored in `~/.lume`, and cached images are stored in `~/.lume/cache`.
|
||||
|
||||
### Which image is Computer using?
|
||||
|
||||
Computer uses an optimized macOS image for Computer-Use interactions, with pre-installed apps and settings for optimal performance.
|
||||
The image is available on our [ghcr registry](https://github.com/orgs/trycua/packages/container/package/macos-sequoia-cua).
|
||||
|
||||
### Are Sandbox disks taking up all the disk space?
|
||||
|
||||
No, macOS uses sparse files, which only allocate space as needed. For example, VM disks totaling 50 GB may only use 20 GB on disk.
|
||||
|
||||
### How do I get the latest macOS restore image URL?
|
||||
|
||||
```bash
|
||||
lume ipsw
|
||||
```
|
||||
|
||||
### How do I delete a VM?
|
||||
|
||||
```bash
|
||||
lume delete <name>
|
||||
```
|
||||
|
||||
### How to Install macOS from an IPSW Image
|
||||
### How do I troubleshoot Computer not connecting to lume daemon?
|
||||
|
||||
#### Create a new macOS VM using the latest supported IPSW image:
|
||||
Run the following command to create a new macOS virtual machine using the latest available IPSW image:
|
||||
If you're experiencing connection issues between Computer and the lume daemon, it could be because the port 3000 (used by lume) is already in use by an orphaned process. You can diagnose this issue with:
|
||||
|
||||
```bash
|
||||
lume create <name> --os macos --ipsw latest
|
||||
sudo lsof -i :3000
|
||||
```
|
||||
|
||||
#### Create a new macOS VM using a specific IPSW image:
|
||||
To create a macOS virtual machine from an older or specific IPSW file, first download the desired IPSW (UniversalMac) from a trusted source.
|
||||
|
||||
Then, use the downloaded IPSW path:
|
||||
This command will show all processes using port 3000. If you see a lume process already running, you can terminate it with:
|
||||
|
||||
```bash
|
||||
lume create <name> --os macos --ipsw <downloaded_ipsw_path>
|
||||
kill <PID>
|
||||
```
|
||||
|
||||
### How do I install a custom Linux image?
|
||||
Where `<PID>` is the process ID shown in the output of the `lsof` command. After terminating the process, run `lume serve` again to start the lume daemon.
|
||||
|
||||
The process for creating a custom Linux image differs than macOS, with IPSW restore files not being used. You need to create a linux VM first, then mount a setup image file to the VM for the first boot.
|
||||
### What information does Cua track?
|
||||
|
||||
```bash
|
||||
lume create <name> --os linux
|
||||
|
||||
lume run <name> --mount <path-to-setup-image>
|
||||
|
||||
lume run <name>
|
||||
```
|
||||
Cua tracks anonymized usage and error report statistics; we ascribe to Posthog's approach as detailed [here](https://posthog.com/blog/open-source-telemetry-ethical). If you would like to opt out of sending anonymized info, you can set `telemetry_enabled` to false in the Computer or Agent constructor. Check out our [Telemetry](Telemetry.md) documentation for more details.
|
||||
|
||||
74
docs/Telemetry.md
Normal file
74
docs/Telemetry.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# Telemetry in CUA
|
||||
|
||||
This document explains how telemetry works in CUA libraries and how you can control it.
|
||||
|
||||
CUA tracks anonymized usage and error report statistics; we ascribe to Posthog's approach as detailed [here](https://posthog.com/blog/open-source-telemetry-ethical). If you would like to opt out of sending anonymized info, you can set `telemetry_enabled` to false.
|
||||
|
||||
## What telemetry data we collect
|
||||
|
||||
CUA libraries collect minimal anonymous usage data to help improve our software. The telemetry data we collect is specifically limited to:
|
||||
|
||||
- Basic system information:
|
||||
- Operating system (e.g., 'darwin', 'win32', 'linux')
|
||||
- Python version (e.g., '3.10.0')
|
||||
- Module initialization events:
|
||||
- When a module (like 'computer' or 'agent') is imported
|
||||
- Version of the module being used
|
||||
|
||||
We do NOT collect:
|
||||
- Personal information
|
||||
- Contents of files
|
||||
- Specific text being typed
|
||||
- Actual screenshots or screen contents
|
||||
- User-specific identifiers
|
||||
- API keys
|
||||
- File contents
|
||||
- Application data or content
|
||||
- User interactions with the computer
|
||||
- Information about files being accessed
|
||||
|
||||
## Controlling Telemetry
|
||||
|
||||
We are committed to transparency and user control over telemetry. There are two ways to control telemetry:
|
||||
|
||||
## 1. Environment Variable (Global Control)
|
||||
|
||||
Telemetry is enabled by default. To disable telemetry, set the `CUA_TELEMETRY_ENABLED` environment variable to a falsy value (`0`, `false`, `no`, or `off`):
|
||||
|
||||
```bash
|
||||
# Disable telemetry before running your script
|
||||
export CUA_TELEMETRY_ENABLED=false
|
||||
|
||||
# Or as part of the command
|
||||
CUA_TELEMETRY_ENABLED=1 python your_script.py
|
||||
|
||||
```
|
||||
Or from Python:
|
||||
```python
|
||||
import os
|
||||
os.environ["CUA_TELEMETRY_ENABLED"] = "false"
|
||||
```
|
||||
|
||||
## 2. Instance-Level Control
|
||||
|
||||
You can control telemetry for specific CUA instances by setting `telemetry_enabled` when creating them:
|
||||
|
||||
```python
|
||||
# Disable telemetry for a specific Computer instance
|
||||
computer = Computer(telemetry_enabled=False)
|
||||
|
||||
# Enable telemetry for a specific Agent instance
|
||||
agent = ComputerAgent(telemetry_enabled=True)
|
||||
```
|
||||
|
||||
You can check if telemetry is enabled for an instance:
|
||||
|
||||
```python
|
||||
print(computer.telemetry_enabled) # Will print True or False
|
||||
```
|
||||
|
||||
Note that telemetry settings must be configured during initialization and cannot be changed after the object is created.
|
||||
|
||||
## Transparency
|
||||
|
||||
We believe in being transparent about the data we collect. If you have any questions about our telemetry practices, please open an issue on our GitHub repository.
|
||||
99
examples/agent_examples.py
Normal file
99
examples/agent_examples.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Example demonstrating the ComputerAgent capabilities with the Omni provider."""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import signal
|
||||
|
||||
from computer import Computer
|
||||
|
||||
# Import the unified agent class and types
|
||||
from agent import ComputerAgent, AgentLoop, LLMProvider, LLM
|
||||
|
||||
# Import utility functions
|
||||
from utils import load_dotenv_files, handle_sigint
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_omni_agent_example():
|
||||
"""Run example of using the ComputerAgent with OpenAI and Omni provider."""
|
||||
print(f"\n=== Example: ComputerAgent with OpenAI and Omni provider ===")
|
||||
try:
|
||||
# Create Computer instance with default parameters
|
||||
computer = Computer(verbosity=logging.DEBUG)
|
||||
|
||||
# Create agent with loop and provider
|
||||
agent = ComputerAgent(
|
||||
computer=computer,
|
||||
# loop=AgentLoop.OMNI,
|
||||
loop=AgentLoop.ANTHROPIC,
|
||||
# model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"),
|
||||
model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
|
||||
save_trajectory=True,
|
||||
trajectory_dir=str(Path("trajectories")),
|
||||
only_n_most_recent_images=3,
|
||||
verbosity=logging.INFO,
|
||||
)
|
||||
|
||||
tasks = [
|
||||
"""
|
||||
1. Look for a repository named trycua/lume on GitHub.
|
||||
2. Check the open issues, open the most recent one and read it.
|
||||
3. Clone the repository in users/lume/projects if it doesn't exist yet.
|
||||
4. Open the repository with an app named Cursor (on the dock, black background and white cube icon).
|
||||
5. From Cursor, open Composer if not already open.
|
||||
6. Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.
|
||||
"""
|
||||
]
|
||||
|
||||
async with agent:
|
||||
for i, task in enumerate(tasks, 1):
|
||||
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
|
||||
async for result in agent.run(task):
|
||||
# Check if result has the expected structure
|
||||
if "role" in result and "content" in result and "metadata" in result:
|
||||
title = result["metadata"].get("title", "Screen Analysis")
|
||||
content = result["content"]
|
||||
else:
|
||||
title = result.get("metadata", {}).get("title", "Screen Analysis")
|
||||
content = result.get("content", str(result))
|
||||
|
||||
print(f"\n{title}")
|
||||
print(content)
|
||||
print(f"Task {i} completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_anthropic_agent_example: {e}")
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clean up resources
|
||||
if computer and computer._initialized:
|
||||
try:
|
||||
await computer.stop()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping computer: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the Anthropic agent example."""
|
||||
try:
|
||||
load_dotenv_files()
|
||||
|
||||
# Register signal handler for graceful exit
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
asyncio.run(run_omni_agent_example())
|
||||
except Exception as e:
|
||||
print(f"Error running example: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
97
examples/computer_examples.py
Normal file
97
examples/computer_examples.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import os
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import json
|
||||
import traceback
|
||||
|
||||
# Load environment variables from .env file
|
||||
project_root = Path(__file__).parent.parent
|
||||
env_file = project_root / ".env"
|
||||
print(f"Loading environment from: {env_file}")
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(env_file)
|
||||
|
||||
# Add paths to sys.path if needed
|
||||
pythonpath = os.environ.get("PYTHONPATH", "")
|
||||
for path in pythonpath.split(":"):
|
||||
if path and path not in sys.path:
|
||||
sys.path.append(path)
|
||||
print(f"Added to sys.path: {path}")
|
||||
|
||||
from computer.computer import Computer
|
||||
from computer.logger import LogLevel
|
||||
from computer.utils import get_image_size
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
print("\n=== Using direct initialization ===")
|
||||
computer = Computer(
|
||||
display="1024x768", # Higher resolution
|
||||
memory="8GB", # More memory
|
||||
cpu="4", # More CPU cores
|
||||
os="macos",
|
||||
verbosity=LogLevel.NORMAL, # Use QUIET to suppress most logs
|
||||
use_host_computer_server=False,
|
||||
)
|
||||
try:
|
||||
await computer.run()
|
||||
|
||||
await computer.interface.hotkey("command", "space")
|
||||
|
||||
# res = await computer.interface.run_command("touch ./Downloads/empty_file")
|
||||
# print(f"Run command result: {res}")
|
||||
|
||||
accessibility_tree = await computer.interface.get_accessibility_tree()
|
||||
print(f"Accessibility tree: {accessibility_tree}")
|
||||
|
||||
# Screen Actions Examples
|
||||
print("\n=== Screen Actions ===")
|
||||
screenshot = await computer.interface.screenshot()
|
||||
with open("screenshot_direct.png", "wb") as f:
|
||||
f.write(screenshot)
|
||||
|
||||
screen_size = await computer.interface.get_screen_size()
|
||||
print(f"Screen size: {screen_size}")
|
||||
|
||||
# Demonstrate coordinate conversion
|
||||
center_x, center_y = 733, 736
|
||||
print(f"Center in screen coordinates: ({center_x}, {center_y})")
|
||||
|
||||
screenshot_center = await computer.to_screenshot_coordinates(center_x, center_y)
|
||||
print(f"Center in screenshot coordinates: {screenshot_center}")
|
||||
|
||||
screen_center = await computer.to_screen_coordinates(*screenshot_center)
|
||||
print(f"Back to screen coordinates: {screen_center}")
|
||||
|
||||
# Mouse Actions Examples
|
||||
print("\n=== Mouse Actions ===")
|
||||
await computer.interface.move_cursor(100, 100)
|
||||
await computer.interface.left_click()
|
||||
await computer.interface.right_click(300, 300)
|
||||
await computer.interface.double_click(400, 400)
|
||||
|
||||
# Keyboard Actions Examples
|
||||
print("\n=== Keyboard Actions ===")
|
||||
await computer.interface.type_text("Hello, World!")
|
||||
await computer.interface.press_key("enter")
|
||||
|
||||
# Clipboard Actions Examples
|
||||
print("\n=== Clipboard Actions ===")
|
||||
await computer.interface.set_clipboard("Test clipboard")
|
||||
content = await computer.interface.copy_to_clipboard()
|
||||
print(f"Clipboard content: {content}")
|
||||
|
||||
finally:
|
||||
# Important to clean up resources
|
||||
pass
|
||||
# await computer.stop()
|
||||
except Exception as e:
|
||||
print(f"Error in main: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
98
examples/pylume_examples.py
Normal file
98
examples/pylume_examples.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
from pylume import (
|
||||
PyLume,
|
||||
ImageRef,
|
||||
VMRunOpts,
|
||||
SharedDirectory,
|
||||
VMConfig,
|
||||
VMUpdateOpts
|
||||
)
|
||||
|
||||
async def main():
|
||||
"""Example usage of PyLume."""
|
||||
async with PyLume(port=3000, use_existing_server=False, debug=True) as pylume:
|
||||
|
||||
# Get latest IPSW URL
|
||||
print("\n=== Getting Latest IPSW URL ===")
|
||||
url = await pylume.get_latest_ipsw_url()
|
||||
print("Latest IPSW URL:", url)
|
||||
|
||||
# Create a new VM
|
||||
print("\n=== Creating a new VM ===")
|
||||
vm_config = VMConfig(
|
||||
name="lume-vm-new",
|
||||
os="macOS",
|
||||
cpu=2,
|
||||
memory="4GB",
|
||||
disk_size="64GB",
|
||||
display="1024x768",
|
||||
ipsw="latest"
|
||||
)
|
||||
await pylume.create_vm(vm_config)
|
||||
|
||||
# Get latest IPSW URL
|
||||
print("\n=== Getting Latest IPSW URL ===")
|
||||
url = await pylume.get_latest_ipsw_url()
|
||||
print("Latest IPSW URL:", url)
|
||||
|
||||
# List available images
|
||||
print("\n=== Listing Available Images ===")
|
||||
images = await pylume.get_images()
|
||||
print("Available Images:", images)
|
||||
|
||||
# List all VMs to verify creation
|
||||
print("\n=== Listing All VMs ===")
|
||||
vms = await pylume.list_vms()
|
||||
print("VMs:", vms)
|
||||
|
||||
# Get specific VM details
|
||||
print("\n=== Getting VM Details ===")
|
||||
vm = await pylume.get_vm("lume-vm")
|
||||
print("VM Details:", vm)
|
||||
|
||||
# Update VM settings
|
||||
print("\n=== Updating VM Settings ===")
|
||||
update_opts = VMUpdateOpts(
|
||||
cpu=8,
|
||||
memory="4GB"
|
||||
)
|
||||
await pylume.update_vm("lume-vm", update_opts)
|
||||
|
||||
# Pull an image
|
||||
image_ref = ImageRef(
|
||||
image="macos-sequoia-vanilla",
|
||||
tag="latest",
|
||||
registry="ghcr.io",
|
||||
organization="trycua"
|
||||
)
|
||||
await pylume.pull_image(image_ref, name="lume-vm-pulled")
|
||||
|
||||
# Run with shared directory
|
||||
run_opts = VMRunOpts(
|
||||
no_display=False,
|
||||
shared_directories=[
|
||||
SharedDirectory(
|
||||
host_path="~/shared",
|
||||
read_only=False
|
||||
)
|
||||
]
|
||||
)
|
||||
await pylume.run_vm("lume-vm", run_opts)
|
||||
|
||||
# Or simpler:
|
||||
await pylume.run_vm("lume-vm")
|
||||
|
||||
# Clone VM
|
||||
print("\n=== Cloning VM ===")
|
||||
await pylume.clone_vm("lume-vm", "lume-vm-cloned")
|
||||
|
||||
# Stop VM
|
||||
print("\n=== Stopping VM ===")
|
||||
await pylume.stop_vm("lume-vm")
|
||||
|
||||
# Delete VM
|
||||
print("\n=== Deleting VM ===")
|
||||
await pylume.delete_vm("lume-vm-cloned")
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
437
examples/som_examples.py
Normal file
437
examples/som_examples.py
Normal file
@@ -0,0 +1,437 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example script demonstrating the usage of OmniParser's UI element detection functionality.
|
||||
This script shows how to:
|
||||
1. Initialize the OmniParser
|
||||
2. Load and process images
|
||||
3. Visualize detection results
|
||||
4. Compare performance between CPU and MPS (Apple Silicon)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import time
|
||||
from PIL import Image
|
||||
from typing import Dict, Any, List, Optional
|
||||
import numpy as np
|
||||
import io
|
||||
import base64
|
||||
import glob
|
||||
import os
|
||||
|
||||
# Load environment variables from .env file
|
||||
project_root = Path(__file__).parent.parent
|
||||
env_file = project_root / ".env"
|
||||
print(f"Loading environment from: {env_file}")
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(env_file)
|
||||
|
||||
# Add paths to sys.path if needed
|
||||
pythonpath = os.environ.get("PYTHONPATH", "")
|
||||
for path in pythonpath.split(":"):
|
||||
if path and path not in sys.path:
|
||||
sys.path.append(path)
|
||||
print(f"Added to sys.path: {path}")
|
||||
|
||||
# Add the libs directory to the path to find som
|
||||
libs_path = project_root / "libs"
|
||||
if str(libs_path) not in sys.path:
|
||||
sys.path.append(str(libs_path))
|
||||
print(f"Added to sys.path: {libs_path}")
|
||||
|
||||
from som import OmniParser, ParseResult, IconElement, TextElement
|
||||
from som.models import UIElement, ParserMetadata, BoundingBox
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""Configure logging with a nice format."""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Enhanced context manager for timing code blocks."""
|
||||
|
||||
def __init__(self, name: str, logger):
|
||||
self.name = name
|
||||
self.logger = logger
|
||||
self.start_time: float = 0.0
|
||||
self.elapsed_time: float = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.elapsed_time = time.time() - self.start_time
|
||||
self.logger.info(f"{self.name}: {self.elapsed_time:.3f}s")
|
||||
return False
|
||||
|
||||
|
||||
def image_to_bytes(image: Image.Image) -> bytes:
|
||||
"""Convert PIL Image to PNG bytes."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def process_image(
|
||||
parser: OmniParser, image_path: str, output_dir: Path, use_ocr: bool = False
|
||||
) -> None:
|
||||
"""Process a single image and save the result."""
|
||||
try:
|
||||
# Load image
|
||||
logger.info(f"Processing image: {image_path}")
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
logger.info(f"Image loaded successfully, size: {image.size}")
|
||||
|
||||
# Create output filename
|
||||
input_filename = Path(image_path).stem
|
||||
output_path = output_dir / f"{input_filename}_analyzed.png"
|
||||
|
||||
# Convert image to PNG bytes
|
||||
image_bytes = image_to_bytes(image)
|
||||
|
||||
# Process image
|
||||
with Timer(f"Processing {input_filename}", logger):
|
||||
result = parser.parse(image_bytes, use_ocr=use_ocr)
|
||||
logger.info(
|
||||
f"Found {result.metadata.num_icons} icons and {result.metadata.num_text} text elements"
|
||||
)
|
||||
|
||||
# Save the annotated image
|
||||
logger.info(f"Saving annotated image to: {output_path}")
|
||||
try:
|
||||
# Save image from base64
|
||||
img_data = base64.b64decode(result.annotated_image_base64)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
img.save(output_path)
|
||||
|
||||
# Print detailed results
|
||||
logger.info("\nDetected Elements:")
|
||||
for elem in result.elements:
|
||||
if isinstance(elem, IconElement):
|
||||
logger.info(
|
||||
f"Icon: confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}"
|
||||
)
|
||||
elif isinstance(elem, TextElement):
|
||||
logger.info(
|
||||
f"Text: '{elem.content}', confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}"
|
||||
)
|
||||
|
||||
# Verify file exists and log size
|
||||
if output_path.exists():
|
||||
logger.info(
|
||||
f"Successfully saved image. File size: {output_path.stat().st_size} bytes"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to verify file at {output_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving image: {str(e)}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {image_path}: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
def run_detection_benchmark(
|
||||
input_path: str,
|
||||
output_dir: Path,
|
||||
use_ocr: bool = False,
|
||||
box_threshold: float = 0.01,
|
||||
iou_threshold: float = 0.1,
|
||||
):
|
||||
"""Run detection benchmark on images."""
|
||||
logger.info(
|
||||
f"Starting benchmark with OCR enabled: {use_ocr}, box_threshold: {box_threshold}, iou_threshold: {iou_threshold}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Initialize parser
|
||||
logger.info("Initializing OmniParser...")
|
||||
parser = OmniParser()
|
||||
|
||||
# Create output directory
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Output directory created at: {output_dir}")
|
||||
|
||||
# Get list of PNG files
|
||||
if os.path.isdir(input_path):
|
||||
image_files = glob.glob(os.path.join(input_path, "*.png"))
|
||||
else:
|
||||
image_files = [input_path]
|
||||
|
||||
logger.info(f"Found {len(image_files)} images to process")
|
||||
|
||||
# Process each image with specified thresholds
|
||||
for image_path in image_files:
|
||||
try:
|
||||
# Load image
|
||||
logger.info(f"Processing image: {image_path}")
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
logger.info(f"Image loaded successfully, size: {image.size}")
|
||||
|
||||
# Create output filename
|
||||
input_filename = Path(image_path).stem
|
||||
output_path = output_dir / f"{input_filename}_analyzed.png"
|
||||
|
||||
# Convert image to PNG bytes
|
||||
image_bytes = image_to_bytes(image)
|
||||
|
||||
# Process image with specified thresholds
|
||||
with Timer(f"Processing {input_filename}", logger):
|
||||
result = parser.parse(
|
||||
image_bytes,
|
||||
use_ocr=use_ocr,
|
||||
box_threshold=box_threshold,
|
||||
iou_threshold=iou_threshold,
|
||||
)
|
||||
logger.info(
|
||||
f"Found {result.metadata.num_icons} icons and {result.metadata.num_text} text elements"
|
||||
)
|
||||
|
||||
# Save the annotated image
|
||||
logger.info(f"Saving annotated image to: {output_path}")
|
||||
try:
|
||||
# Save image from base64
|
||||
img_data = base64.b64decode(result.annotated_image_base64)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
img.save(output_path)
|
||||
|
||||
# Print detailed results
|
||||
logger.info("\nDetected Elements:")
|
||||
for elem in result.elements:
|
||||
if isinstance(elem, IconElement):
|
||||
logger.info(
|
||||
f"Icon: confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}"
|
||||
)
|
||||
elif isinstance(elem, TextElement):
|
||||
logger.info(
|
||||
f"Text: '{elem.content}', confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}"
|
||||
)
|
||||
|
||||
# Verify file exists and log size
|
||||
if output_path.exists():
|
||||
logger.info(
|
||||
f"Successfully saved image. File size: {output_path.stat().st_size} bytes"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to verify file at {output_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving image: {str(e)}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {image_path}: {str(e)}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Benchmark failed: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def run_experiments(input_path: str, output_dir: Path, use_ocr: bool = False):
|
||||
"""Run experiments with different threshold combinations."""
|
||||
# Define threshold values to test
|
||||
box_thresholds = [0.01, 0.05, 0.1, 0.3]
|
||||
iou_thresholds = [0.05, 0.1, 0.2, 0.5]
|
||||
|
||||
logger.info("Starting threshold experiments...")
|
||||
logger.info("Box thresholds to test: %s", box_thresholds)
|
||||
logger.info("IOU thresholds to test: %s", iou_thresholds)
|
||||
|
||||
# Create results directory for this experiment
|
||||
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||
ocr_suffix = "_ocr" if use_ocr else "_no_ocr"
|
||||
exp_dir = output_dir / f"experiment_{timestamp}{ocr_suffix}"
|
||||
exp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a summary file
|
||||
summary_file = exp_dir / "results_summary.txt"
|
||||
with open(summary_file, "w") as f:
|
||||
f.write("Threshold Experiments Results\n")
|
||||
f.write("==========================\n\n")
|
||||
f.write(f"Input: {input_path}\n")
|
||||
f.write(f"OCR Enabled: {use_ocr}\n")
|
||||
f.write(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
f.write("Results:\n")
|
||||
f.write("-" * 80 + "\n")
|
||||
f.write(
|
||||
f"{'Box Thresh':^10} | {'IOU Thresh':^10} | {'Num Icons':^10} | {'Num Text':^10} | {'Time (s)':^10}\n"
|
||||
)
|
||||
f.write("-" * 80 + "\n")
|
||||
|
||||
# Initialize parser once for all experiments
|
||||
parser = OmniParser()
|
||||
|
||||
# Run experiments with each combination
|
||||
for box_thresh in box_thresholds:
|
||||
for iou_thresh in iou_thresholds:
|
||||
logger.info(f"\nTesting box_threshold={box_thresh}, iou_threshold={iou_thresh}")
|
||||
|
||||
# Create directory for this combination
|
||||
combo_dir = exp_dir / f"box_{box_thresh}_iou_{iou_thresh}"
|
||||
combo_dir.mkdir(exist_ok=True)
|
||||
|
||||
try:
|
||||
# Process each image
|
||||
if os.path.isdir(input_path):
|
||||
image_files = glob.glob(os.path.join(input_path, "*.png"))
|
||||
else:
|
||||
image_files = [input_path]
|
||||
|
||||
total_icons = 0
|
||||
total_text = 0
|
||||
total_time = 0
|
||||
|
||||
for image_path in image_files:
|
||||
# Load and process image
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image_bytes = image_to_bytes(image)
|
||||
|
||||
# Process with current thresholds
|
||||
with Timer(f"Processing {Path(image_path).stem}", logger) as t:
|
||||
result = parser.parse(
|
||||
image_bytes,
|
||||
use_ocr=use_ocr,
|
||||
box_threshold=box_thresh,
|
||||
iou_threshold=iou_thresh,
|
||||
)
|
||||
|
||||
# Save annotated image
|
||||
output_path = combo_dir / f"{Path(image_path).stem}_analyzed.png"
|
||||
img_data = base64.b64decode(result.annotated_image_base64)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
img.save(output_path)
|
||||
|
||||
# Update totals
|
||||
total_icons += result.metadata.num_icons
|
||||
total_text += result.metadata.num_text
|
||||
total_time += t.elapsed_time
|
||||
|
||||
# Log detailed results
|
||||
detail_file = combo_dir / f"{Path(image_path).stem}_details.txt"
|
||||
with open(detail_file, "w") as detail_f:
|
||||
detail_f.write(f"Results for {Path(image_path).name}\n")
|
||||
detail_f.write("-" * 40 + "\n")
|
||||
detail_f.write(f"Number of icons: {result.metadata.num_icons}\n")
|
||||
detail_f.write(
|
||||
f"Number of text elements: {result.metadata.num_text}\n\n"
|
||||
)
|
||||
|
||||
detail_f.write("Icon Detections:\n")
|
||||
icon_count = 1
|
||||
text_count = (
|
||||
result.metadata.num_icons + 1
|
||||
) # Text boxes start after icons
|
||||
|
||||
# First list all icons
|
||||
for elem in result.elements:
|
||||
if isinstance(elem, IconElement):
|
||||
detail_f.write(f"Box #{icon_count}: Icon\n")
|
||||
detail_f.write(f" - Confidence: {elem.confidence:.3f}\n")
|
||||
detail_f.write(
|
||||
f" - Coordinates: {elem.bbox.coordinates}\n"
|
||||
)
|
||||
icon_count += 1
|
||||
|
||||
if use_ocr:
|
||||
detail_f.write("\nText Detections:\n")
|
||||
for elem in result.elements:
|
||||
if isinstance(elem, TextElement):
|
||||
detail_f.write(f"Box #{text_count}: Text\n")
|
||||
detail_f.write(f" - Content: '{elem.content}'\n")
|
||||
detail_f.write(
|
||||
f" - Confidence: {elem.confidence:.3f}\n"
|
||||
)
|
||||
detail_f.write(
|
||||
f" - Coordinates: {elem.bbox.coordinates}\n"
|
||||
)
|
||||
text_count += 1
|
||||
|
||||
# Write summary for this combination
|
||||
avg_time = total_time / len(image_files)
|
||||
f.write(
|
||||
f"{box_thresh:^10.3f} | {iou_thresh:^10.3f} | {total_icons:^10d} | {total_text:^10d} | {avg_time:^10.3f}\n"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in experiment box={box_thresh}, iou={iou_thresh}: {str(e)}"
|
||||
)
|
||||
f.write(
|
||||
f"{box_thresh:^10.3f} | {iou_thresh:^10.3f} | {'ERROR':^10s} | {'ERROR':^10s} | {'ERROR':^10s}\n"
|
||||
)
|
||||
|
||||
# Write summary footer
|
||||
f.write("-" * 80 + "\n")
|
||||
f.write("\nExperiment completed successfully!\n")
|
||||
|
||||
logger.info(f"\nExperiment results saved to {exp_dir}")
|
||||
logger.info(f"Summary file: {summary_file}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(description="Run OmniParser benchmark")
|
||||
parser.add_argument("input_path", help="Path to input image or directory containing images")
|
||||
parser.add_argument(
|
||||
"--output-dir", default="examples/output", help="Output directory for annotated images"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ocr",
|
||||
choices=["none", "easyocr"],
|
||||
default="none",
|
||||
help="OCR engine to use (default: none)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["single", "experiment"],
|
||||
default="single",
|
||||
help="Run mode: single run or threshold experiments (default: single)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--box-threshold",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Confidence threshold for detection (default: 0.01)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iou-threshold",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="IOU threshold for Non-Maximum Suppression (default: 0.1)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Starting OmniParser with arguments: {args}")
|
||||
use_ocr = args.ocr != "none"
|
||||
output_dir = Path(args.output_dir)
|
||||
|
||||
try:
|
||||
if args.mode == "experiment":
|
||||
run_experiments(args.input_path, output_dir, use_ocr)
|
||||
else:
|
||||
run_detection_benchmark(
|
||||
args.input_path, output_dir, use_ocr, args.box_threshold, args.iou_threshold
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Process failed: {str(e)}", exc_info=True)
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
55
examples/utils.py
Normal file
55
examples/utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Utility functions for example scripts."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def load_env_file(path: Path) -> bool:
|
||||
"""Load environment variables from a file.
|
||||
|
||||
Args:
|
||||
path: Path to the .env file
|
||||
|
||||
Returns:
|
||||
True if file was loaded successfully, False otherwise
|
||||
"""
|
||||
if not path.exists():
|
||||
return False
|
||||
|
||||
print(f"Loading environment from {path}")
|
||||
with open(path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
key, value = line.split("=", 1)
|
||||
os.environ[key] = value
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def load_dotenv_files():
|
||||
"""Load environment variables from .env files.
|
||||
|
||||
Tries to load from .env.local first, then .env if .env.local doesn't exist.
|
||||
"""
|
||||
# Get the project root directory (parent of the examples directory)
|
||||
project_root = Path(__file__).parent.parent
|
||||
|
||||
# Try loading .env.local first, then .env if .env.local doesn't exist
|
||||
env_local_path = project_root / ".env.local"
|
||||
env_path = project_root / ".env"
|
||||
|
||||
# Load .env.local if it exists, otherwise try .env
|
||||
if not load_env_file(env_local_path):
|
||||
load_env_file(env_path)
|
||||
|
||||
|
||||
def handle_sigint(signum, frame):
|
||||
"""Handle SIGINT (Ctrl+C) gracefully."""
|
||||
print("\nExiting gracefully...")
|
||||
sys.exit(0)
|
||||
BIN
img/agent.png
Normal file
BIN
img/agent.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.1 MiB |
BIN
img/computer.png
Normal file
BIN
img/computer.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.6 MiB |
74
libs/agent/README.md
Normal file
74
libs/agent/README.md
Normal file
@@ -0,0 +1,74 @@
|
||||
<div align="center">
|
||||
<h1>
|
||||
<div class="image-wrapper" style="display: inline-block;">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="../../img/logo_white.png" style="display: block; margin: auto;">
|
||||
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="../../img/logo_black.png" style="display: block; margin: auto;">
|
||||
<img alt="Shows my svg">
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer/)
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
**Agent** is a Computer Use (CUA) framework for running multi-app agentic workflows targeting macOS and Linux sandbox, supporting local (Ollama) and cloud model providers (OpenAI, Anthropic, Groq, DeepSeek, Qwen). The framework integrates with Microsoft's OmniParser for enhanced UI understanding and interaction.
|
||||
|
||||
### Get started with Agent
|
||||
|
||||
```python
|
||||
from agent import ComputerAgent, AgentLoop, LLMProvider
|
||||
from computer import Computer
|
||||
|
||||
computer = Computer(verbosity=logging.INFO)
|
||||
|
||||
agent = ComputerAgent(
|
||||
computer=computer,
|
||||
loop=AgentLoop.ANTHROPIC,
|
||||
# loop=AgentLoop.OMNI,
|
||||
model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
|
||||
# model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"),
|
||||
save_trajectory=True,
|
||||
trajectory_dir=str(Path("trajectories")),
|
||||
only_n_most_recent_images=3,
|
||||
verbosity=logging.INFO,
|
||||
)
|
||||
|
||||
tasks = [
|
||||
"""
|
||||
Please help me with the following task:
|
||||
1. Open Safari browser
|
||||
2. Go to Wikipedia.org
|
||||
3. Search for "Claude AI"
|
||||
4. Summarize the main points you find about Claude AI
|
||||
"""
|
||||
]
|
||||
|
||||
async with agent:
|
||||
for i, task in enumerate(tasks, 1):
|
||||
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
|
||||
async for result in agent.run(task):
|
||||
print(result)
|
||||
print(f"Task {i} completed")
|
||||
```
|
||||
|
||||
## Install
|
||||
|
||||
### cua-agent
|
||||
|
||||
```bash
|
||||
pip install "cua-agent[all]"
|
||||
|
||||
# or install specific loop providers
|
||||
pip install "cua-agent[anthropic]"
|
||||
pip install "cua-agent[omni]"
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
Refer to these notebooks for step-by-step guides on how to use the Computer-Use Agent (CUA):
|
||||
|
||||
- [Agent Notebook](../../notebooks/agent_nb.ipynb) - Complete examples and workflows
|
||||
63
libs/agent/agent/README.md
Normal file
63
libs/agent/agent/README.md
Normal file
@@ -0,0 +1,63 @@
|
||||
# Agent Package Structure
|
||||
|
||||
## Overview
|
||||
The agent package provides a modular and extensible framework for AI-powered computer agents.
|
||||
|
||||
## Directory Structure
|
||||
```
|
||||
agent/
|
||||
├── __init__.py # Package exports
|
||||
├── core/ # Core functionality
|
||||
│ ├── __init__.py
|
||||
│ ├── computer_agent.py # Main entry point
|
||||
│ └── factory.py # Provider factory
|
||||
├── base/ # Base implementations
|
||||
│ ├── __init__.py
|
||||
│ ├── agent.py # Base agent class
|
||||
│ ├── core/ # Core components
|
||||
│ │ ├── callbacks.py
|
||||
│ │ ├── loop.py
|
||||
│ │ └── messages.py
|
||||
│ └── tools/ # Tool implementations
|
||||
├── providers/ # Provider implementations
|
||||
│ ├── __init__.py
|
||||
│ ├── anthropic/ # Anthropic provider
|
||||
│ │ ├── agent.py
|
||||
│ │ ├── loop.py
|
||||
│ │ └── tool_manager.py
|
||||
│ └── omni/ # Omni provider
|
||||
│ ├── agent.py
|
||||
│ ├── loop.py
|
||||
│ └── tool_manager.py
|
||||
└── types/ # Type definitions
|
||||
├── __init__.py
|
||||
├── base.py # Core types
|
||||
├── messages.py # Message types
|
||||
├── tools.py # Tool types
|
||||
└── providers/ # Provider-specific types
|
||||
├── anthropic.py
|
||||
└── omni.py
|
||||
```
|
||||
|
||||
## Key Components
|
||||
|
||||
### Core
|
||||
- `computer_agent.py`: Main entry point for creating and using agents
|
||||
- `factory.py`: Factory for creating provider-specific implementations
|
||||
|
||||
### Base
|
||||
- `agent.py`: Base agent implementation with shared functionality
|
||||
- `core/`: Core components used across providers
|
||||
- `tools/`: Shared tool implementations
|
||||
|
||||
### Providers
|
||||
Each provider follows the same structure:
|
||||
- `agent.py`: Provider-specific agent implementation
|
||||
- `loop.py`: Provider-specific message loop
|
||||
- `tool_manager.py`: Tool management for provider
|
||||
|
||||
### Types
|
||||
- `base.py`: Core type definitions
|
||||
- `messages.py`: Message-related types
|
||||
- `tools.py`: Tool-related types
|
||||
- `providers/`: Provider-specific type definitions
|
||||
56
libs/agent/agent/__init__.py
Normal file
56
libs/agent/agent/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""CUA (Computer Use) Agent for AI-driven computer interaction."""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# Initialize logging
|
||||
logger = logging.getLogger("cua.agent")
|
||||
|
||||
# Initialize telemetry when the package is imported
|
||||
try:
|
||||
# Import from core telemetry for basic functions
|
||||
from core.telemetry import (
|
||||
is_telemetry_enabled,
|
||||
flush,
|
||||
record_event,
|
||||
)
|
||||
|
||||
# Import set_dimension from our own telemetry module
|
||||
from .core.telemetry import set_dimension
|
||||
|
||||
# Check if telemetry is enabled
|
||||
if is_telemetry_enabled():
|
||||
logger.info("Telemetry is enabled")
|
||||
|
||||
# Record package initialization
|
||||
record_event(
|
||||
"module_init",
|
||||
{
|
||||
"module": "agent",
|
||||
"version": __version__,
|
||||
"python_version": sys.version,
|
||||
},
|
||||
)
|
||||
|
||||
# Set the package version as a dimension
|
||||
set_dimension("agent_version", __version__)
|
||||
|
||||
# Flush events to ensure they're sent
|
||||
flush()
|
||||
else:
|
||||
logger.info("Telemetry is disabled")
|
||||
except ImportError as e:
|
||||
# Telemetry not available
|
||||
logger.warning(f"Telemetry not available: {e}")
|
||||
except Exception as e:
|
||||
# Other issues with telemetry
|
||||
logger.warning(f"Error initializing telemetry: {e}")
|
||||
|
||||
from .core.factory import AgentFactory
|
||||
from .core.agent import ComputerAgent
|
||||
from .providers.omni.types import LLMProvider, LLM
|
||||
from .types.base import Provider, AgentLoop
|
||||
|
||||
__all__ = ["AgentFactory", "Provider", "ComputerAgent", "AgentLoop", "LLMProvider", "LLM"]
|
||||
101
libs/agent/agent/core/README.md
Normal file
101
libs/agent/agent/core/README.md
Normal file
@@ -0,0 +1,101 @@
|
||||
# Unified ComputerAgent
|
||||
|
||||
The `ComputerAgent` class provides a unified implementation that consolidates the previously separate agent implementations (AnthropicComputerAgent and OmniComputerAgent) into a single, configurable class.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Loop Types**: Switch between different agentic loop implementations using the `loop_type` parameter (Anthropic or Omni).
|
||||
- **Provider Support**: Use different AI providers (OpenAI, Anthropic, etc.) with the appropriate loop.
|
||||
- **Trajectory Saving**: Control whether to save screenshots and logs with the `save_trajectory` parameter.
|
||||
- **Consistent Interface**: Maintains a consistent interface regardless of the underlying loop implementation.
|
||||
|
||||
## API Key Requirements
|
||||
|
||||
To use the ComputerAgent, you'll need API keys for the providers you want to use:
|
||||
|
||||
- For **OpenAI**: Set the `OPENAI_API_KEY` environment variable or pass it directly as `api_key`.
|
||||
- For **Anthropic**: Set the `ANTHROPIC_API_KEY` environment variable or pass it directly as `api_key`.
|
||||
- For **Groq**: Set the `GROQ_API_KEY` environment variable or pass it directly as `api_key`.
|
||||
|
||||
You can set environment variables in several ways:
|
||||
|
||||
```bash
|
||||
# In your terminal before running the code
|
||||
export OPENAI_API_KEY=your_api_key_here
|
||||
|
||||
# Or in a .env file
|
||||
OPENAI_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Here's how to use the unified ComputerAgent:
|
||||
|
||||
```python
|
||||
from agent.core.agent import ComputerAgent
|
||||
from agent.types.base import AgenticLoop
|
||||
from agent.providers.omni.types import LLMProvider
|
||||
from computer import Computer
|
||||
|
||||
# Create a Computer instance
|
||||
computer = Computer()
|
||||
|
||||
# Create an agent with the OMNI loop and OpenAI provider
|
||||
agent = ComputerAgent(
|
||||
computer=computer,
|
||||
loop_type=AgenticLoop.OMNI,
|
||||
provider=LLMProvider.OPENAI,
|
||||
model="gpt-4o",
|
||||
api_key="your_api_key_here", # Can also use OPENAI_API_KEY environment variable
|
||||
save_trajectory=True,
|
||||
only_n_most_recent_images=5
|
||||
)
|
||||
|
||||
# Create an agent with the ANTHROPIC loop
|
||||
agent = ComputerAgent(
|
||||
computer=computer,
|
||||
loop_type=AgenticLoop.ANTHROPIC,
|
||||
model="claude-3-7-sonnet-20250219",
|
||||
api_key="your_api_key_here", # Can also use ANTHROPIC_API_KEY environment variable
|
||||
save_trajectory=True,
|
||||
only_n_most_recent_images=5
|
||||
)
|
||||
|
||||
# Use the agent
|
||||
async with agent:
|
||||
async for result in agent.run("Your task description here"):
|
||||
# Process the result
|
||||
title = result["metadata"].get("title", "Screen Analysis")
|
||||
content = result["content"]
|
||||
print(f"\n{title}")
|
||||
print(content)
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- `computer`: Computer instance to control
|
||||
- `loop_type`: The type of loop to use (AgenticLoop.ANTHROPIC or AgenticLoop.OMNI)
|
||||
- `provider`: AI provider to use (required for Omni loop)
|
||||
- `api_key`: Optional API key (will use environment variable if not provided)
|
||||
- `model`: Optional model name (will use provider default if not specified)
|
||||
- `save_trajectory`: Whether to save screenshots and logs
|
||||
- `only_n_most_recent_images`: Only keep N most recent images
|
||||
- `max_retries`: Maximum number of retry attempts
|
||||
|
||||
## Directory Structure
|
||||
|
||||
When `save_trajectory` is enabled, the agent will create the following directory structure:
|
||||
|
||||
```
|
||||
experiments/
|
||||
├── screenshots/ # Screenshots captured during agent execution
|
||||
└── logs/ # API call logs and other logging information
|
||||
```
|
||||
|
||||
## Extending with New Loop Types
|
||||
|
||||
To add a new loop type:
|
||||
|
||||
1. Implement a new loop class
|
||||
2. Add a new value to the `AgenticLoop` enum
|
||||
3. Update the `_initialize_loop` method in `ComputerAgent` to handle the new loop type
|
||||
34
libs/agent/agent/core/__init__.py
Normal file
34
libs/agent/agent/core/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Core agent components."""
|
||||
|
||||
from .base_agent import BaseComputerAgent
|
||||
from .loop import BaseLoop
|
||||
from .messages import (
|
||||
create_user_message,
|
||||
create_assistant_message,
|
||||
create_system_message,
|
||||
create_image_message,
|
||||
create_screen_message,
|
||||
BaseMessageManager,
|
||||
ImageRetentionConfig,
|
||||
)
|
||||
from .callbacks import (
|
||||
CallbackManager,
|
||||
CallbackHandler,
|
||||
BaseCallbackManager,
|
||||
ContentCallback,
|
||||
ToolCallback,
|
||||
APICallback,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseComputerAgent",
|
||||
"BaseLoop",
|
||||
"CallbackManager",
|
||||
"CallbackHandler",
|
||||
"BaseMessageManager",
|
||||
"ImageRetentionConfig",
|
||||
"BaseCallbackManager",
|
||||
"ContentCallback",
|
||||
"ToolCallback",
|
||||
"APICallback",
|
||||
]
|
||||
252
libs/agent/agent/core/agent.py
Normal file
252
libs/agent/agent/core/agent.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Unified computer agent implementation that supports multiple loops."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from computer import Computer
|
||||
|
||||
from ..types.base import Provider, AgentLoop
|
||||
from .base_agent import BaseComputerAgent
|
||||
from ..core.telemetry import record_agent_initialization
|
||||
|
||||
# Only import types for type checking to avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
|
||||
# Import the provider types
|
||||
from ..providers.omni.types import LLMProvider, LLM, Model, LLMModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default models for different providers
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
# Map providers to their environment variable names
|
||||
ENV_VARS = {
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
class ComputerAgent(BaseComputerAgent):
|
||||
"""Unified implementation of the computer agent supporting multiple loop types.
|
||||
|
||||
This class consolidates the previous AnthropicComputerAgent and OmniComputerAgent
|
||||
into a single implementation with configurable loop type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
loop: AgentLoop = AgentLoop.OMNI,
|
||||
model: Optional[Union[LLM, Dict[str, str], str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
save_trajectory: bool = True,
|
||||
trajectory_dir: Optional[str] = "trajectories",
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
max_retries: int = 3,
|
||||
verbosity: int = logging.INFO,
|
||||
telemetry_enabled: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize a ComputerAgent instance.
|
||||
|
||||
Args:
|
||||
computer: The Computer instance to control
|
||||
loop: The agent loop to use: ANTHROPIC or OMNI
|
||||
model: The model to use. Can be a string, dict or LLM object.
|
||||
Defaults to LLM for the loop type.
|
||||
api_key: The API key to use. If None, will use environment variables.
|
||||
save_trajectory: Whether to save the trajectory.
|
||||
trajectory_dir: The directory to save trajectories to.
|
||||
only_n_most_recent_images: Only keep this many most recent images.
|
||||
max_retries: Maximum number of retries for failed requests.
|
||||
verbosity: Logging level (standard Python logging levels).
|
||||
telemetry_enabled: Whether to enable telemetry tracking. Defaults to True.
|
||||
**kwargs: Additional keyword arguments to pass to the loop.
|
||||
"""
|
||||
super().__init__(computer)
|
||||
self._configure_logging(verbosity)
|
||||
logger.info(f"Initializing ComputerAgent with {loop} loop")
|
||||
|
||||
# Store telemetry preference
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
# Process the model configuration
|
||||
self.model = self._process_model_config(model, loop)
|
||||
self.loop_type = loop
|
||||
self.api_key = api_key
|
||||
|
||||
# Store computer
|
||||
self.computer = computer
|
||||
|
||||
# Save trajectory settings
|
||||
self.save_trajectory = save_trajectory
|
||||
self.trajectory_dir = trajectory_dir
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
|
||||
# Store the max retries setting
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Initialize message history
|
||||
self.messages = []
|
||||
|
||||
# Extra kwargs for the loop
|
||||
self.loop_kwargs = kwargs
|
||||
|
||||
# Initialize the actual loop implementation
|
||||
self.loop = self._init_loop()
|
||||
|
||||
# Record initialization in telemetry if enabled
|
||||
if telemetry_enabled:
|
||||
record_agent_initialization()
|
||||
|
||||
def _process_model_config(
|
||||
self, model_input: Optional[Union[LLM, Dict[str, str], str]], loop: AgentLoop
|
||||
) -> LLM:
|
||||
"""Process and normalize model configuration.
|
||||
|
||||
Args:
|
||||
model_input: Input model configuration (LLM, dict, string, or None)
|
||||
loop: The loop type being used
|
||||
|
||||
Returns:
|
||||
Normalized LLM instance
|
||||
"""
|
||||
# Handle case where model_input is None
|
||||
if model_input is None:
|
||||
# Use Anthropic for Anthropic loop, OpenAI for Omni loop
|
||||
default_provider = (
|
||||
LLMProvider.ANTHROPIC if loop == AgentLoop.ANTHROPIC else LLMProvider.OPENAI
|
||||
)
|
||||
return LLM(provider=default_provider)
|
||||
|
||||
# Handle case where model_input is already a LLM or one of its aliases
|
||||
if isinstance(model_input, (LLM, Model, LLMModel)):
|
||||
return model_input
|
||||
|
||||
# Handle case where model_input is a dict
|
||||
if isinstance(model_input, dict):
|
||||
provider = model_input.get("provider", LLMProvider.OPENAI)
|
||||
if isinstance(provider, str):
|
||||
provider = LLMProvider(provider)
|
||||
return LLM(provider=provider, name=model_input.get("name"))
|
||||
|
||||
# Handle case where model_input is a string (model name)
|
||||
if isinstance(model_input, str):
|
||||
default_provider = (
|
||||
LLMProvider.ANTHROPIC if loop == AgentLoop.ANTHROPIC else LLMProvider.OPENAI
|
||||
)
|
||||
return LLM(provider=default_provider, name=model_input)
|
||||
|
||||
raise ValueError(f"Unsupported model configuration: {model_input}")
|
||||
|
||||
def _configure_logging(self, verbosity: int):
|
||||
"""Configure logging based on verbosity level."""
|
||||
# Use the logging level directly without mapping
|
||||
logger.setLevel(verbosity)
|
||||
logging.getLogger("agent").setLevel(verbosity)
|
||||
|
||||
# Log the verbosity level that was set
|
||||
if verbosity <= logging.DEBUG:
|
||||
logger.info("Agent logging set to DEBUG level (full debug information)")
|
||||
elif verbosity <= logging.INFO:
|
||||
logger.info("Agent logging set to INFO level (standard output)")
|
||||
elif verbosity <= logging.WARNING:
|
||||
logger.warning("Agent logging set to WARNING level (warnings and errors only)")
|
||||
elif verbosity <= logging.ERROR:
|
||||
logger.warning("Agent logging set to ERROR level (errors only)")
|
||||
elif verbosity <= logging.CRITICAL:
|
||||
logger.warning("Agent logging set to CRITICAL level (critical errors only)")
|
||||
|
||||
def _init_loop(self) -> Any:
|
||||
"""Initialize the loop based on the loop_type.
|
||||
|
||||
Returns:
|
||||
Initialized loop instance
|
||||
"""
|
||||
# Lazy import OmniLoop and OmniParser to avoid circular imports
|
||||
from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
|
||||
if self.loop_type == AgentLoop.ANTHROPIC:
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
|
||||
# Ensure we always have a valid model name
|
||||
model_name = self.model.name or DEFAULT_MODELS[LLMProvider.ANTHROPIC]
|
||||
|
||||
return AnthropicLoop(
|
||||
api_key=self.api_key,
|
||||
model=model_name,
|
||||
computer=self.computer,
|
||||
save_trajectory=self.save_trajectory,
|
||||
base_dir=self.trajectory_dir,
|
||||
only_n_most_recent_images=self.only_n_most_recent_images,
|
||||
**self.loop_kwargs,
|
||||
)
|
||||
|
||||
# Initialize parser for OmniLoop with appropriate device
|
||||
if "parser" not in self.loop_kwargs:
|
||||
self.loop_kwargs["parser"] = OmniParser()
|
||||
|
||||
# Ensure we always have a valid model name
|
||||
model_name = self.model.name or DEFAULT_MODELS[self.model.provider]
|
||||
|
||||
return OmniLoop(
|
||||
provider=self.model.provider,
|
||||
api_key=self.api_key,
|
||||
model=model_name,
|
||||
computer=self.computer,
|
||||
save_trajectory=self.save_trajectory,
|
||||
base_dir=self.trajectory_dir,
|
||||
only_n_most_recent_images=self.only_n_most_recent_images,
|
||||
**self.loop_kwargs,
|
||||
)
|
||||
|
||||
async def _execute_task(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Execute a task using the appropriate agent loop.
|
||||
|
||||
Args:
|
||||
task: The task to execute
|
||||
|
||||
Returns:
|
||||
AsyncGenerator yielding task outputs
|
||||
"""
|
||||
logger.info(f"Executing task: {task}")
|
||||
|
||||
try:
|
||||
# Create a message from the task
|
||||
task_message = {"role": "user", "content": task}
|
||||
messages_with_task = self.messages + [task_message]
|
||||
|
||||
# Use the run method of the loop
|
||||
async for output in self.loop.run(messages_with_task):
|
||||
yield output
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task: {e}")
|
||||
raise
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def _execute_action(self, action_type: str, **action_params) -> Any:
|
||||
"""Execute an action with telemetry tracking."""
|
||||
try:
|
||||
# Execute the action
|
||||
result = await super()._execute_action(action_type, **action_params)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing action {action_type}: {e}")
|
||||
raise
|
||||
finally:
|
||||
pass
|
||||
164
libs/agent/agent/core/base_agent.py
Normal file
164
libs/agent/agent/core/base_agent.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Base computer agent implementation."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from computer import Computer
|
||||
|
||||
from ..types.base import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseComputerAgent(ABC):
|
||||
"""Base class for computer agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
computer: Optional[Computer] = None,
|
||||
screenshot_dir: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the base computer agent.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts
|
||||
computer: Optional Computer instance
|
||||
screenshot_dir: Directory to save screenshots
|
||||
log_dir: Directory to save logs (set to None to disable logging to files)
|
||||
**kwargs: Additional provider-specific arguments
|
||||
"""
|
||||
self.max_retries = max_retries
|
||||
self.computer = computer or Computer()
|
||||
self.queue = asyncio.Queue()
|
||||
self.screenshot_dir = screenshot_dir
|
||||
self.log_dir = log_dir
|
||||
self._retry_count = 0
|
||||
self.provider = Provider.UNKNOWN
|
||||
|
||||
# Setup logging
|
||||
if self.log_dir:
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
logger.info(f"Created logs directory: {self.log_dir}")
|
||||
|
||||
# Setup screenshots directory
|
||||
if self.screenshot_dir:
|
||||
os.makedirs(self.screenshot_dir, exist_ok=True)
|
||||
logger.info(f"Created screenshots directory: {self.screenshot_dir}")
|
||||
|
||||
logger.info("BaseComputerAgent initialized")
|
||||
|
||||
async def run(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Run a task using the computer agent.
|
||||
|
||||
Args:
|
||||
task: Task description
|
||||
|
||||
Yields:
|
||||
Task execution updates
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Running task: {task}")
|
||||
|
||||
# Initialize the computer if needed
|
||||
await self._init_if_needed()
|
||||
|
||||
# Execute the task and yield results
|
||||
# The _execute_task method should be implemented to yield results
|
||||
async for result in self._execute_task(task):
|
||||
yield result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in agent run method: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
async def _init_if_needed(self):
|
||||
"""Initialize the computer interface if it hasn't been initialized yet."""
|
||||
if not self.computer._initialized:
|
||||
logger.info("Computer not initialized, initializing now...")
|
||||
try:
|
||||
# Call run directly without setting the flag first
|
||||
await self.computer.run()
|
||||
logger.info("Computer interface initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing computer interface: {str(e)}")
|
||||
raise
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Initialize the agent when used as a context manager."""
|
||||
logger.info("Entering BaseComputerAgent context")
|
||||
|
||||
# In case the computer wasn't initialized
|
||||
try:
|
||||
# Initialize the computer only if not already initialized
|
||||
logger.info("Checking if computer is already initialized...")
|
||||
if not self.computer._initialized:
|
||||
logger.info("Initializing computer in __aenter__...")
|
||||
# Use the computer's __aenter__ directly instead of calling run()
|
||||
# This avoids the circular dependency
|
||||
await self.computer.__aenter__()
|
||||
logger.info("Computer initialized in __aenter__")
|
||||
else:
|
||||
logger.info("Computer already initialized, skipping initialization")
|
||||
|
||||
# Take a test screenshot to verify the computer is working
|
||||
logger.info("Testing computer with a screenshot...")
|
||||
try:
|
||||
test_screenshot = await self.computer.interface.screenshot()
|
||||
# Determine the screenshot size based on its type
|
||||
if isinstance(test_screenshot, bytes):
|
||||
size = len(test_screenshot)
|
||||
else:
|
||||
# Assume it's an object with base64_image attribute
|
||||
try:
|
||||
size = len(test_screenshot.base64_image)
|
||||
except AttributeError:
|
||||
size = "unknown"
|
||||
logger.info(f"Screenshot test successful, size: {size}")
|
||||
except Exception as e:
|
||||
logger.error(f"Screenshot test failed: {str(e)}")
|
||||
# Even though screenshot failed, we continue since some tests might not need it
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing computer in __aenter__: {str(e)}")
|
||||
raise
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Cleanup computer resources if needed."""
|
||||
logger.info("Cleaning up agent resources")
|
||||
|
||||
# Do any necessary cleanup
|
||||
# We're not shutting down the computer here as it might be shared
|
||||
# Just log that we're exiting
|
||||
if exc_type:
|
||||
logger.error(f"Exiting agent context with error: {exc_type.__name__}: {exc_val}")
|
||||
else:
|
||||
logger.info("Exiting agent context normally")
|
||||
|
||||
# If we have a queue, make sure to signal it's done
|
||||
if hasattr(self, "queue") and self.queue:
|
||||
await self.queue.put(None) # Signal that we're done
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_task(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Execute a task. Must be implemented by subclasses.
|
||||
|
||||
This is an async method that returns an AsyncGenerator. Implementations
|
||||
should use 'yield' statements to produce results asynchronously.
|
||||
"""
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": "Base class method called",
|
||||
"metadata": {"title": "Error"},
|
||||
}
|
||||
raise NotImplementedError("Subclasses must implement _execute_task")
|
||||
147
libs/agent/agent/core/callbacks.py
Normal file
147
libs/agent/agent/core/callbacks.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Callback handlers for agent."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ContentCallback(Protocol):
|
||||
"""Protocol for content callbacks."""
|
||||
def __call__(self, content: Dict[str, Any]) -> None: ...
|
||||
|
||||
class ToolCallback(Protocol):
|
||||
"""Protocol for tool callbacks."""
|
||||
def __call__(self, result: Any, tool_id: str) -> None: ...
|
||||
|
||||
class APICallback(Protocol):
|
||||
"""Protocol for API callbacks."""
|
||||
def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
|
||||
|
||||
class BaseCallbackManager(ABC):
|
||||
"""Base class for callback managers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_callback: ContentCallback,
|
||||
tool_callback: ToolCallback,
|
||||
api_callback: APICallback,
|
||||
):
|
||||
"""Initialize the callback manager.
|
||||
|
||||
Args:
|
||||
content_callback: Callback for content updates
|
||||
tool_callback: Callback for tool execution results
|
||||
api_callback: Callback for API interactions
|
||||
"""
|
||||
self.content_callback = content_callback
|
||||
self.tool_callback = tool_callback
|
||||
self.api_callback = api_callback
|
||||
|
||||
@abstractmethod
|
||||
def on_content(self, content: Any) -> None:
|
||||
"""Handle content updates."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_result(self, result: Any, tool_id: str) -> None:
|
||||
"""Handle tool execution results."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_api_interaction(
|
||||
self,
|
||||
request: Any,
|
||||
response: Any,
|
||||
error: Optional[Exception] = None
|
||||
) -> None:
|
||||
"""Handle API interactions."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CallbackManager:
|
||||
"""Manager for callback handlers."""
|
||||
|
||||
def __init__(self, handlers: Optional[List["CallbackHandler"]] = None):
|
||||
"""Initialize with optional handlers.
|
||||
|
||||
Args:
|
||||
handlers: List of callback handlers
|
||||
"""
|
||||
self.handlers = handlers or []
|
||||
|
||||
def add_handler(self, handler: "CallbackHandler") -> None:
|
||||
"""Add a callback handler.
|
||||
|
||||
Args:
|
||||
handler: Callback handler to add
|
||||
"""
|
||||
self.handlers.append(handler)
|
||||
|
||||
async def on_action_start(self, action: str, **kwargs) -> None:
|
||||
"""Called when an action starts.
|
||||
|
||||
Args:
|
||||
action: Action name
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
await handler.on_action_start(action, **kwargs)
|
||||
|
||||
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
||||
"""Called when an action ends.
|
||||
|
||||
Args:
|
||||
action: Action name
|
||||
success: Whether the action was successful
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
await handler.on_action_end(action, success, **kwargs)
|
||||
|
||||
async def on_error(self, error: Exception, **kwargs) -> None:
|
||||
"""Called when an error occurs.
|
||||
|
||||
Args:
|
||||
error: Exception that occurred
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
await handler.on_error(error, **kwargs)
|
||||
|
||||
|
||||
class CallbackHandler(ABC):
|
||||
"""Base class for callback handlers."""
|
||||
|
||||
@abstractmethod
|
||||
async def on_action_start(self, action: str, **kwargs) -> None:
|
||||
"""Called when an action starts.
|
||||
|
||||
Args:
|
||||
action: Action name
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
||||
"""Called when an action ends.
|
||||
|
||||
Args:
|
||||
action: Action name
|
||||
success: Whether the action was successful
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_error(self, error: Exception, **kwargs) -> None:
|
||||
"""Called when an error occurs.
|
||||
|
||||
Args:
|
||||
error: Exception that occurred
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
pass
|
||||
69
libs/agent/agent/core/computer_agent.py
Normal file
69
libs/agent/agent/core/computer_agent.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Main entry point for computer agents."""
|
||||
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from computer import Computer
|
||||
from ..types.base import Provider
|
||||
from .factory import AgentFactory
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComputerAgent:
|
||||
"""A computer agent that can perform automated tasks using natural language instructions."""
|
||||
|
||||
def __init__(self, provider: Provider, computer: Optional[Computer] = None, **kwargs):
|
||||
"""Initialize the ComputerAgent.
|
||||
|
||||
Args:
|
||||
provider: The AI provider to use (e.g., Provider.ANTHROPIC)
|
||||
computer: Optional Computer instance. If not provided, one will be created with default settings.
|
||||
**kwargs: Additional provider-specific arguments
|
||||
"""
|
||||
self.provider = provider
|
||||
self._computer = computer
|
||||
self._kwargs = kwargs
|
||||
self._agent = None
|
||||
self._initialized = False
|
||||
self._in_context = False
|
||||
|
||||
# Create provider-specific agent using factory
|
||||
self._agent = AgentFactory.create(provider=provider, computer=computer, **kwargs)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter the async context manager."""
|
||||
self._in_context = True
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit the async context manager."""
|
||||
self._in_context = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the agent and its components."""
|
||||
if not self._initialized:
|
||||
if not self._in_context and self._computer:
|
||||
# If not in context manager but have a computer, initialize it
|
||||
await self._computer.run()
|
||||
self._initialized = True
|
||||
|
||||
async def run(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Run the agent with a given task."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if self._agent is None:
|
||||
logger.error("Agent not initialized properly")
|
||||
yield {"error": "Agent not initialized properly"}
|
||||
return
|
||||
|
||||
async for result in self._agent.run(task):
|
||||
yield result
|
||||
|
||||
@property
|
||||
def computer(self) -> Optional[Computer]:
|
||||
"""Get the underlying computer instance."""
|
||||
return self._agent.computer if self._agent else None
|
||||
232
libs/agent/agent/core/experiment.py
Normal file
232
libs/agent/agent/core/experiment.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Core experiment management for agents."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from PIL import Image
|
||||
import json
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExperimentManager:
|
||||
"""Manages experiment directories and logging for the agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: Optional[str] = None,
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the experiment manager.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for saving experiment data
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
"""
|
||||
self.base_dir = base_dir
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self.run_dir = None
|
||||
self.current_turn_dir = None
|
||||
self.turn_count = 0
|
||||
self.screenshot_count = 0
|
||||
# Track all screenshots for potential API request inclusion
|
||||
self.screenshot_paths = []
|
||||
|
||||
# Set up experiment directories if base_dir is provided
|
||||
if self.base_dir:
|
||||
self.setup_experiment_dirs()
|
||||
|
||||
def setup_experiment_dirs(self) -> None:
|
||||
"""Setup the experiment directory structure."""
|
||||
if not self.base_dir:
|
||||
return
|
||||
|
||||
# Create base experiments directory if it doesn't exist
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
|
||||
# Create timestamped run directory
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.run_dir = os.path.join(self.base_dir, timestamp)
|
||||
os.makedirs(self.run_dir, exist_ok=True)
|
||||
logger.info(f"Created run directory: {self.run_dir}")
|
||||
|
||||
# Create first turn directory
|
||||
self.create_turn_dir()
|
||||
|
||||
def create_turn_dir(self) -> None:
|
||||
"""Create a new directory for the current turn."""
|
||||
if not self.run_dir:
|
||||
logger.warning("Cannot create turn directory: run_dir not set")
|
||||
return
|
||||
|
||||
# Increment turn counter
|
||||
self.turn_count += 1
|
||||
|
||||
# Create turn directory with padded number
|
||||
turn_name = f"turn_{self.turn_count:03d}"
|
||||
self.current_turn_dir = os.path.join(self.run_dir, turn_name)
|
||||
os.makedirs(self.current_turn_dir, exist_ok=True)
|
||||
logger.info(f"Created turn directory: {self.current_turn_dir}")
|
||||
|
||||
def sanitize_log_data(self, data: Any) -> Any:
|
||||
"""Sanitize log data by replacing large binary data with placeholders.
|
||||
|
||||
Args:
|
||||
data: Data to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized copy of the data
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
result = {}
|
||||
for k, v in data.items():
|
||||
result[k] = self.sanitize_log_data(v)
|
||||
return result
|
||||
elif isinstance(data, list):
|
||||
return [self.sanitize_log_data(item) for item in data]
|
||||
elif isinstance(data, str) and len(data) > 1000 and "base64" in data.lower():
|
||||
return f"[BASE64_DATA_LENGTH_{len(data)}]"
|
||||
else:
|
||||
return data
|
||||
|
||||
def save_screenshot(self, img_base64: str, action_type: str = "") -> None:
|
||||
"""Save a screenshot to the experiment directory.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return
|
||||
|
||||
try:
|
||||
# Increment screenshot counter
|
||||
self.screenshot_count += 1
|
||||
|
||||
# Sanitize action_type to ensure valid filename
|
||||
# Replace characters that are not safe for filenames
|
||||
sanitized_action = ""
|
||||
if action_type:
|
||||
# Replace invalid filename characters with underscores
|
||||
sanitized_action = re.sub(r'[\\/*?:"<>|]', "_", action_type)
|
||||
# Limit the length to avoid excessively long filenames
|
||||
sanitized_action = sanitized_action[:50]
|
||||
|
||||
# Create a descriptive filename
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
action_suffix = f"_{sanitized_action}" if sanitized_action else ""
|
||||
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
|
||||
|
||||
# Save directly to the turn directory
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Save the screenshot
|
||||
img_data = base64.b64decode(img_base64)
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(img_data)
|
||||
|
||||
# Keep track of the file path
|
||||
self.screenshot_paths.append(filepath)
|
||||
|
||||
return filepath
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving screenshot: {str(e)}")
|
||||
return None
|
||||
|
||||
def save_action_visualization(
|
||||
self, img: Image.Image, action_name: str, details: str = ""
|
||||
) -> str:
|
||||
"""Save a visualization of an action.
|
||||
|
||||
Args:
|
||||
img: Image to save
|
||||
action_name: Name of the action
|
||||
details: Additional details about the action
|
||||
|
||||
Returns:
|
||||
Path to the saved image
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Create a descriptive filename
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
details_suffix = f"_{details}" if details else ""
|
||||
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
|
||||
|
||||
# Save directly to the turn directory
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Save the image
|
||||
img.save(filepath)
|
||||
|
||||
# Keep track of the file path
|
||||
self.screenshot_paths.append(filepath)
|
||||
|
||||
return filepath
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving action visualization: {str(e)}")
|
||||
return ""
|
||||
|
||||
def log_api_call(
|
||||
self,
|
||||
call_type: str,
|
||||
request: Any,
|
||||
provider: str = "unknown",
|
||||
model: str = "unknown",
|
||||
response: Any = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> None:
|
||||
"""Log API call details to file.
|
||||
|
||||
Args:
|
||||
call_type: Type of API call (request, response, error)
|
||||
request: Request data
|
||||
provider: API provider name
|
||||
model: Model name
|
||||
response: Response data (for response logs)
|
||||
error: Error information (for error logs)
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
logger.warning("Cannot log API call: current_turn_dir not set")
|
||||
return
|
||||
|
||||
try:
|
||||
# Create a timestamp for the log file
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Create filename based on log type
|
||||
filename = f"api_call_{timestamp}_{call_type}.json"
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Sanitize data before logging
|
||||
sanitized_request = self.sanitize_log_data(request)
|
||||
sanitized_response = self.sanitize_log_data(response) if response is not None else None
|
||||
|
||||
# Prepare log data
|
||||
log_data = {
|
||||
"timestamp": timestamp,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"type": call_type,
|
||||
"request": sanitized_request,
|
||||
}
|
||||
|
||||
if sanitized_response is not None:
|
||||
log_data["response"] = sanitized_response
|
||||
if error is not None:
|
||||
log_data["error"] = str(error)
|
||||
|
||||
# Write to file
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(log_data, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Logged API {call_type} to {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging API call: {str(e)}")
|
||||
102
libs/agent/agent/core/factory.py
Normal file
102
libs/agent/agent/core/factory.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Factory for creating provider-specific agents."""
|
||||
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from computer import Computer
|
||||
from ..types.base import Provider
|
||||
from .base_agent import BaseComputerAgent
|
||||
|
||||
# Import provider-specific implementations
|
||||
_ANTHROPIC_AVAILABLE = False
|
||||
_OPENAI_AVAILABLE = False
|
||||
_OLLAMA_AVAILABLE = False
|
||||
_OMNI_AVAILABLE = False
|
||||
|
||||
# Try importing providers
|
||||
try:
|
||||
import anthropic
|
||||
from ..providers.anthropic.agent import AnthropicComputerAgent
|
||||
|
||||
_ANTHROPIC_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
_OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from ..providers.omni.agent import OmniComputerAgent
|
||||
|
||||
_OMNI_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class AgentFactory:
|
||||
"""Factory for creating provider-specific agent implementations."""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
provider: Provider, computer: Optional[Computer] = None, **kwargs: Any
|
||||
) -> BaseComputerAgent:
|
||||
"""Create an agent based on the specified provider.
|
||||
|
||||
Args:
|
||||
provider: The AI provider to use
|
||||
computer: Optional Computer instance
|
||||
**kwargs: Additional provider-specific arguments
|
||||
|
||||
Returns:
|
||||
A provider-specific agent implementation
|
||||
|
||||
Raises:
|
||||
ImportError: If provider dependencies are not installed
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
# Create a Computer instance if none is provided
|
||||
if computer is None:
|
||||
computer = Computer()
|
||||
|
||||
if provider == Provider.ANTHROPIC:
|
||||
if not _ANTHROPIC_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Anthropic provider requires additional dependencies. "
|
||||
"Install them with: pip install cua-agent[anthropic]"
|
||||
)
|
||||
return AnthropicComputerAgent(max_retries=3, computer=computer, **kwargs)
|
||||
elif provider == Provider.OPENAI:
|
||||
if not _OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenAI provider requires additional dependencies. "
|
||||
"Install them with: pip install cua-agent[openai]"
|
||||
)
|
||||
raise NotImplementedError("OpenAI provider not yet implemented")
|
||||
elif provider == Provider.OLLAMA:
|
||||
if not _OLLAMA_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Ollama provider requires additional dependencies. "
|
||||
"Install them with: pip install cua-agent[ollama]"
|
||||
)
|
||||
# Only import ollama when actually creating an Ollama agent
|
||||
try:
|
||||
import ollama
|
||||
from ..providers.ollama.agent import OllamaComputerAgent
|
||||
|
||||
return OllamaComputerAgent(max_retries=3, computer=computer, **kwargs)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Failed to import ollama package. " "Install it with: pip install ollama"
|
||||
)
|
||||
elif provider == Provider.OMNI:
|
||||
if not _OMNI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Omni provider requires additional dependencies. "
|
||||
"Install them with: pip install cua-agent[omni]"
|
||||
)
|
||||
return OmniComputerAgent(max_retries=3, computer=computer, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
244
libs/agent/agent/core/loop.py
Normal file
244
libs/agent/agent/core/loop.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Base agent loop implementation."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import base64
|
||||
|
||||
from computer import Computer
|
||||
from .experiment import ExperimentManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLoop(ABC):
|
||||
"""Base class for agent loops that handle message processing and tool execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
model: str,
|
||||
api_key: str,
|
||||
max_tokens: int = 4096,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
save_trajectory: bool = True,
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize base agent loop.
|
||||
|
||||
Args:
|
||||
computer: Computer instance to control
|
||||
model: Model name to use
|
||||
api_key: API key for provider
|
||||
max_tokens: Maximum tokens to generate
|
||||
max_retries: Maximum number of retries
|
||||
retry_delay: Delay between retries in seconds
|
||||
base_dir: Base directory for saving experiment data
|
||||
save_trajectory: Whether to save trajectory data
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
**kwargs: Additional provider-specific arguments
|
||||
"""
|
||||
self.computer = computer
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.max_tokens = max_tokens
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.base_dir = base_dir
|
||||
self.save_trajectory = save_trajectory
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self._kwargs = kwargs
|
||||
self.message_history = []
|
||||
# self.tool_manager = BaseToolManager(computer)
|
||||
|
||||
# Initialize experiment manager
|
||||
if self.save_trajectory and self.base_dir:
|
||||
self.experiment_manager = ExperimentManager(
|
||||
base_dir=self.base_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
)
|
||||
# Track directories for convenience
|
||||
self.run_dir = self.experiment_manager.run_dir
|
||||
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
||||
else:
|
||||
self.experiment_manager = None
|
||||
self.run_dir = None
|
||||
self.current_turn_dir = None
|
||||
|
||||
# Initialize basic tracking
|
||||
self.turn_count = 0
|
||||
|
||||
def _setup_experiment_dirs(self) -> None:
|
||||
"""Setup the experiment directory structure."""
|
||||
if self.experiment_manager:
|
||||
# Use the experiment manager to set up directories
|
||||
self.experiment_manager.setup_experiment_dirs()
|
||||
|
||||
# Update local tracking variables
|
||||
self.run_dir = self.experiment_manager.run_dir
|
||||
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
||||
|
||||
def _create_turn_dir(self) -> None:
|
||||
"""Create a new directory for the current turn."""
|
||||
if self.experiment_manager:
|
||||
# Use the experiment manager to create the turn directory
|
||||
self.experiment_manager.create_turn_dir()
|
||||
|
||||
# Update local tracking variables
|
||||
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
||||
self.turn_count = self.experiment_manager.turn_count
|
||||
|
||||
def _log_api_call(
|
||||
self, call_type: str, request: Any, response: Any = None, error: Optional[Exception] = None
|
||||
) -> None:
|
||||
"""Log API call details to file.
|
||||
|
||||
Args:
|
||||
call_type: Type of API call (e.g., 'request', 'response', 'error')
|
||||
request: The API request data
|
||||
response: Optional API response data
|
||||
error: Optional error information
|
||||
"""
|
||||
if self.experiment_manager:
|
||||
# Use the experiment manager to log the API call
|
||||
provider = getattr(self, "provider", "unknown")
|
||||
provider_str = str(provider) if provider else "unknown"
|
||||
|
||||
self.experiment_manager.log_api_call(
|
||||
call_type=call_type,
|
||||
request=request,
|
||||
provider=provider_str,
|
||||
model=self.model,
|
||||
response=response,
|
||||
error=error,
|
||||
)
|
||||
|
||||
def _save_screenshot(self, img_base64: str, action_type: str = "") -> None:
|
||||
"""Save a screenshot to the experiment directory.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
"""
|
||||
if self.experiment_manager:
|
||||
self.experiment_manager.save_screenshot(img_base64, action_type)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize both the API client and computer interface with retries."""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
logger.info(
|
||||
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
|
||||
)
|
||||
|
||||
# Initialize API client
|
||||
await self.initialize_client()
|
||||
|
||||
# Initialize computer
|
||||
await self.computer.initialize()
|
||||
|
||||
logger.info("Initialization complete.")
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.warning(
|
||||
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
|
||||
)
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(f"Failed to initialize: {str(e)}")
|
||||
|
||||
async def _get_parsed_screen_som(self) -> Dict[str, Any]:
|
||||
"""Get parsed screen information.
|
||||
|
||||
Returns:
|
||||
Dict containing screen information
|
||||
"""
|
||||
try:
|
||||
# Take screenshot
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
|
||||
# Initialize with default values
|
||||
width, height = 1024, 768
|
||||
base64_image = ""
|
||||
|
||||
# Handle different types of screenshot returns
|
||||
if isinstance(screenshot, bytes):
|
||||
# Raw bytes screenshot
|
||||
base64_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
elif hasattr(screenshot, "base64_image"):
|
||||
# Object-style screenshot with attributes
|
||||
base64_image = screenshot.base64_image
|
||||
if hasattr(screenshot, "width") and hasattr(screenshot, "height"):
|
||||
width = screenshot.width
|
||||
height = screenshot.height
|
||||
|
||||
# Create parsed screen data
|
||||
parsed_screen = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"parsed_content_list": [],
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"screenshot_base64": base64_image,
|
||||
}
|
||||
|
||||
# Save screenshot if requested
|
||||
if self.save_trajectory and self.experiment_manager:
|
||||
try:
|
||||
img_data = base64_image
|
||||
if "," in img_data:
|
||||
img_data = img_data.split(",")[1]
|
||||
self._save_screenshot(img_data, action_type="state")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving screenshot: {str(e)}")
|
||||
|
||||
return parsed_screen
|
||||
except Exception as e:
|
||||
logger.error(f"Error taking screenshot: {str(e)}")
|
||||
return {
|
||||
"width": 1024,
|
||||
"height": 768,
|
||||
"parsed_content_list": [],
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"error": f"Error taking screenshot: {str(e)}",
|
||||
"screenshot_base64": "",
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the API client and any provider-specific components."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Yields:
|
||||
Dict containing response data
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _process_screen(
|
||||
self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Process screen information and add to messages.
|
||||
|
||||
Args:
|
||||
parsed_screen: Dictionary containing parsed screen info
|
||||
messages: List of messages to update
|
||||
"""
|
||||
raise NotImplementedError
|
||||
245
libs/agent/agent/core/messages.py
Normal file
245
libs/agent/agent/core/messages.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Message handling utilities for agent."""
|
||||
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageRetentionConfig:
|
||||
"""Configuration for image retention in messages."""
|
||||
|
||||
num_images_to_keep: Optional[int] = None
|
||||
min_removal_threshold: int = 1
|
||||
enable_caching: bool = True
|
||||
|
||||
def should_retain_images(self) -> bool:
|
||||
"""Check if image retention is enabled."""
|
||||
return self.num_images_to_keep is not None and self.num_images_to_keep > 0
|
||||
|
||||
|
||||
class BaseMessageManager:
|
||||
"""Base class for message preparation and management."""
|
||||
|
||||
def __init__(self, image_retention_config: Optional[ImageRetentionConfig] = None):
|
||||
"""Initialize the message manager.
|
||||
|
||||
Args:
|
||||
image_retention_config: Configuration for image retention
|
||||
"""
|
||||
self.image_retention_config = image_retention_config or ImageRetentionConfig()
|
||||
if self.image_retention_config.min_removal_threshold < 1:
|
||||
raise ValueError("min_removal_threshold must be at least 1")
|
||||
|
||||
# Track provider for message formatting
|
||||
self.provider = "openai" # Default provider
|
||||
|
||||
def set_provider(self, provider: str) -> None:
|
||||
"""Set the current provider to format messages for.
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., 'openai', 'anthropic')
|
||||
"""
|
||||
self.provider = provider.lower()
|
||||
|
||||
def prepare_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Prepare messages by applying image retention and caching as configured.
|
||||
|
||||
Args:
|
||||
messages: List of messages to prepare
|
||||
|
||||
Returns:
|
||||
Prepared messages
|
||||
"""
|
||||
if self.image_retention_config.should_retain_images():
|
||||
self._filter_images(messages)
|
||||
if self.image_retention_config.enable_caching:
|
||||
self._inject_caching(messages)
|
||||
return messages
|
||||
|
||||
def _filter_images(self, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Filter messages to retain only the specified number of most recent images.
|
||||
|
||||
Args:
|
||||
messages: Messages to filter
|
||||
"""
|
||||
# Find all tool result blocks that contain images
|
||||
tool_results = [
|
||||
item
|
||||
for message in messages
|
||||
for item in (message["content"] if isinstance(message["content"], list) else [])
|
||||
if isinstance(item, dict) and item.get("type") == "tool_result"
|
||||
]
|
||||
|
||||
# Count total images
|
||||
total_images = sum(
|
||||
1
|
||||
for result in tool_results
|
||||
for content in result.get("content", [])
|
||||
if isinstance(content, dict) and content.get("type") == "image"
|
||||
)
|
||||
|
||||
# Calculate how many images to remove
|
||||
images_to_remove = total_images - (self.image_retention_config.num_images_to_keep or 0)
|
||||
images_to_remove -= images_to_remove % self.image_retention_config.min_removal_threshold
|
||||
|
||||
# Remove oldest images first
|
||||
for result in tool_results:
|
||||
if isinstance(result.get("content"), list):
|
||||
new_content = []
|
||||
for content in result["content"]:
|
||||
if isinstance(content, dict) and content.get("type") == "image":
|
||||
if images_to_remove > 0:
|
||||
images_to_remove -= 1
|
||||
continue
|
||||
new_content.append(content)
|
||||
result["content"] = new_content
|
||||
|
||||
def _inject_caching(self, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Inject caching control for recent message turns.
|
||||
|
||||
Args:
|
||||
messages: Messages to inject caching into
|
||||
"""
|
||||
# Only apply cache_control for Anthropic API, not OpenAI
|
||||
if self.provider != "anthropic":
|
||||
return
|
||||
|
||||
# Default to caching last 3 turns
|
||||
turns_to_cache = 3
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "user" and isinstance(content := message["content"], list):
|
||||
if turns_to_cache:
|
||||
turns_to_cache -= 1
|
||||
content[-1]["cache_control"] = {"type": "ephemeral"}
|
||||
else:
|
||||
content[-1].pop("cache_control", None)
|
||||
break
|
||||
|
||||
|
||||
def create_user_message(text: str) -> Dict[str, str]:
|
||||
"""Create a user message.
|
||||
|
||||
Args:
|
||||
text: The message text
|
||||
|
||||
Returns:
|
||||
Message dictionary
|
||||
"""
|
||||
return {
|
||||
"role": "user",
|
||||
"content": text,
|
||||
}
|
||||
|
||||
|
||||
def create_assistant_message(text: str) -> Dict[str, str]:
|
||||
"""Create an assistant message.
|
||||
|
||||
Args:
|
||||
text: The message text
|
||||
|
||||
Returns:
|
||||
Message dictionary
|
||||
"""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
}
|
||||
|
||||
|
||||
def create_system_message(text: str) -> Dict[str, str]:
|
||||
"""Create a system message.
|
||||
|
||||
Args:
|
||||
text: The message text
|
||||
|
||||
Returns:
|
||||
Message dictionary
|
||||
"""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": text,
|
||||
}
|
||||
|
||||
|
||||
def create_image_message(
|
||||
image_base64: Optional[str] = None,
|
||||
image_path: Optional[str] = None,
|
||||
image_obj: Optional[Image.Image] = None,
|
||||
) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
|
||||
"""Create a message with an image.
|
||||
|
||||
Args:
|
||||
image_base64: Base64 encoded image
|
||||
image_path: Path to image file
|
||||
image_obj: PIL Image object
|
||||
|
||||
Returns:
|
||||
Message dictionary with content list
|
||||
|
||||
Raises:
|
||||
ValueError: If no image source is provided
|
||||
"""
|
||||
if not any([image_base64, image_path, image_obj]):
|
||||
raise ValueError("Must provide one of image_base64, image_path, or image_obj")
|
||||
|
||||
# Convert to base64 if needed
|
||||
if image_path and not image_base64:
|
||||
with open(image_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
elif image_obj and not image_base64:
|
||||
buffer = BytesIO()
|
||||
image_obj.save(buffer, format="PNG")
|
||||
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def create_screen_message(
|
||||
parsed_screen: Dict[str, Any],
|
||||
include_raw: bool = False,
|
||||
) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
|
||||
"""Create a message with screen information.
|
||||
|
||||
Args:
|
||||
parsed_screen: Dictionary containing parsed screen info
|
||||
include_raw: Whether to include raw screenshot base64
|
||||
|
||||
Returns:
|
||||
Message dictionary with content
|
||||
"""
|
||||
if include_raw and "screenshot_base64" in parsed_screen:
|
||||
# Create content list with both image and text
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{parsed_screen['screenshot_base64']}"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
|
||||
},
|
||||
],
|
||||
}
|
||||
else:
|
||||
# Create text-only message with screen info
|
||||
return {
|
||||
"role": "user",
|
||||
"content": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
|
||||
}
|
||||
130
libs/agent/agent/core/telemetry.py
Normal file
130
libs/agent/agent/core/telemetry.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Agent telemetry for tracking anonymous usage and feature usage."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
# Import the core telemetry module
|
||||
TELEMETRY_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from core.telemetry import (
|
||||
record_event,
|
||||
increment,
|
||||
get_telemetry_client,
|
||||
flush,
|
||||
is_telemetry_enabled,
|
||||
is_telemetry_globally_disabled,
|
||||
)
|
||||
|
||||
def increment_counter(counter_name: str, value: int = 1) -> None:
|
||||
"""Wrapper for increment to maintain backward compatibility."""
|
||||
if is_telemetry_enabled():
|
||||
increment(counter_name, value)
|
||||
|
||||
def set_dimension(name: str, value: Any) -> None:
|
||||
"""Set a dimension that will be attached to all events."""
|
||||
logger = logging.getLogger("cua.agent.telemetry")
|
||||
logger.debug(f"Setting dimension {name}={value}")
|
||||
|
||||
TELEMETRY_AVAILABLE = True
|
||||
logger = logging.getLogger("cua.agent.telemetry")
|
||||
logger.info("Successfully imported telemetry")
|
||||
except ImportError as e:
|
||||
logger = logging.getLogger("cua.agent.telemetry")
|
||||
logger.warning(f"Could not import telemetry: {e}")
|
||||
TELEMETRY_AVAILABLE = False
|
||||
|
||||
|
||||
# Local fallbacks in case core telemetry isn't available
|
||||
def _noop(*args: Any, **kwargs: Any) -> None:
|
||||
"""No-op function for when telemetry is not available."""
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger("cua.agent.telemetry")
|
||||
|
||||
# If telemetry isn't available, use no-op functions
|
||||
if not TELEMETRY_AVAILABLE:
|
||||
logger.debug("Telemetry not available, using no-op functions")
|
||||
record_event = _noop # type: ignore
|
||||
increment_counter = _noop # type: ignore
|
||||
set_dimension = _noop # type: ignore
|
||||
get_telemetry_client = lambda: None # type: ignore
|
||||
flush = _noop # type: ignore
|
||||
is_telemetry_enabled = lambda: False # type: ignore
|
||||
is_telemetry_globally_disabled = lambda: True # type: ignore
|
||||
|
||||
# Get system info once to use in telemetry
|
||||
SYSTEM_INFO = {
|
||||
"os": platform.system().lower(),
|
||||
"os_version": platform.release(),
|
||||
"python_version": platform.python_version(),
|
||||
}
|
||||
|
||||
|
||||
def enable_telemetry() -> bool:
|
||||
"""Enable telemetry if available.
|
||||
|
||||
Returns:
|
||||
bool: True if telemetry was successfully enabled, False otherwise
|
||||
"""
|
||||
global TELEMETRY_AVAILABLE
|
||||
|
||||
# Check if globally disabled using core function
|
||||
if TELEMETRY_AVAILABLE and is_telemetry_globally_disabled():
|
||||
logger.info("Telemetry is globally disabled via environment variable - cannot enable")
|
||||
return False
|
||||
|
||||
# Already enabled
|
||||
if TELEMETRY_AVAILABLE:
|
||||
return True
|
||||
|
||||
# Try to import and enable
|
||||
try:
|
||||
from core.telemetry import (
|
||||
record_event,
|
||||
increment,
|
||||
get_telemetry_client,
|
||||
flush,
|
||||
is_telemetry_globally_disabled,
|
||||
)
|
||||
|
||||
# Check again after import
|
||||
if is_telemetry_globally_disabled():
|
||||
logger.info("Telemetry is globally disabled via environment variable - cannot enable")
|
||||
return False
|
||||
|
||||
TELEMETRY_AVAILABLE = True
|
||||
logger.info("Telemetry successfully enabled")
|
||||
return True
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not enable telemetry: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def is_telemetry_enabled() -> bool:
|
||||
"""Check if telemetry is enabled.
|
||||
|
||||
Returns:
|
||||
bool: True if telemetry is enabled, False otherwise
|
||||
"""
|
||||
# Use the core function if available, otherwise use our local flag
|
||||
if TELEMETRY_AVAILABLE:
|
||||
from core.telemetry import is_telemetry_enabled as core_is_enabled
|
||||
|
||||
return core_is_enabled()
|
||||
return False
|
||||
|
||||
|
||||
def record_agent_initialization() -> None:
|
||||
"""Record when an agent instance is initialized."""
|
||||
if TELEMETRY_AVAILABLE and is_telemetry_enabled():
|
||||
record_event("agent_initialized", SYSTEM_INFO)
|
||||
|
||||
# Set dimensions that will be attached to all events
|
||||
set_dimension("os", SYSTEM_INFO["os"])
|
||||
set_dimension("os_version", SYSTEM_INFO["os_version"])
|
||||
set_dimension("python_version", SYSTEM_INFO["python_version"])
|
||||
21
libs/agent/agent/core/tools/__init__.py
Normal file
21
libs/agent/agent/core/tools/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Core tools package."""
|
||||
|
||||
from .base import BaseTool, ToolResult, ToolError, ToolFailure, CLIResult
|
||||
from .bash import BaseBashTool
|
||||
from .collection import ToolCollection
|
||||
from .computer import BaseComputerTool
|
||||
from .edit import BaseEditTool
|
||||
from .manager import BaseToolManager
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"ToolResult",
|
||||
"ToolError",
|
||||
"ToolFailure",
|
||||
"CLIResult",
|
||||
"BaseBashTool",
|
||||
"BaseComputerTool",
|
||||
"BaseEditTool",
|
||||
"ToolCollection",
|
||||
"BaseToolManager",
|
||||
]
|
||||
74
libs/agent/agent/core/tools/base.py
Normal file
74
libs/agent/agent/core/tools/base.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Abstract base classes for tools that can be used with any provider."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class BaseTool(metaclass=ABCMeta):
|
||||
"""Abstract base class for provider-agnostic tools."""
|
||||
|
||||
name: str
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to provider-specific API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters specific to the LLM provider
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class ToolResult:
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
base64_image: str | None = None
|
||||
system: str | None = None
|
||||
content: list[dict] | None = None
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
content=self.content or other.content, # Use first non-None content
|
||||
)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
52
libs/agent/agent/core/tools/bash.py
Normal file
52
libs/agent/agent/core/tools/bash.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Abstract base bash/shell tool implementation."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
|
||||
class BaseBashTool(BaseTool):
|
||||
"""Base class for bash/shell command execution tools across different providers."""
|
||||
|
||||
name = "bash"
|
||||
logger = logging.getLogger(__name__)
|
||||
computer: Computer
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the BashTool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance, may be used for related operations
|
||||
"""
|
||||
self.computer = computer
|
||||
|
||||
async def run_command(self, command: str) -> Tuple[int, str, str]:
|
||||
"""Run a shell command and return exit code, stdout, and stderr.
|
||||
|
||||
Args:
|
||||
command: Shell command to execute
|
||||
|
||||
Returns:
|
||||
Tuple containing (exit_code, stdout, stderr)
|
||||
"""
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
return process.returncode or 0, stdout.decode(), stderr.decode()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error running command: {str(e)}")
|
||||
return 1, "", str(e)
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Execute the tool with the provided arguments."""
|
||||
raise NotImplementedError
|
||||
46
libs/agent/agent/core/tools/collection.py
Normal file
46
libs/agent/agent/core/tools/collection.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Collection classes for managing multiple tools."""
|
||||
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from .base import (
|
||||
BaseTool,
|
||||
ToolError,
|
||||
ToolFailure,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""A collection of tools that can be used with any provider."""
|
||||
|
||||
def __init__(self, *tools: BaseTool):
|
||||
self.tools = tools
|
||||
self.tool_map = {tool.name: tool for tool in tools}
|
||||
|
||||
def to_params(self) -> List[Dict[str, Any]]:
|
||||
"""Convert all tools to provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with tool parameters
|
||||
"""
|
||||
return [tool.to_params() for tool in self.tools]
|
||||
|
||||
async def run(self, *, name: str, tool_input: Dict[str, Any]) -> ToolResult:
|
||||
"""Run a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to run
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
tool = self.tool_map.get(name)
|
||||
if not tool:
|
||||
return ToolFailure(error=f"Tool {name} is invalid")
|
||||
try:
|
||||
return await tool(**tool_input)
|
||||
except ToolError as e:
|
||||
return ToolFailure(error=e.message)
|
||||
except Exception as e:
|
||||
return ToolFailure(error=f"Unexpected error in tool {name}: {str(e)}")
|
||||
113
libs/agent/agent/core/tools/computer.py
Normal file
113
libs/agent/agent/core/tools/computer.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Abstract base computer tool implementation."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseTool, ToolError, ToolResult
|
||||
|
||||
|
||||
class BaseComputerTool(BaseTool):
|
||||
"""Base class for computer interaction tools across different providers."""
|
||||
|
||||
name = "computer"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
display_num: Optional[int] = None
|
||||
computer: Computer
|
||||
|
||||
_screenshot_delay = 1.0 # Default delay for most platforms
|
||||
_scaling_enabled = True
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the ComputerTool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for screen interactions
|
||||
"""
|
||||
self.computer = computer
|
||||
|
||||
async def initialize_dimensions(self):
|
||||
"""Initialize screen dimensions from the computer interface."""
|
||||
display_size = await self.computer.interface.get_screen_size()
|
||||
self.width = display_size["width"]
|
||||
self.height = display_size["height"]
|
||||
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
||||
|
||||
@property
|
||||
def options(self) -> Dict[str, Any]:
|
||||
"""Get the options for the tool.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool options
|
||||
"""
|
||||
if self.width is None or self.height is None:
|
||||
raise RuntimeError(
|
||||
"Screen dimensions not initialized. Call initialize_dimensions() first."
|
||||
)
|
||||
return {
|
||||
"display_width_px": self.width,
|
||||
"display_height_px": self.height,
|
||||
"display_number": self.display_num,
|
||||
}
|
||||
|
||||
async def resize_screenshot_if_needed(self, screenshot: bytes) -> bytes:
|
||||
"""Resize a screenshot to match the expected dimensions.
|
||||
|
||||
Args:
|
||||
screenshot: Raw screenshot data
|
||||
|
||||
Returns:
|
||||
Resized screenshot data
|
||||
"""
|
||||
if self.width is None or self.height is None:
|
||||
raise ToolError("Screen dimensions not initialized")
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(screenshot))
|
||||
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
|
||||
img = img.convert("RGB")
|
||||
|
||||
# Resize if dimensions don't match
|
||||
if img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling image from {img.size} to {self.width}x{self.height} to match screen dimensions"
|
||||
)
|
||||
img = img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Save back to bytes
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
return buffer.getvalue()
|
||||
|
||||
return screenshot
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during screenshot resizing: {str(e)}")
|
||||
raise ToolError(f"Failed to resize screenshot: {str(e)}")
|
||||
|
||||
async def screenshot(self) -> ToolResult:
|
||||
"""Take a screenshot and return it as a ToolResult with base64-encoded image.
|
||||
|
||||
Returns:
|
||||
ToolResult with the screenshot
|
||||
"""
|
||||
try:
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
||||
return ToolResult(base64_image=base64.b64encode(screenshot).decode())
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error taking screenshot: {str(e)}")
|
||||
return ToolResult(error=f"Failed to take screenshot: {str(e)}")
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Execute the tool with the provided arguments."""
|
||||
raise NotImplementedError
|
||||
67
libs/agent/agent/core/tools/edit.py
Normal file
67
libs/agent/agent/core/tools/edit.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Abstract base edit tool implementation."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseTool, ToolError, ToolResult
|
||||
|
||||
|
||||
class BaseEditTool(BaseTool):
|
||||
"""Base class for text editor tools across different providers."""
|
||||
|
||||
name = "edit"
|
||||
logger = logging.getLogger(__name__)
|
||||
computer: Computer
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the EditTool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance, may be used for related operations
|
||||
"""
|
||||
self.computer = computer
|
||||
|
||||
async def read_file(self, path: str) -> str:
|
||||
"""Read a file and return its contents.
|
||||
|
||||
Args:
|
||||
path: Path to the file to read
|
||||
|
||||
Returns:
|
||||
File contents as a string
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path)
|
||||
if not path_obj.exists():
|
||||
raise ToolError(f"File does not exist: {path}")
|
||||
return path_obj.read_text()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error reading file: {str(e)}")
|
||||
raise ToolError(f"Failed to read file: {str(e)}")
|
||||
|
||||
async def write_file(self, path: str, content: str) -> None:
|
||||
"""Write content to a file.
|
||||
|
||||
Args:
|
||||
path: Path to the file to write
|
||||
content: Content to write to the file
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path)
|
||||
# Create parent directories if they don't exist
|
||||
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
path_obj.write_text(content)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error writing file: {str(e)}")
|
||||
raise ToolError(f"Failed to write file: {str(e)}")
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Execute the tool with the provided arguments."""
|
||||
raise NotImplementedError
|
||||
56
libs/agent/agent/core/tools/manager.py
Normal file
56
libs/agent/agent/core/tools/manager.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Tool manager for initializing and running tools."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
from .collection import ToolCollection
|
||||
|
||||
|
||||
class BaseToolManager(ABC):
|
||||
"""Base class for tool managers across different providers."""
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for computer-related tools
|
||||
"""
|
||||
self.computer = computer
|
||||
self.tools: ToolCollection | None = None
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_tools(self) -> ToolCollection:
|
||||
"""Initialize all available tools."""
|
||||
...
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize tool-specific requirements and create tool collection."""
|
||||
await self._initialize_tools_specific()
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
@abstractmethod
|
||||
async def _initialize_tools_specific(self) -> None:
|
||||
"""Initialize provider-specific tool requirements."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_tool_params(self) -> List[Dict[str, Any]]:
|
||||
"""Get tool parameters for API calls."""
|
||||
...
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to execute
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return await self.tools.run(name=name, tool_input=tool_input)
|
||||
4
libs/agent/agent/providers/__init__.py
Normal file
4
libs/agent/agent/providers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Provider implementations for different AI services."""
|
||||
|
||||
# Import specific providers only when needed to avoid circular imports
|
||||
__all__ = [] # Let each provider module handle its own exports
|
||||
6
libs/agent/agent/providers/anthropic/__init__.py
Normal file
6
libs/agent/agent/providers/anthropic/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Anthropic provider implementation."""
|
||||
|
||||
from .loop import AnthropicLoop
|
||||
from .types import LLMProvider
|
||||
|
||||
__all__ = ["AnthropicLoop", "LLMProvider"]
|
||||
219
libs/agent/agent/providers/anthropic/api/client.py
Normal file
219
libs/agent/agent/providers/anthropic/api/client.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from typing import Any
|
||||
import httpx
|
||||
import asyncio
|
||||
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
|
||||
from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaToolUnionParam
|
||||
from ..types import LLMProvider
|
||||
from .logging import log_api_interaction
|
||||
import random
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIConnectionError(Exception):
|
||||
"""Error raised when there are connection issues with the API."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseAnthropicClient:
|
||||
"""Base class for Anthropic API clients."""
|
||||
|
||||
MAX_RETRIES = 10
|
||||
INITIAL_RETRY_DELAY = 1.0
|
||||
MAX_RETRY_DELAY = 60.0
|
||||
JITTER_FACTOR = 0.1
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
*,
|
||||
messages: list[BetaMessageParam],
|
||||
system: list[Any],
|
||||
tools: list[BetaToolUnionParam],
|
||||
max_tokens: int,
|
||||
betas: list[str],
|
||||
) -> BetaMessage:
|
||||
"""Create a message using the Anthropic API."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _make_api_call_with_retries(self, api_call):
|
||||
"""Make an API call with exponential backoff retry logic.
|
||||
|
||||
Args:
|
||||
api_call: Async function that makes the actual API call
|
||||
|
||||
Returns:
|
||||
API response
|
||||
|
||||
Raises:
|
||||
APIConnectionError: If all retries fail
|
||||
"""
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
while retry_count < self.MAX_RETRIES:
|
||||
try:
|
||||
return await api_call()
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retry_count += 1
|
||||
|
||||
if retry_count == self.MAX_RETRIES:
|
||||
break
|
||||
|
||||
# Calculate delay with exponential backoff and jitter
|
||||
delay = min(
|
||||
self.INITIAL_RETRY_DELAY * (2 ** (retry_count - 1)), self.MAX_RETRY_DELAY
|
||||
)
|
||||
# Add jitter to avoid thundering herd
|
||||
jitter = delay * self.JITTER_FACTOR * (2 * random.random() - 1)
|
||||
final_delay = delay + jitter
|
||||
|
||||
logger.info(
|
||||
f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) "
|
||||
f"in {final_delay:.2f} seconds after error: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(final_delay)
|
||||
|
||||
raise APIConnectionError(
|
||||
f"Failed after {self.MAX_RETRIES} retries. " f"Last error: {str(last_error)}"
|
||||
)
|
||||
|
||||
|
||||
class AnthropicDirectClient(BaseAnthropicClient):
|
||||
"""Direct Anthropic API client implementation."""
|
||||
|
||||
def __init__(self, api_key: str, model: str):
|
||||
self.model = model
|
||||
self.client = Anthropic(api_key=api_key, http_client=self._create_http_client())
|
||||
|
||||
def _create_http_client(self) -> httpx.Client:
|
||||
"""Create an HTTP client with appropriate settings."""
|
||||
return httpx.Client(
|
||||
verify=True,
|
||||
timeout=httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0),
|
||||
transport=httpx.HTTPTransport(
|
||||
retries=3,
|
||||
verify=True,
|
||||
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
||||
),
|
||||
)
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
*,
|
||||
messages: list[BetaMessageParam],
|
||||
system: list[Any],
|
||||
tools: list[BetaToolUnionParam],
|
||||
max_tokens: int,
|
||||
betas: list[str],
|
||||
) -> BetaMessage:
|
||||
"""Create a message using the direct Anthropic API with retry logic."""
|
||||
|
||||
async def api_call():
|
||||
response = self.client.beta.messages.with_raw_response.create(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
system=system,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
log_api_interaction(response.http_response.request, response.http_response, None)
|
||||
return response.parse()
|
||||
|
||||
try:
|
||||
return await self._make_api_call_with_retries(api_call)
|
||||
except Exception as e:
|
||||
log_api_interaction(None, None, e)
|
||||
raise
|
||||
|
||||
|
||||
class AnthropicVertexClient(BaseAnthropicClient):
|
||||
"""Google Cloud Vertex AI implementation of Anthropic client."""
|
||||
|
||||
def __init__(self, model: str):
|
||||
self.model = model
|
||||
self.client = AnthropicVertex()
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
*,
|
||||
messages: list[BetaMessageParam],
|
||||
system: list[Any],
|
||||
tools: list[BetaToolUnionParam],
|
||||
max_tokens: int,
|
||||
betas: list[str],
|
||||
) -> BetaMessage:
|
||||
"""Create a message using Vertex AI with retry logic."""
|
||||
|
||||
async def api_call():
|
||||
response = self.client.beta.messages.with_raw_response.create(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
system=system,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
log_api_interaction(response.http_response.request, response.http_response, None)
|
||||
return response.parse()
|
||||
|
||||
try:
|
||||
return await self._make_api_call_with_retries(api_call)
|
||||
except Exception as e:
|
||||
log_api_interaction(None, None, e)
|
||||
raise
|
||||
|
||||
|
||||
class AnthropicBedrockClient(BaseAnthropicClient):
|
||||
"""AWS Bedrock implementation of Anthropic client."""
|
||||
|
||||
def __init__(self, model: str):
|
||||
self.model = model
|
||||
self.client = AnthropicBedrock()
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
*,
|
||||
messages: list[BetaMessageParam],
|
||||
system: list[Any],
|
||||
tools: list[BetaToolUnionParam],
|
||||
max_tokens: int,
|
||||
betas: list[str],
|
||||
) -> BetaMessage:
|
||||
"""Create a message using AWS Bedrock with retry logic."""
|
||||
|
||||
async def api_call():
|
||||
response = self.client.beta.messages.with_raw_response.create(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
system=system,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
log_api_interaction(response.http_response.request, response.http_response, None)
|
||||
return response.parse()
|
||||
|
||||
try:
|
||||
return await self._make_api_call_with_retries(api_call)
|
||||
except Exception as e:
|
||||
log_api_interaction(None, None, e)
|
||||
raise
|
||||
|
||||
|
||||
class AnthropicClientFactory:
|
||||
"""Factory for creating appropriate Anthropic client implementations."""
|
||||
|
||||
@staticmethod
|
||||
def create_client(provider: LLMProvider, api_key: str, model: str) -> BaseAnthropicClient:
|
||||
"""Create an appropriate client based on the provider."""
|
||||
if provider == LLMProvider.ANTHROPIC:
|
||||
return AnthropicDirectClient(api_key, model)
|
||||
elif provider == LLMProvider.VERTEX:
|
||||
return AnthropicVertexClient(model)
|
||||
elif provider == LLMProvider.BEDROCK:
|
||||
return AnthropicBedrockClient(model)
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
150
libs/agent/agent/providers/anthropic/api/logging.py
Normal file
150
libs/agent/agent/providers/anthropic/api/logging.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""API logging functionality."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _filter_base64_images(content: Any) -> Any:
|
||||
"""Filter out base64 image data from content.
|
||||
|
||||
Args:
|
||||
content: Content to filter
|
||||
|
||||
Returns:
|
||||
Filtered content with base64 data replaced by placeholder
|
||||
"""
|
||||
if isinstance(content, dict):
|
||||
filtered = {}
|
||||
for key, value in content.items():
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and value.get("type") == "image"
|
||||
and value.get("source", {}).get("type") == "base64"
|
||||
):
|
||||
# Replace base64 data with placeholder
|
||||
filtered[key] = {
|
||||
**value,
|
||||
"source": {
|
||||
**value["source"],
|
||||
"data": "<base64_image_data>"
|
||||
}
|
||||
}
|
||||
else:
|
||||
filtered[key] = _filter_base64_images(value)
|
||||
return filtered
|
||||
elif isinstance(content, list):
|
||||
return [_filter_base64_images(item) for item in content]
|
||||
return content
|
||||
|
||||
def log_api_interaction(
|
||||
request: httpx.Request | None,
|
||||
response: httpx.Response | object | None,
|
||||
error: Exception | None,
|
||||
log_dir: Path = Path("/tmp/claude_logs")
|
||||
) -> None:
|
||||
"""Log API request, response, and any errors in a structured way.
|
||||
|
||||
Args:
|
||||
request: The HTTP request if available
|
||||
response: The HTTP response or response object
|
||||
error: Any error that occurred
|
||||
log_dir: Directory to store log files
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
# Helper function to safely decode JSON content
|
||||
def safe_json_decode(content):
|
||||
if not content:
|
||||
return None
|
||||
try:
|
||||
if isinstance(content, bytes):
|
||||
return json.loads(content.decode())
|
||||
elif isinstance(content, str):
|
||||
return json.loads(content)
|
||||
elif isinstance(content, dict):
|
||||
return content
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return {"error": "Could not decode JSON", "raw": str(content)}
|
||||
|
||||
# Process request content
|
||||
request_content = None
|
||||
if request and request.content:
|
||||
request_content = safe_json_decode(request.content)
|
||||
request_content = _filter_base64_images(request_content)
|
||||
|
||||
# Process response content
|
||||
response_content = None
|
||||
if response:
|
||||
if isinstance(response, httpx.Response):
|
||||
try:
|
||||
response_content = response.json()
|
||||
except json.JSONDecodeError:
|
||||
response_content = {"error": "Could not decode JSON", "raw": response.text}
|
||||
else:
|
||||
response_content = safe_json_decode(response)
|
||||
response_content = _filter_base64_images(response_content)
|
||||
|
||||
log_entry = {
|
||||
"timestamp": timestamp,
|
||||
"request": {
|
||||
"method": request.method if request else None,
|
||||
"url": str(request.url) if request else None,
|
||||
"headers": dict(request.headers) if request else None,
|
||||
"content": request_content,
|
||||
} if request else None,
|
||||
"response": {
|
||||
"status_code": response.status_code if isinstance(response, httpx.Response) else None,
|
||||
"headers": dict(response.headers) if isinstance(response, httpx.Response) else None,
|
||||
"content": response_content,
|
||||
} if response else None,
|
||||
"error": {
|
||||
"type": type(error).__name__ if error else None,
|
||||
"message": str(error) if error else None,
|
||||
} if error else None
|
||||
}
|
||||
|
||||
# Log to file with timestamp in filename
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
log_file = log_dir / f"claude_api_{timestamp.replace(' ', '_').replace(':', '-')}.json"
|
||||
|
||||
with open(log_file, 'w') as f:
|
||||
json.dump(log_entry, f, indent=2)
|
||||
|
||||
# Also log a summary to the console
|
||||
if error:
|
||||
logger.error(f"API Error at {timestamp}: {error}")
|
||||
else:
|
||||
logger.info(
|
||||
f"API Call at {timestamp}: "
|
||||
f"{request.method if request else 'No request'} -> "
|
||||
f"{response.status_code if isinstance(response, httpx.Response) else 'No response'}"
|
||||
)
|
||||
|
||||
# Log if there are any images in the content
|
||||
if response_content:
|
||||
image_count = count_images(response_content)
|
||||
if image_count > 0:
|
||||
logger.info(f"Response contains {image_count} images")
|
||||
|
||||
def count_images(content: dict | list | Any) -> int:
|
||||
"""Count the number of images in the content.
|
||||
|
||||
Args:
|
||||
content: Content to search for images
|
||||
|
||||
Returns:
|
||||
Number of images found
|
||||
"""
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") == "image":
|
||||
return 1
|
||||
return sum(count_images(v) for v in content.values())
|
||||
elif isinstance(content, list):
|
||||
return sum(count_images(item) for item in content)
|
||||
return 0
|
||||
55
libs/agent/agent/providers/anthropic/callbacks/manager.py
Normal file
55
libs/agent/agent/providers/anthropic/callbacks/manager.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Callable, Protocol
|
||||
import httpx
|
||||
from anthropic.types.beta import BetaContentBlockParam
|
||||
from ..tools import ToolResult
|
||||
|
||||
class APICallback(Protocol):
|
||||
"""Protocol for API callbacks."""
|
||||
def __call__(self, request: httpx.Request | None,
|
||||
response: httpx.Response | object | None,
|
||||
error: Exception | None) -> None: ...
|
||||
|
||||
class ContentCallback(Protocol):
|
||||
"""Protocol for content callbacks."""
|
||||
def __call__(self, content: BetaContentBlockParam) -> None: ...
|
||||
|
||||
class ToolCallback(Protocol):
|
||||
"""Protocol for tool callbacks."""
|
||||
def __call__(self, result: ToolResult, tool_id: str) -> None: ...
|
||||
|
||||
class CallbackManager:
|
||||
"""Manages various callbacks for the agent system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_callback: ContentCallback,
|
||||
tool_callback: ToolCallback,
|
||||
api_callback: APICallback,
|
||||
):
|
||||
"""Initialize the callback manager.
|
||||
|
||||
Args:
|
||||
content_callback: Callback for content updates
|
||||
tool_callback: Callback for tool execution results
|
||||
api_callback: Callback for API interactions
|
||||
"""
|
||||
self.content_callback = content_callback
|
||||
self.tool_callback = tool_callback
|
||||
self.api_callback = api_callback
|
||||
|
||||
def on_content(self, content: BetaContentBlockParam) -> None:
|
||||
"""Handle content updates."""
|
||||
self.content_callback(content)
|
||||
|
||||
def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
|
||||
"""Handle tool execution results."""
|
||||
self.tool_callback(result, tool_id)
|
||||
|
||||
def on_api_interaction(
|
||||
self,
|
||||
request: httpx.Request | None,
|
||||
response: httpx.Response | object | None,
|
||||
error: Exception | None
|
||||
) -> None:
|
||||
"""Handle API interactions."""
|
||||
self.api_callback(request, response, error)
|
||||
521
libs/agent/agent/providers/anthropic/loop.py
Normal file
521
libs/agent/agent/providers/anthropic/loop.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""Anthropic-specific agent loop implementation."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, cast
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
|
||||
# Anthropic-specific imports
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
BetaMessageParam,
|
||||
BetaTextBlock,
|
||||
BetaTextBlockParam,
|
||||
BetaToolUseBlockParam,
|
||||
)
|
||||
|
||||
# Computer
|
||||
from computer import Computer
|
||||
|
||||
# Base imports
|
||||
from ...core.loop import BaseLoop
|
||||
from ...core.messages import ImageRetentionConfig
|
||||
|
||||
# Anthropic provider-specific imports
|
||||
from .api.client import AnthropicClientFactory, BaseAnthropicClient
|
||||
from .tools.manager import ToolManager
|
||||
from .messages.manager import MessageManager
|
||||
from .callbacks.manager import CallbackManager
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
from .types import LLMProvider
|
||||
from .tools import ToolResult
|
||||
|
||||
# Constants
|
||||
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
|
||||
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicLoop(BaseLoop):
|
||||
"""Anthropic-specific implementation of the agent loop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "claude-3-7-sonnet-20250219", # Fixed model
|
||||
computer: Optional[Computer] = None,
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
save_trajectory: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Anthropic loop.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
model: Model name (fixed to claude-3-7-sonnet-20250219)
|
||||
computer: Computer instance
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
base_dir: Base directory for saving experiment data
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Delay between retries in seconds
|
||||
save_trajectory: Whether to save trajectory data
|
||||
"""
|
||||
# Initialize base class
|
||||
super().__init__(
|
||||
computer=computer,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_dir=base_dir,
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Ensure model is always the fixed one
|
||||
self.model = "claude-3-7-sonnet-20250219"
|
||||
|
||||
# Anthropic-specific attributes
|
||||
self.provider = LLMProvider.ANTHROPIC
|
||||
self.client = None
|
||||
self.retry_count = 0
|
||||
self.tool_manager = None
|
||||
self.message_manager = None
|
||||
self.callback_manager = None
|
||||
|
||||
# Configure image retention
|
||||
self.image_retention_config = ImageRetentionConfig(
|
||||
num_images_to_keep=only_n_most_recent_images
|
||||
)
|
||||
|
||||
# Message history
|
||||
self.message_history = []
|
||||
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the Anthropic API client and tools."""
|
||||
try:
|
||||
logger.info(f"Initializing Anthropic client with model {self.model}...")
|
||||
|
||||
# Initialize client
|
||||
self.client = AnthropicClientFactory.create_client(
|
||||
provider=self.provider, api_key=self.api_key, model=self.model
|
||||
)
|
||||
|
||||
# Initialize message manager
|
||||
self.message_manager = MessageManager(
|
||||
ImageRetentionConfig(
|
||||
num_images_to_keep=self.only_n_most_recent_images, enable_caching=True
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize callback manager
|
||||
self.callback_manager = CallbackManager(
|
||||
content_callback=self._handle_content,
|
||||
tool_callback=self._handle_tool_result,
|
||||
api_callback=self._handle_api_interaction,
|
||||
)
|
||||
|
||||
# Initialize tool manager
|
||||
self.tool_manager = ToolManager(self.computer)
|
||||
await self.tool_manager.initialize()
|
||||
|
||||
logger.info(f"Initialized Anthropic client with model {self.model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Anthropic client: {str(e)}")
|
||||
self.client = None
|
||||
raise RuntimeError(f"Failed to initialize Anthropic client: {str(e)}")
|
||||
|
||||
async def _process_screen(
|
||||
self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Process screen information and add to messages.
|
||||
|
||||
Args:
|
||||
parsed_screen: Dictionary containing parsed screen info
|
||||
messages: List of messages to update
|
||||
"""
|
||||
try:
|
||||
# Extract screenshot from parsed screen
|
||||
screenshot_base64 = parsed_screen.get("screenshot_base64")
|
||||
|
||||
if screenshot_base64:
|
||||
# Remove data URL prefix if present
|
||||
if "," in screenshot_base64:
|
||||
screenshot_base64 = screenshot_base64.split(",")[1]
|
||||
|
||||
# Create Anthropic-compatible message with image
|
||||
screen_info_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": screenshot_base64,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Add screen info message to messages
|
||||
messages.append(screen_info_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing screen info: {str(e)}")
|
||||
raise
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Yields:
|
||||
Dict containing response data
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting Anthropic loop run")
|
||||
|
||||
# Reset message history and add new messages
|
||||
self.message_history = []
|
||||
self.message_history.extend(messages)
|
||||
|
||||
# Create queue for response streaming
|
||||
queue = asyncio.Queue()
|
||||
|
||||
# Ensure client is initialized
|
||||
if self.client is None or self.tool_manager is None:
|
||||
logger.info("Initializing client...")
|
||||
await self.initialize_client()
|
||||
if self.client is None:
|
||||
raise RuntimeError("Failed to initialize client")
|
||||
logger.info("Client initialized successfully")
|
||||
|
||||
# Start loop in background task
|
||||
loop_task = asyncio.create_task(self._run_loop(queue))
|
||||
|
||||
# Process and yield messages as they arrive
|
||||
while True:
|
||||
try:
|
||||
item = await queue.get()
|
||||
if item is None: # Stop signal
|
||||
break
|
||||
yield item
|
||||
queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing queue item: {str(e)}")
|
||||
continue
|
||||
|
||||
# Wait for loop to complete
|
||||
await loop_task
|
||||
|
||||
# Send completion message
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": "Task completed successfully.",
|
||||
"metadata": {"title": "✅ Complete"},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
async def _run_loop(self, queue: asyncio.Queue) -> None:
|
||||
"""Run the agent loop with current message history.
|
||||
|
||||
Args:
|
||||
queue: Queue for response streaming
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
# Get up-to-date screen information
|
||||
parsed_screen = await self._get_parsed_screen_som()
|
||||
|
||||
# Process screen info and update messages
|
||||
await self._process_screen(parsed_screen, self.message_history)
|
||||
|
||||
# Prepare messages and make API call
|
||||
prepared_messages = self.message_manager.prepare_messages(
|
||||
cast(List[BetaMessageParam], self.message_history.copy())
|
||||
)
|
||||
|
||||
# Create new turn directory for this API call
|
||||
self._create_turn_dir()
|
||||
|
||||
# Make API call
|
||||
response = await self._make_api_call(prepared_messages)
|
||||
|
||||
# Handle the response
|
||||
if not await self._handle_response(response, self.message_history):
|
||||
break
|
||||
|
||||
# Signal completion
|
||||
await queue.put(None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _run_loop: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error in agent loop: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
await queue.put(None)
|
||||
|
||||
async def _make_api_call(self, messages: List[BetaMessageParam]) -> BetaMessage:
|
||||
"""Make API call to Anthropic with retry logic.
|
||||
|
||||
Args:
|
||||
messages: List of messages to send to the API
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Log request
|
||||
request_data = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"system": SYSTEM_PROMPT,
|
||||
}
|
||||
self._log_api_call("request", request_data)
|
||||
|
||||
# Setup betas and system
|
||||
system = BetaTextBlockParam(
|
||||
type="text",
|
||||
text=SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
# Temporarily disable prompt caching due to "A maximum of 4 blocks with cache_control may be provided" error
|
||||
# if self.message_manager.image_retention_config.enable_caching:
|
||||
# betas.append(PROMPT_CACHING_BETA_FLAG)
|
||||
# system["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
# Make API call
|
||||
response = await self.client.create_message(
|
||||
messages=messages,
|
||||
system=[system],
|
||||
tools=self.tool_manager.get_tool_params(),
|
||||
max_tokens=self.max_tokens,
|
||||
betas=betas,
|
||||
)
|
||||
|
||||
# Log success response
|
||||
self._log_api_call("response", request_data, response)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.error(
|
||||
f"Error in API call (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
|
||||
)
|
||||
self._log_api_call("error", {"messages": messages}, error=e)
|
||||
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
||||
continue
|
||||
|
||||
# If we get here, all retries failed
|
||||
error_message = f"API call failed after {self.max_retries} attempts"
|
||||
if last_error:
|
||||
error_message += f": {str(last_error)}"
|
||||
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
async def _handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool:
|
||||
"""Handle the Anthropic API response.
|
||||
|
||||
Args:
|
||||
response: API response
|
||||
messages: List of messages to update
|
||||
|
||||
Returns:
|
||||
True if the loop should continue, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Convert response to parameter format
|
||||
response_params = self._response_to_params(response)
|
||||
|
||||
# Add response to messages
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": response_params,
|
||||
}
|
||||
)
|
||||
|
||||
# Handle tool use blocks and collect results
|
||||
tool_result_content = []
|
||||
for content_block in response_params:
|
||||
# Notify callback of content
|
||||
self.callback_manager.on_content(content_block)
|
||||
|
||||
# Handle tool use
|
||||
if content_block.get("type") == "tool_use":
|
||||
result = await self.tool_manager.execute_tool(
|
||||
name=content_block["name"],
|
||||
tool_input=cast(Dict[str, Any], content_block["input"]),
|
||||
)
|
||||
|
||||
# Create tool result and add to content
|
||||
tool_result = self._make_tool_result(result, content_block["id"])
|
||||
tool_result_content.append(tool_result)
|
||||
|
||||
# Notify callback of tool result
|
||||
self.callback_manager.on_tool_result(result, content_block["id"])
|
||||
|
||||
# If no tool results, we're done
|
||||
if not tool_result_content:
|
||||
# Signal completion
|
||||
self.callback_manager.on_content({"type": "text", "text": "<DONE>"})
|
||||
return False
|
||||
|
||||
# Add tool results to message history
|
||||
messages.append({"content": tool_result_content, "role": "user"})
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling response: {str(e)}")
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
}
|
||||
)
|
||||
return False
|
||||
|
||||
def _response_to_params(
|
||||
self,
|
||||
response: BetaMessage,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert API response to message parameters.
|
||||
|
||||
Args:
|
||||
response: API response message
|
||||
|
||||
Returns:
|
||||
List of content blocks
|
||||
"""
|
||||
result = []
|
||||
for block in response.content:
|
||||
if isinstance(block, BetaTextBlock):
|
||||
result.append({"type": "text", "text": block.text})
|
||||
else:
|
||||
result.append(cast(Dict[str, Any], block.model_dump()))
|
||||
return result
|
||||
|
||||
def _make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
|
||||
"""Convert a tool result to API format.
|
||||
|
||||
Args:
|
||||
result: Tool execution result
|
||||
tool_use_id: ID of the tool use
|
||||
|
||||
Returns:
|
||||
Formatted tool result
|
||||
"""
|
||||
if result.content:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"content": result.content,
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": bool(result.error),
|
||||
}
|
||||
|
||||
tool_result_content = []
|
||||
is_error = False
|
||||
|
||||
if result.error:
|
||||
is_error = True
|
||||
tool_result_content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": self._maybe_prepend_system_tool_result(result, result.error),
|
||||
}
|
||||
]
|
||||
else:
|
||||
if result.output:
|
||||
tool_result_content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": self._maybe_prepend_system_tool_result(result, result.output),
|
||||
}
|
||||
)
|
||||
if result.base64_image:
|
||||
tool_result_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": result.base64_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"content": tool_result_content,
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": is_error,
|
||||
}
|
||||
|
||||
def _maybe_prepend_system_tool_result(self, result: ToolResult, result_text: str) -> str:
|
||||
"""Prepend system information to tool result if available.
|
||||
|
||||
Args:
|
||||
result: Tool execution result
|
||||
result_text: Text to prepend to
|
||||
|
||||
Returns:
|
||||
Text with system information prepended if available
|
||||
"""
|
||||
if result.system:
|
||||
result_text = f"<s>{result.system}</s>\n{result_text}"
|
||||
return result_text
|
||||
|
||||
def _handle_content(self, content: Dict[str, Any]) -> None:
|
||||
"""Handle content updates from the assistant."""
|
||||
if content.get("type") == "text":
|
||||
text = content.get("text", "")
|
||||
if text == "<DONE>":
|
||||
return
|
||||
|
||||
logger.info(f"Assistant: {text}")
|
||||
|
||||
def _handle_tool_result(self, result: ToolResult, tool_id: str) -> None:
|
||||
"""Handle tool execution results."""
|
||||
if result.error:
|
||||
logger.error(f"Tool {tool_id} error: {result.error}")
|
||||
else:
|
||||
logger.info(f"Tool {tool_id} output: {result.output}")
|
||||
|
||||
def _handle_api_interaction(
|
||||
self, request: Any, response: Any, error: Optional[Exception]
|
||||
) -> None:
|
||||
"""Handle API interactions."""
|
||||
if error:
|
||||
logger.error(f"API error: {error}")
|
||||
else:
|
||||
logger.debug(f"API request: {request}")
|
||||
110
libs/agent/agent/providers/anthropic/messages/manager.py
Normal file
110
libs/agent/agent/providers/anthropic/messages/manager.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from anthropic.types.beta import (
|
||||
BetaMessageParam,
|
||||
BetaCacheControlEphemeralParam,
|
||||
BetaToolResultBlockParam,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageRetentionConfig:
|
||||
"""Configuration for image retention in messages."""
|
||||
|
||||
num_images_to_keep: int | None = None
|
||||
min_removal_threshold: int = 1
|
||||
enable_caching: bool = True
|
||||
|
||||
def should_retain_images(self) -> bool:
|
||||
"""Check if image retention is enabled."""
|
||||
return self.num_images_to_keep is not None and self.num_images_to_keep > 0
|
||||
|
||||
|
||||
class MessageManager:
|
||||
"""Manages message preparation, including image retention and caching."""
|
||||
|
||||
def __init__(self, image_retention_config: ImageRetentionConfig):
|
||||
"""Initialize the message manager.
|
||||
|
||||
Args:
|
||||
image_retention_config: Configuration for image retention
|
||||
"""
|
||||
if image_retention_config.min_removal_threshold < 1:
|
||||
raise ValueError("min_removal_threshold must be at least 1")
|
||||
self.image_retention_config = image_retention_config
|
||||
|
||||
def prepare_messages(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]:
|
||||
"""Prepare messages by applying image retention and caching as configured."""
|
||||
if self.image_retention_config.should_retain_images():
|
||||
self._filter_images(messages)
|
||||
if self.image_retention_config.enable_caching:
|
||||
self._inject_caching(messages)
|
||||
return messages
|
||||
|
||||
def _filter_images(self, messages: list[BetaMessageParam]) -> None:
|
||||
"""Filter messages to retain only the specified number of most recent images."""
|
||||
tool_result_blocks = cast(
|
||||
list[BetaToolResultBlockParam],
|
||||
[
|
||||
item
|
||||
for message in messages
|
||||
for item in (message["content"] if isinstance(message["content"], list) else [])
|
||||
if isinstance(item, dict) and item.get("type") == "tool_result"
|
||||
],
|
||||
)
|
||||
|
||||
total_images = sum(
|
||||
1
|
||||
for tool_result in tool_result_blocks
|
||||
for content in tool_result.get("content", [])
|
||||
if isinstance(content, dict) and content.get("type") == "image"
|
||||
)
|
||||
|
||||
images_to_remove = total_images - (self.image_retention_config.num_images_to_keep or 0)
|
||||
# Round down to nearest min_removal_threshold for better cache behavior
|
||||
images_to_remove -= images_to_remove % self.image_retention_config.min_removal_threshold
|
||||
|
||||
# Remove oldest images first
|
||||
for tool_result in tool_result_blocks:
|
||||
if isinstance(tool_result.get("content"), list):
|
||||
new_content = []
|
||||
for content in tool_result.get("content", []):
|
||||
if isinstance(content, dict) and content.get("type") == "image":
|
||||
if images_to_remove > 0:
|
||||
images_to_remove -= 1
|
||||
continue
|
||||
new_content.append(content)
|
||||
tool_result["content"] = new_content
|
||||
|
||||
def _inject_caching(self, messages: list[BetaMessageParam]) -> None:
|
||||
"""Inject caching control for the most recent turns, limited to 3 blocks max to avoid API errors."""
|
||||
# Anthropic API allows a maximum of 4 blocks with cache_control
|
||||
# We use 3 here to be safe, as the system block may also have cache_control
|
||||
blocks_with_cache_control = 0
|
||||
max_cache_control_blocks = 3
|
||||
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "user" and isinstance(content := message["content"], list):
|
||||
# Only add cache control to the latest message in each turn
|
||||
if blocks_with_cache_control < max_cache_control_blocks:
|
||||
blocks_with_cache_control += 1
|
||||
# Add cache control to the last content block only
|
||||
if content and len(content) > 0:
|
||||
content[-1]["cache_control"] = {"type": "ephemeral"}
|
||||
else:
|
||||
# Remove any existing cache control
|
||||
if content and len(content) > 0:
|
||||
content[-1].pop("cache_control", None)
|
||||
|
||||
# Ensure we're not exceeding the limit by checking the total
|
||||
if blocks_with_cache_control > max_cache_control_blocks:
|
||||
# If we somehow exceeded the limit, remove excess cache controls
|
||||
excess = blocks_with_cache_control - max_cache_control_blocks
|
||||
for message in messages:
|
||||
if excess <= 0:
|
||||
break
|
||||
|
||||
if message["role"] == "user" and isinstance(content := message["content"], list):
|
||||
if content and len(content) > 0 and "cache_control" in content[-1]:
|
||||
content[-1].pop("cache_control", None)
|
||||
excess -= 1
|
||||
20
libs/agent/agent/providers/anthropic/prompts.py
Normal file
20
libs/agent/agent/providers/anthropic/prompts.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""System prompts for Anthropic provider."""
|
||||
|
||||
from datetime import datetime
|
||||
import platform
|
||||
|
||||
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
|
||||
* You are utilising a macOS virtual machine using ARM architecture with internet access and Safari as default browser.
|
||||
* You can feel free to install macOS applications with your bash tool. Use curl instead of wget.
|
||||
* Using bash tool you can start GUI applications. GUI apps run with bash tool will appear within your desktop environment, but they may take some time to appear. Take a screenshot to confirm it did.
|
||||
* When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B <lines before> -A <lines after> <query> <filename>` to confirm output.
|
||||
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
* Plan at maximum 1 step each time, and evaluate the result of each step before proceeding. Hold back if you're not sure about the result of the step.
|
||||
* If you're not sure about the location of an application, use start the app using the bash tool.
|
||||
* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool.
|
||||
</IMPORTANT>"""
|
||||
33
libs/agent/agent/providers/anthropic/tools/__init__.py
Normal file
33
libs/agent/agent/providers/anthropic/tools/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Anthropic-specific tools for agent."""
|
||||
|
||||
from .base import (
|
||||
BaseAnthropicTool,
|
||||
ToolResult,
|
||||
ToolError,
|
||||
ToolFailure,
|
||||
CLIResult,
|
||||
AnthropicToolResult,
|
||||
AnthropicToolError,
|
||||
AnthropicToolFailure,
|
||||
AnthropicCLIResult,
|
||||
)
|
||||
from .bash import BashTool
|
||||
from .computer import ComputerTool
|
||||
from .edit import EditTool
|
||||
from .manager import ToolManager
|
||||
|
||||
__all__ = [
|
||||
"BaseAnthropicTool",
|
||||
"ToolResult",
|
||||
"ToolError",
|
||||
"ToolFailure",
|
||||
"CLIResult",
|
||||
"AnthropicToolResult",
|
||||
"AnthropicToolError",
|
||||
"AnthropicToolFailure",
|
||||
"AnthropicCLIResult",
|
||||
"BashTool",
|
||||
"ComputerTool",
|
||||
"EditTool",
|
||||
"ToolManager",
|
||||
]
|
||||
88
libs/agent/agent/providers/anthropic/tools/base.py
Normal file
88
libs/agent/agent/providers/anthropic/tools/base.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Anthropic-specific tool base classes."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Dict
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
from ....core.tools.base import BaseTool, ToolError, ToolResult, ToolFailure, CLIResult
|
||||
|
||||
|
||||
class BaseAnthropicTool(BaseTool, metaclass=ABCMeta):
|
||||
"""Abstract base class for Anthropic-defined tools."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the base Anthropic tool."""
|
||||
# No specific initialization needed yet, but included for future extensibility
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to Anthropic-specific API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters for Anthropic API
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class ToolResult:
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
base64_image: str | None = None
|
||||
system: str | None = None
|
||||
content: list[dict] | None = None
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
content=self.content or other.content, # Use first non-None content
|
||||
)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
# Re-export the core tool classes with Anthropic-specific names for backward compatibility
|
||||
AnthropicToolResult = ToolResult
|
||||
AnthropicToolError = ToolError
|
||||
AnthropicToolFailure = ToolFailure
|
||||
AnthropicCLIResult = CLIResult
|
||||
163
libs/agent/agent/providers/anthropic/tools/bash.py
Normal file
163
libs/agent/agent/providers/anthropic/tools/bash.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import ClassVar, Literal, Dict, Any
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
||||
from ....core.tools.bash import BaseBashTool
|
||||
|
||||
|
||||
class _BashSession:
|
||||
"""A session of a bash shell."""
|
||||
|
||||
_started: bool
|
||||
_process: asyncio.subprocess.Process
|
||||
|
||||
command: str = "/bin/bash"
|
||||
_output_delay: float = 0.2 # seconds
|
||||
_timeout: float = 120.0 # seconds
|
||||
_sentinel: str = "<<exit>>"
|
||||
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
self._timed_out = False
|
||||
|
||||
async def start(self):
|
||||
if self._started:
|
||||
return
|
||||
|
||||
self._process = await asyncio.create_subprocess_shell(
|
||||
self.command,
|
||||
preexec_fn=os.setsid,
|
||||
shell=True,
|
||||
bufsize=0,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
self._started = True
|
||||
|
||||
def stop(self):
|
||||
"""Terminate the bash shell."""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return
|
||||
self._process.terminate()
|
||||
|
||||
async def run(self, command: str):
|
||||
"""Execute a command in the bash shell."""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return ToolResult(
|
||||
system="tool must be restarted",
|
||||
error=f"bash has exited with returncode {self._process.returncode}",
|
||||
)
|
||||
if self._timed_out:
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
)
|
||||
|
||||
# we know these are not None because we created the process with PIPEs
|
||||
assert self._process.stdin
|
||||
assert self._process.stdout
|
||||
assert self._process.stderr
|
||||
|
||||
# send command to the process
|
||||
self._process.stdin.write(command.encode() + f"; echo '{self._sentinel}'\n".encode())
|
||||
await self._process.stdin.drain()
|
||||
|
||||
# read output from the process, until the sentinel is found
|
||||
try:
|
||||
async with asyncio.timeout(self._timeout):
|
||||
while True:
|
||||
await asyncio.sleep(self._output_delay)
|
||||
# if we read directly from stdout/stderr, it will wait forever for
|
||||
# EOF. use the StreamReader buffer directly instead.
|
||||
output = (
|
||||
self._process.stdout._buffer.decode()
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
if self._sentinel in output:
|
||||
# strip the sentinel and break
|
||||
output = output[: output.index(self._sentinel)]
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
self._timed_out = True
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
) from None
|
||||
|
||||
if output.endswith("\n"):
|
||||
output = output[:-1]
|
||||
|
||||
error = self._process.stderr._buffer.decode() # pyright: ignore[reportAttributeAccessIssue]
|
||||
if error.endswith("\n"):
|
||||
error = error[:-1]
|
||||
|
||||
# clear the buffers so that the next output can be read correctly
|
||||
self._process.stdout._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
|
||||
self._process.stderr._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return CLIResult(output=output, error=error)
|
||||
|
||||
|
||||
class BashTool(BaseBashTool, BaseAnthropicTool):
|
||||
"""
|
||||
A tool that allows the agent to run bash commands.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
name: ClassVar[Literal["bash"]] = "bash"
|
||||
api_type: ClassVar[Literal["bash_20250124"]] = "bash_20250124"
|
||||
_timeout: float = 120.0 # seconds
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the bash tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for executing commands
|
||||
"""
|
||||
# Initialize the base bash tool first
|
||||
BaseBashTool.__init__(self, computer)
|
||||
# Then initialize the Anthropic tool
|
||||
BaseAnthropicTool.__init__(self)
|
||||
# Initialize bash session
|
||||
self._session = _BashSession()
|
||||
|
||||
async def __call__(self, command: str | None = None, restart: bool = False, **kwargs):
|
||||
"""Execute a bash command.
|
||||
|
||||
Args:
|
||||
command: The command to execute
|
||||
restart: Whether to restart the shell (not used with computer interface)
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Raises:
|
||||
ToolError: If command execution fails
|
||||
"""
|
||||
if restart:
|
||||
return ToolResult(system="Restart not needed with computer interface.")
|
||||
|
||||
if command is None:
|
||||
raise ToolError("no command provided.")
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(self._timeout):
|
||||
stdout, stderr = await self.computer.interface.run_command(command)
|
||||
return CLIResult(output=stdout or "", error=stderr or "")
|
||||
except asyncio.TimeoutError as e:
|
||||
raise ToolError(f"Command timed out after {self._timeout} seconds") from e
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to execute command: {str(e)}")
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
return {"name": self.name, "type": self.api_type}
|
||||
34
libs/agent/agent/providers/anthropic/tools/collection.py
Normal file
34
libs/agent/agent/providers/anthropic/tools/collection.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Collection classes for managing multiple tools."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
from .base import (
|
||||
BaseAnthropicTool,
|
||||
ToolError,
|
||||
ToolFailure,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""A collection of anthropic-defined tools."""
|
||||
|
||||
def __init__(self, *tools: BaseAnthropicTool):
|
||||
self.tools = tools
|
||||
self.tool_map = {tool.to_params()["name"]: tool for tool in tools}
|
||||
|
||||
def to_params(
|
||||
self,
|
||||
) -> list[BetaToolUnionParam]:
|
||||
return [tool.to_params() for tool in self.tools]
|
||||
|
||||
async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
||||
tool = self.tool_map.get(name)
|
||||
if not tool:
|
||||
return ToolFailure(error=f"Tool {name} is invalid")
|
||||
try:
|
||||
return await tool(**tool_input)
|
||||
except ToolError as e:
|
||||
return ToolFailure(error=e.message)
|
||||
550
libs/agent/agent/providers/anthropic/tools/computer.py
Normal file
550
libs/agent/agent/providers/anthropic/tools/computer.py
Normal file
@@ -0,0 +1,550 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Literal, TypedDict, Any, Dict
|
||||
import subprocess
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseAnthropicTool, ToolError, ToolResult
|
||||
from .run import run
|
||||
from ....core.tools.computer import BaseComputerTool
|
||||
|
||||
TYPING_DELAY_MS = 12
|
||||
TYPING_GROUP_SIZE = 50
|
||||
|
||||
Action = Literal[
|
||||
"key",
|
||||
"type",
|
||||
"mouse_move",
|
||||
"left_click",
|
||||
"left_click_drag",
|
||||
"right_click",
|
||||
"middle_click",
|
||||
"double_click",
|
||||
"screenshot",
|
||||
"cursor_position",
|
||||
"scroll",
|
||||
]
|
||||
|
||||
|
||||
class Resolution(TypedDict):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class ScalingSource(StrEnum):
|
||||
COMPUTER = "computer"
|
||||
API = "api"
|
||||
|
||||
|
||||
class ComputerToolOptions(TypedDict):
|
||||
display_height_px: int
|
||||
display_width_px: int
|
||||
display_number: int | None
|
||||
|
||||
|
||||
def chunks(s: str, chunk_size: int) -> list[str]:
|
||||
return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)]
|
||||
|
||||
|
||||
class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
"""
|
||||
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current macOS computer.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
name: Literal["computer"] = "computer"
|
||||
api_type: Literal["computer_20250124"] = "computer_20250124"
|
||||
width: int | None
|
||||
height: int | None
|
||||
display_num: int | None
|
||||
computer: Computer # The CUA Computer instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_screenshot_delay = 1.0 # macOS is generally faster than X11
|
||||
_scaling_enabled = True
|
||||
|
||||
@property
|
||||
def options(self) -> ComputerToolOptions:
|
||||
if self.width is None or self.height is None:
|
||||
raise RuntimeError(
|
||||
"Screen dimensions not initialized. Call initialize_dimensions() first."
|
||||
)
|
||||
return {
|
||||
"display_width_px": self.width,
|
||||
"display_height_px": self.height,
|
||||
"display_number": self.display_num,
|
||||
}
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
return {"name": self.name, "type": self.api_type, **self.options}
|
||||
|
||||
def __init__(self, computer):
|
||||
# Initialize the base computer tool first
|
||||
BaseComputerTool.__init__(self, computer)
|
||||
# Then initialize the Anthropic tool
|
||||
BaseAnthropicTool.__init__(self)
|
||||
|
||||
# Additional initialization
|
||||
self.width = None # Will be initialized from computer interface
|
||||
self.height = None # Will be initialized from computer interface
|
||||
self.display_num = None
|
||||
|
||||
async def initialize_dimensions(self):
|
||||
"""Initialize screen dimensions from the computer interface."""
|
||||
display_size = await self.computer.interface.get_screen_size()
|
||||
self.width = display_size["width"]
|
||||
self.height = display_size["height"]
|
||||
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
action: Action,
|
||||
text: str | None = None,
|
||||
coordinate: tuple[int, int] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
# Ensure dimensions are initialized
|
||||
if self.width is None or self.height is None:
|
||||
await self.initialize_dimensions()
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to initialize dimensions: {e}")
|
||||
|
||||
if action in ("mouse_move", "left_click_drag"):
|
||||
if coordinate is None:
|
||||
raise ToolError(f"coordinate is required for {action}")
|
||||
if text is not None:
|
||||
raise ToolError(f"text is not accepted for {action}")
|
||||
if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2:
|
||||
raise ToolError(f"{coordinate} must be a tuple of length 2")
|
||||
if not all(isinstance(i, int) and i >= 0 for i in coordinate):
|
||||
raise ToolError(f"{coordinate} must be a tuple of non-negative ints")
|
||||
|
||||
try:
|
||||
x, y = coordinate
|
||||
self.logger.info(f"Handling {action} action:")
|
||||
self.logger.info(f" Coordinates: ({x}, {y})")
|
||||
|
||||
# Take pre-action screenshot to get current dimensions
|
||||
pre_screenshot = await self.computer.interface.screenshot()
|
||||
pre_img = Image.open(io.BytesIO(pre_screenshot))
|
||||
|
||||
# Scale image to match screen dimensions if needed
|
||||
if pre_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling image from {pre_img.size} to {self.width}x{self.height} to match screen dimensions"
|
||||
)
|
||||
pre_img = pre_img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
|
||||
self.logger.info(f" Current dimensions: {pre_img.width}x{pre_img.height}")
|
||||
|
||||
if action == "mouse_move":
|
||||
self.logger.info(f"Moving cursor to ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
elif action == "left_click_drag":
|
||||
self.logger.info(f"Dragging from ({x}, {y})")
|
||||
# First move to the position
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
# Then perform drag operation - check if drag_to exists or we need to use other methods
|
||||
try:
|
||||
if hasattr(self.computer.interface, "drag_to"):
|
||||
await self.computer.interface.drag_to(x, y)
|
||||
else:
|
||||
# Alternative approach: press mouse down, move, release
|
||||
await self.computer.interface.mouse_down()
|
||||
await asyncio.sleep(0.2)
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await asyncio.sleep(0.2)
|
||||
await self.computer.interface.mouse_up()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during drag operation: {str(e)}")
|
||||
raise ToolError(f"Failed to perform drag: {str(e)}")
|
||||
|
||||
# Wait briefly for any UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take post-action screenshot
|
||||
post_screenshot = await self.computer.interface.screenshot()
|
||||
post_img = Image.open(io.BytesIO(post_screenshot))
|
||||
|
||||
# Scale post-action image if needed
|
||||
if post_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
post_img = post_img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
buffer = io.BytesIO()
|
||||
post_img.save(buffer, format="PNG")
|
||||
post_screenshot = buffer.getvalue()
|
||||
|
||||
return ToolResult(
|
||||
output=f"{'Moved cursor to' if action == 'mouse_move' else 'Dragged to'} {x},{y}",
|
||||
base64_image=base64.b64encode(post_screenshot).decode(),
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during {action} action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {action}: {str(e)}")
|
||||
|
||||
elif action in ("left_click", "right_click", "double_click"):
|
||||
if coordinate:
|
||||
x, y = coordinate
|
||||
self.logger.info(f"Handling {action} action:")
|
||||
self.logger.info(f" Coordinates: ({x}, {y})")
|
||||
|
||||
try:
|
||||
# Take pre-action screenshot to get current dimensions
|
||||
pre_screenshot = await self.computer.interface.screenshot()
|
||||
pre_img = Image.open(io.BytesIO(pre_screenshot))
|
||||
|
||||
# Scale image to match screen dimensions if needed
|
||||
if pre_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling image from {pre_img.size} to {self.width}x{self.height} to match screen dimensions"
|
||||
)
|
||||
pre_img = pre_img.resize(
|
||||
(self.width, self.height), Image.Resampling.LANCZOS
|
||||
)
|
||||
# Save the scaled image back to bytes
|
||||
buffer = io.BytesIO()
|
||||
pre_img.save(buffer, format="PNG")
|
||||
pre_screenshot = buffer.getvalue()
|
||||
|
||||
self.logger.info(f" Current dimensions: {pre_img.width}x{pre_img.height}")
|
||||
|
||||
# Perform the click action
|
||||
if action == "left_click":
|
||||
self.logger.info(f"Clicking at ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await self.computer.interface.left_click()
|
||||
elif action == "right_click":
|
||||
self.logger.info(f"Right clicking at ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await self.computer.interface.right_click()
|
||||
elif action == "double_click":
|
||||
self.logger.info(f"Double clicking at ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await self.computer.interface.double_click()
|
||||
|
||||
# Wait briefly for any UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take and save post-action screenshot
|
||||
post_screenshot = await self.computer.interface.screenshot()
|
||||
post_img = Image.open(io.BytesIO(post_screenshot))
|
||||
|
||||
# Scale post-action image if needed
|
||||
if post_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
post_img = post_img.resize(
|
||||
(self.width, self.height), Image.Resampling.LANCZOS
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
post_img.save(buffer, format="PNG")
|
||||
post_screenshot = buffer.getvalue()
|
||||
|
||||
return ToolResult(
|
||||
output=f"Performed {action} at ({x}, {y})",
|
||||
base64_image=base64.b64encode(post_screenshot).decode(),
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during {action} action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {action}: {str(e)}")
|
||||
else:
|
||||
try:
|
||||
# Take pre-action screenshot
|
||||
pre_screenshot = await self.computer.interface.screenshot()
|
||||
pre_img = Image.open(io.BytesIO(pre_screenshot))
|
||||
|
||||
# Scale image if needed
|
||||
if pre_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling image from {pre_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
pre_img = pre_img.resize(
|
||||
(self.width, self.height), Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
# Perform the click action
|
||||
if action == "left_click":
|
||||
self.logger.info("Performing left click at current position")
|
||||
await self.computer.interface.left_click()
|
||||
elif action == "right_click":
|
||||
self.logger.info("Performing right click at current position")
|
||||
await self.computer.interface.right_click()
|
||||
elif action == "double_click":
|
||||
self.logger.info("Performing double click at current position")
|
||||
await self.computer.interface.double_click()
|
||||
|
||||
# Wait briefly for any UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take post-action screenshot
|
||||
post_screenshot = await self.computer.interface.screenshot()
|
||||
post_img = Image.open(io.BytesIO(post_screenshot))
|
||||
|
||||
# Scale post-action image if needed
|
||||
if post_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
post_img = post_img.resize(
|
||||
(self.width, self.height), Image.Resampling.LANCZOS
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
post_img.save(buffer, format="PNG")
|
||||
post_screenshot = buffer.getvalue()
|
||||
|
||||
return ToolResult(
|
||||
output=f"Performed {action} at current position",
|
||||
base64_image=base64.b64encode(post_screenshot).decode(),
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during {action} action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {action}: {str(e)}")
|
||||
|
||||
elif action in ("key", "type"):
|
||||
if text is None:
|
||||
raise ToolError(f"text is required for {action}")
|
||||
if coordinate is not None:
|
||||
raise ToolError(f"coordinate is not accepted for {action}")
|
||||
if not isinstance(text, str):
|
||||
raise ToolError(f"{text} must be a string")
|
||||
|
||||
try:
|
||||
# Take pre-action screenshot
|
||||
pre_screenshot = await self.computer.interface.screenshot()
|
||||
pre_img = Image.open(io.BytesIO(pre_screenshot))
|
||||
|
||||
# Scale image if needed
|
||||
if pre_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling image from {pre_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
pre_img = pre_img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
|
||||
if action == "key":
|
||||
# Special handling for page up/down on macOS
|
||||
if text.lower() in ["pagedown", "page_down", "page down"]:
|
||||
self.logger.info("Converting page down to fn+down for macOS")
|
||||
await self.computer.interface.hotkey("fn", "down")
|
||||
output_text = "fn+down"
|
||||
elif text.lower() in ["pageup", "page_up", "page up"]:
|
||||
self.logger.info("Converting page up to fn+up for macOS")
|
||||
await self.computer.interface.hotkey("fn", "up")
|
||||
output_text = "fn+up"
|
||||
elif text == "fn+down":
|
||||
self.logger.info("Using fn+down combination")
|
||||
await self.computer.interface.hotkey("fn", "down")
|
||||
output_text = text
|
||||
elif text == "fn+up":
|
||||
self.logger.info("Using fn+up combination")
|
||||
await self.computer.interface.hotkey("fn", "up")
|
||||
output_text = text
|
||||
elif "+" in text:
|
||||
# Handle hotkey combinations
|
||||
keys = text.split("+")
|
||||
self.logger.info(f"Pressing hotkey combination: {text}")
|
||||
await self.computer.interface.hotkey(*keys)
|
||||
output_text = text
|
||||
else:
|
||||
# Handle single key press
|
||||
self.logger.info(f"Pressing key: {text}")
|
||||
try:
|
||||
await self.computer.interface.press(text)
|
||||
output_text = text
|
||||
except ValueError as e:
|
||||
raise ToolError(f"Invalid key: {text}. {str(e)}")
|
||||
|
||||
# Wait briefly for UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take post-action screenshot
|
||||
post_screenshot = await self.computer.interface.screenshot()
|
||||
post_img = Image.open(io.BytesIO(post_screenshot))
|
||||
|
||||
# Scale post-action image if needed
|
||||
if post_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
post_img = post_img.resize(
|
||||
(self.width, self.height), Image.Resampling.LANCZOS
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
post_img.save(buffer, format="PNG")
|
||||
post_screenshot = buffer.getvalue()
|
||||
|
||||
return ToolResult(
|
||||
output=f"Pressed key: {output_text}",
|
||||
base64_image=base64.b64encode(post_screenshot).decode(),
|
||||
)
|
||||
|
||||
elif action == "type":
|
||||
self.logger.info(f"Typing text: {text}")
|
||||
await self.computer.interface.type_text(text)
|
||||
|
||||
# Wait briefly for UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take post-action screenshot
|
||||
post_screenshot = await self.computer.interface.screenshot()
|
||||
post_img = Image.open(io.BytesIO(post_screenshot))
|
||||
|
||||
# Scale post-action image if needed
|
||||
if post_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
post_img = post_img.resize(
|
||||
(self.width, self.height), Image.Resampling.LANCZOS
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
post_img.save(buffer, format="PNG")
|
||||
post_screenshot = buffer.getvalue()
|
||||
|
||||
return ToolResult(
|
||||
output=f"Typed text: {text}",
|
||||
base64_image=base64.b64encode(post_screenshot).decode(),
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during {action} action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {action}: {str(e)}")
|
||||
|
||||
elif action in ("screenshot", "cursor_position"):
|
||||
if text is not None:
|
||||
raise ToolError(f"text is not accepted for {action}")
|
||||
if coordinate is not None:
|
||||
raise ToolError(f"coordinate is not accepted for {action}")
|
||||
|
||||
try:
|
||||
if action == "screenshot":
|
||||
# Take screenshot
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
img = Image.open(io.BytesIO(screenshot))
|
||||
|
||||
# Scale image if needed
|
||||
if img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling image from {img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
img = img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
screenshot = buffer.getvalue()
|
||||
|
||||
return ToolResult(base64_image=base64.b64encode(screenshot).decode())
|
||||
|
||||
elif action == "cursor_position":
|
||||
pos = await self.computer.interface.get_cursor_position()
|
||||
return ToolResult(output=f"X={int(pos[0])},Y={int(pos[1])}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during {action} action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {action}: {str(e)}")
|
||||
|
||||
elif action == "scroll":
|
||||
# Implement scroll action
|
||||
direction = kwargs.get("direction", "down")
|
||||
amount = kwargs.get("amount", 10)
|
||||
|
||||
if direction not in ["up", "down"]:
|
||||
raise ToolError(f"Invalid scroll direction: {direction}. Must be 'up' or 'down'.")
|
||||
|
||||
try:
|
||||
if direction == "down":
|
||||
# Scroll down (Page Down on macOS)
|
||||
self.logger.info(f"Scrolling down, amount: {amount}")
|
||||
# Use fn+down for page down on macOS
|
||||
for _ in range(amount):
|
||||
await self.computer.interface.hotkey("fn", "down")
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
# Scroll up (Page Up on macOS)
|
||||
self.logger.info(f"Scrolling up, amount: {amount}")
|
||||
# Use fn+up for page up on macOS
|
||||
for _ in range(amount):
|
||||
await self.computer.interface.hotkey("fn", "up")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Wait briefly for UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take post-action screenshot
|
||||
post_screenshot = await self.computer.interface.screenshot()
|
||||
post_img = Image.open(io.BytesIO(post_screenshot))
|
||||
|
||||
# Scale post-action image if needed
|
||||
if post_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
post_img = post_img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
buffer = io.BytesIO()
|
||||
post_img.save(buffer, format="PNG")
|
||||
post_screenshot = buffer.getvalue()
|
||||
|
||||
return ToolResult(
|
||||
output=f"Scrolled {direction} by {amount} steps",
|
||||
base64_image=base64.b64encode(post_screenshot).decode(),
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during scroll action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform scroll: {str(e)}")
|
||||
|
||||
raise ToolError(f"Invalid action: {action}")
|
||||
|
||||
async def screenshot(self):
|
||||
"""Take a screenshot and return it as a base64-encoded string."""
|
||||
try:
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
img = Image.open(io.BytesIO(screenshot))
|
||||
|
||||
# Scale image if needed
|
||||
if img.size != (self.width, self.height):
|
||||
self.logger.info(f"Scaling image from {img.size} to {self.width}x{self.height}")
|
||||
img = img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
screenshot = buffer.getvalue()
|
||||
|
||||
return ToolResult(base64_image=base64.b64encode(screenshot).decode())
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error taking screenshot: {str(e)}")
|
||||
return ToolResult(error=f"Failed to take screenshot: {str(e)}")
|
||||
|
||||
async def shell(self, command: str, take_screenshot=False) -> ToolResult:
|
||||
"""Run a shell command and return the output, error, and optionally a screenshot."""
|
||||
try:
|
||||
_, stdout, stderr = await run(command)
|
||||
base64_image = None
|
||||
|
||||
if take_screenshot:
|
||||
# delay to let things settle before taking a screenshot
|
||||
await asyncio.sleep(self._screenshot_delay)
|
||||
screenshot_result = await self.screenshot()
|
||||
if screenshot_result.error:
|
||||
return ToolResult(
|
||||
output=stdout,
|
||||
error=f"{stderr}\nScreenshot error: {screenshot_result.error}",
|
||||
)
|
||||
base64_image = screenshot_result.base64_image
|
||||
|
||||
return ToolResult(output=stdout, error=stderr, base64_image=base64_image)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult(error=f"Shell command failed: {str(e)}")
|
||||
326
libs/agent/agent/providers/anthropic/tools/edit.py
Normal file
326
libs/agent/agent/providers/anthropic/tools/edit.py
Normal file
@@ -0,0 +1,326 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Literal, get_args, Dict, Any
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
||||
from ....core.tools.edit import BaseEditTool
|
||||
from .run import maybe_truncate
|
||||
|
||||
Command = Literal[
|
||||
"view",
|
||||
"create",
|
||||
"str_replace",
|
||||
"insert",
|
||||
"undo_edit",
|
||||
]
|
||||
SNIPPET_LINES: int = 4
|
||||
|
||||
|
||||
class EditTool(BaseEditTool, BaseAnthropicTool):
|
||||
"""
|
||||
An filesystem editor tool that allows the agent to view, create, and edit files.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
api_type: Literal["text_editor_20250124"] = "text_editor_20250124"
|
||||
name: Literal["str_replace_editor"] = "str_replace_editor"
|
||||
_timeout: float = 30.0 # seconds
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the edit tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for file operations
|
||||
"""
|
||||
# Initialize the base edit tool first
|
||||
BaseEditTool.__init__(self, computer)
|
||||
# Then initialize the Anthropic tool
|
||||
BaseAnthropicTool.__init__(self)
|
||||
|
||||
# Edit history for the current session
|
||||
self.edit_history = defaultdict(list)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
command: Command,
|
||||
path: str,
|
||||
file_text: str | None = None,
|
||||
view_range: list[int] | None = None,
|
||||
old_str: str | None = None,
|
||||
new_str: str | None = None,
|
||||
insert_line: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
_path = Path(path)
|
||||
await self.validate_path(command, _path)
|
||||
|
||||
if command == "view":
|
||||
return await self.view(_path, view_range)
|
||||
elif command == "create":
|
||||
if file_text is None:
|
||||
raise ToolError("Parameter `file_text` is required for command: create")
|
||||
await self.write_file(_path, file_text)
|
||||
self.edit_history[_path].append(file_text)
|
||||
return ToolResult(output=f"File created successfully at: {_path}")
|
||||
elif command == "str_replace":
|
||||
if old_str is None:
|
||||
raise ToolError("Parameter `old_str` is required for command: str_replace")
|
||||
return await self.str_replace(_path, old_str, new_str)
|
||||
elif command == "insert":
|
||||
if insert_line is None:
|
||||
raise ToolError("Parameter `insert_line` is required for command: insert")
|
||||
if new_str is None:
|
||||
raise ToolError("Parameter `new_str` is required for command: insert")
|
||||
return await self.insert(_path, insert_line, new_str)
|
||||
elif command == "undo_edit":
|
||||
return await self.undo_edit(_path)
|
||||
|
||||
raise ToolError(
|
||||
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
|
||||
)
|
||||
|
||||
async def validate_path(self, command: str, path: Path):
|
||||
"""Check that the path/command combination is valid."""
|
||||
# Check if its an absolute path
|
||||
if not path.is_absolute():
|
||||
suggested_path = Path("") / path
|
||||
raise ToolError(
|
||||
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
|
||||
)
|
||||
|
||||
# Check if path exists using bash commands
|
||||
try:
|
||||
result = await self.computer.interface.run_command(
|
||||
f'[ -e "{str(path)}" ] && echo "exists" || echo "not exists"'
|
||||
)
|
||||
exists = result[0].strip() == "exists"
|
||||
|
||||
if exists:
|
||||
result = await self.computer.interface.run_command(
|
||||
f'[ -d "{str(path)}" ] && echo "dir" || echo "file"'
|
||||
)
|
||||
is_dir = result[0].strip() == "dir"
|
||||
else:
|
||||
is_dir = False
|
||||
|
||||
# Check path validity
|
||||
if not exists and command != "create":
|
||||
raise ToolError(f"The path {path} does not exist. Please provide a valid path.")
|
||||
if exists and command == "create":
|
||||
raise ToolError(
|
||||
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
||||
)
|
||||
if is_dir and command != "view":
|
||||
raise ToolError(
|
||||
f"The path {path} is a directory and only the `view` command can be used on directories"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to validate path: {str(e)}")
|
||||
|
||||
async def view(self, path: Path, view_range: list[int] | None = None):
|
||||
"""Implement the view command"""
|
||||
try:
|
||||
# Check if path is a directory
|
||||
result = await self.computer.interface.run_command(
|
||||
f'[ -d "{str(path)}" ] && echo "dir" || echo "file"'
|
||||
)
|
||||
is_dir = result[0].strip() == "dir"
|
||||
|
||||
if is_dir:
|
||||
if view_range:
|
||||
raise ToolError(
|
||||
"The `view_range` parameter is not allowed when `path` points to a directory."
|
||||
)
|
||||
|
||||
# List directory contents using ls
|
||||
result = await self.computer.interface.run_command(f'ls -la "{str(path)}"')
|
||||
contents = result[0]
|
||||
if contents:
|
||||
stdout = f"Here's the files and directories in {path}:\n{contents}\n"
|
||||
else:
|
||||
stdout = f"Directory {path} is empty\n"
|
||||
return CLIResult(output=stdout)
|
||||
|
||||
# Read file content using cat
|
||||
file_content = await self.read_file(path)
|
||||
init_line = 1
|
||||
|
||||
if view_range:
|
||||
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
|
||||
raise ToolError("Invalid `view_range`. It should be a list of two integers.")
|
||||
|
||||
file_lines = file_content.split("\n")
|
||||
n_lines_file = len(file_lines)
|
||||
init_line, final_line = view_range
|
||||
|
||||
if init_line < 1 or init_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
|
||||
)
|
||||
if final_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
|
||||
)
|
||||
if final_line != -1 and final_line < init_line:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
|
||||
)
|
||||
|
||||
if final_line == -1:
|
||||
file_content = "\n".join(file_lines[init_line - 1 :])
|
||||
else:
|
||||
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
||||
|
||||
return CLIResult(output=self._make_output(file_content, str(path), init_line=init_line))
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to view path: {str(e)}")
|
||||
|
||||
async def str_replace(self, path: Path, old_str: str, new_str: str | None):
|
||||
"""Implement the str_replace command"""
|
||||
# Read the file content
|
||||
file_content = await self.read_file(path)
|
||||
file_content = file_content.expandtabs()
|
||||
old_str = old_str.expandtabs()
|
||||
new_str = new_str.expandtabs() if new_str is not None else ""
|
||||
|
||||
# Check if old_str is unique in the file
|
||||
occurrences = file_content.count(old_str)
|
||||
if occurrences == 0:
|
||||
raise ToolError(
|
||||
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
||||
)
|
||||
elif occurrences > 1:
|
||||
file_content_lines = file_content.split("\n")
|
||||
lines = [idx + 1 for idx, line in enumerate(file_content_lines) if old_str in line]
|
||||
raise ToolError(
|
||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
|
||||
)
|
||||
|
||||
# Replace old_str with new_str
|
||||
new_file_content = file_content.replace(old_str, new_str)
|
||||
|
||||
# Write the new content to the file
|
||||
await self.write_file(path, new_file_content)
|
||||
|
||||
# Save the content to history
|
||||
self.edit_history[path].append(file_content)
|
||||
|
||||
# Create a snippet of the edited section
|
||||
replacement_line = file_content.split(old_str)[0].count("\n")
|
||||
start_line = max(0, replacement_line - SNIPPET_LINES)
|
||||
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
|
||||
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(snippet, f"a snippet of {path}", start_line + 1)
|
||||
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
|
||||
|
||||
return CLIResult(output=success_msg)
|
||||
|
||||
async def insert(self, path: Path, insert_line: int, new_str: str):
|
||||
"""Implement the insert command"""
|
||||
file_text = await self.read_file(path)
|
||||
file_text = file_text.expandtabs()
|
||||
new_str = new_str.expandtabs()
|
||||
file_text_lines = file_text.split("\n")
|
||||
n_lines_file = len(file_text_lines)
|
||||
|
||||
if insert_line < 0 or insert_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
|
||||
)
|
||||
|
||||
new_str_lines = new_str.split("\n")
|
||||
new_file_text_lines = (
|
||||
file_text_lines[:insert_line] + new_str_lines + file_text_lines[insert_line:]
|
||||
)
|
||||
snippet_lines = (
|
||||
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||
)
|
||||
|
||||
new_file_text = "\n".join(new_file_text_lines)
|
||||
snippet = "\n".join(snippet_lines)
|
||||
|
||||
await self.write_file(path, new_file_text)
|
||||
self.edit_history[path].append(file_text)
|
||||
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(
|
||||
snippet, "a snippet of the edited file", max(1, insert_line - SNIPPET_LINES + 1)
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
||||
return CLIResult(output=success_msg)
|
||||
|
||||
async def undo_edit(self, path: Path):
|
||||
"""Implement the undo_edit command"""
|
||||
if not self.edit_history[path]:
|
||||
raise ToolError(f"No edit history found for {path}.")
|
||||
|
||||
old_text = self.edit_history[path].pop()
|
||||
await self.write_file(path, old_text)
|
||||
|
||||
return CLIResult(
|
||||
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
|
||||
)
|
||||
|
||||
async def read_file(self, path: Path) -> str:
|
||||
"""Read the content of a file using cat command."""
|
||||
try:
|
||||
result = await self.computer.interface.run_command(f'cat "{str(path)}"')
|
||||
if result[1]: # If there's stderr output
|
||||
raise ToolError(f"Error reading file: {result[1]}")
|
||||
return result[0]
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to read {path}: {str(e)}")
|
||||
|
||||
async def write_file(self, path: Path, content: str):
|
||||
"""Write content to a file using echo and redirection."""
|
||||
try:
|
||||
# Create parent directories if they don't exist
|
||||
parent = path.parent
|
||||
if parent != Path("/"):
|
||||
await self.computer.interface.run_command(f'mkdir -p "{str(parent)}"')
|
||||
|
||||
# Write content to file using echo and heredoc to preserve formatting
|
||||
cmd = f"""cat > "{str(path)}" << 'EOFCUA'
|
||||
{content}
|
||||
EOFCUA"""
|
||||
result = await self.computer.interface.run_command(cmd)
|
||||
if result[1]: # If there's stderr output
|
||||
raise ToolError(f"Error writing file: {result[1]}")
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to write to {path}: {str(e)}")
|
||||
|
||||
def _make_output(
|
||||
self,
|
||||
file_content: str,
|
||||
file_descriptor: str,
|
||||
init_line: int = 1,
|
||||
expand_tabs: bool = True,
|
||||
) -> str:
|
||||
"""Generate output for the CLI based on the content of a file."""
|
||||
file_content = maybe_truncate(file_content)
|
||||
if expand_tabs:
|
||||
file_content = file_content.expandtabs()
|
||||
file_content = "\n".join(
|
||||
[f"{i + init_line:6}\t{line}" for i, line in enumerate(file_content.split("\n"))]
|
||||
)
|
||||
return (
|
||||
f"Here's the result of running `cat -n` on {file_descriptor}:\n" + file_content + "\n"
|
||||
)
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": self.api_type,
|
||||
}
|
||||
54
libs/agent/agent/providers/anthropic/tools/manager.py
Normal file
54
libs/agent/agent/providers/anthropic/tools/manager.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Any, Dict, List
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
from computer.computer import Computer
|
||||
|
||||
from ....core.tools import BaseToolManager, ToolResult
|
||||
from ....core.tools.collection import ToolCollection
|
||||
|
||||
from .bash import BashTool
|
||||
from .computer import ComputerTool
|
||||
from .edit import EditTool
|
||||
|
||||
|
||||
class ToolManager(BaseToolManager):
|
||||
"""Manages Anthropic-specific tool initialization and execution."""
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for computer-related tools
|
||||
"""
|
||||
super().__init__(computer)
|
||||
# Initialize Anthropic-specific tools
|
||||
self.computer_tool = ComputerTool(self.computer)
|
||||
self.bash_tool = BashTool(self.computer)
|
||||
self.edit_tool = EditTool(self.computer)
|
||||
|
||||
def _initialize_tools(self) -> ToolCollection:
|
||||
"""Initialize all available tools."""
|
||||
return ToolCollection(self.computer_tool, self.bash_tool, self.edit_tool)
|
||||
|
||||
async def _initialize_tools_specific(self) -> None:
|
||||
"""Initialize Anthropic-specific tool requirements."""
|
||||
await self.computer_tool.initialize_dimensions()
|
||||
|
||||
def get_tool_params(self) -> List[BetaToolUnionParam]:
|
||||
"""Get tool parameters for Anthropic API calls."""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return self.tools.to_params()
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to execute
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return await self.tools.run(name=name, tool_input=tool_input)
|
||||
42
libs/agent/agent/providers/anthropic/tools/run.py
Normal file
42
libs/agent/agent/providers/anthropic/tools/run.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Utility to run shell commands asynchronously with a timeout."""
|
||||
|
||||
import asyncio
|
||||
|
||||
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
||||
MAX_RESPONSE_LEN: int = 16000
|
||||
|
||||
|
||||
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
|
||||
"""Truncate content and append a notice if content exceeds the specified length."""
|
||||
return (
|
||||
content
|
||||
if not truncate_after or len(content) <= truncate_after
|
||||
else content[:truncate_after] + TRUNCATED_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
async def run(
|
||||
cmd: str,
|
||||
timeout: float | None = 120.0, # seconds
|
||||
truncate_after: int | None = MAX_RESPONSE_LEN,
|
||||
):
|
||||
"""Run a shell command asynchronously with a timeout."""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
return (
|
||||
process.returncode or 0,
|
||||
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
|
||||
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||
) from exc
|
||||
16
libs/agent/agent/providers/anthropic/types.py
Normal file
16
libs/agent/agent/providers/anthropic/types.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class LLMProvider(StrEnum):
|
||||
"""Enum for supported API providers."""
|
||||
|
||||
ANTHROPIC = "anthropic"
|
||||
BEDROCK = "bedrock"
|
||||
VERTEX = "vertex"
|
||||
|
||||
|
||||
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[LLMProvider, str] = {
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
LLMProvider.BEDROCK: "anthropic.claude-3-7-sonnet-20250219-v2:0",
|
||||
LLMProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
|
||||
}
|
||||
27
libs/agent/agent/providers/omni/__init__.py
Normal file
27
libs/agent/agent/providers/omni/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Omni provider implementation."""
|
||||
|
||||
# The OmniComputerAgent has been replaced by the unified ComputerAgent
|
||||
# which can be found in agent.core.agent
|
||||
from .types import LLMProvider
|
||||
from .experiment import ExperimentManager
|
||||
from .visualization import visualize_click, visualize_scroll, calculate_element_center
|
||||
from .image_utils import (
|
||||
decode_base64_image,
|
||||
encode_image_base64,
|
||||
clean_base64_data,
|
||||
extract_base64_from_text,
|
||||
get_image_dimensions,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"ExperimentManager",
|
||||
"visualize_click",
|
||||
"visualize_scroll",
|
||||
"calculate_element_center",
|
||||
"decode_base64_image",
|
||||
"encode_image_base64",
|
||||
"clean_base64_data",
|
||||
"extract_base64_from_text",
|
||||
"get_image_dimensions",
|
||||
]
|
||||
78
libs/agent/agent/providers/omni/callbacks.py
Normal file
78
libs/agent/agent/providers/omni/callbacks.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Omni callback manager implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from ...core.callbacks import BaseCallbackManager, ContentCallback, ToolCallback, APICallback
|
||||
from ...types.tools import ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OmniCallbackManager(BaseCallbackManager):
|
||||
"""Callback manager for multi-provider support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_callback: ContentCallback,
|
||||
tool_callback: ToolCallback,
|
||||
api_callback: APICallback,
|
||||
):
|
||||
"""Initialize Omni callback manager.
|
||||
|
||||
Args:
|
||||
content_callback: Callback for content updates
|
||||
tool_callback: Callback for tool execution results
|
||||
api_callback: Callback for API interactions
|
||||
"""
|
||||
super().__init__(
|
||||
content_callback=content_callback,
|
||||
tool_callback=tool_callback,
|
||||
api_callback=api_callback
|
||||
)
|
||||
self._active_tools: Set[str] = set()
|
||||
|
||||
def on_content(self, content: Any) -> None:
|
||||
"""Handle content updates.
|
||||
|
||||
Args:
|
||||
content: Content update data
|
||||
"""
|
||||
logger.debug(f"Content update: {content}")
|
||||
self.content_callback(content)
|
||||
|
||||
def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
|
||||
"""Handle tool execution results.
|
||||
|
||||
Args:
|
||||
result: Tool execution result
|
||||
tool_id: ID of the tool
|
||||
"""
|
||||
logger.debug(f"Tool result for {tool_id}: {result}")
|
||||
self.tool_callback(result, tool_id)
|
||||
|
||||
def on_api_interaction(
|
||||
self,
|
||||
request: Any,
|
||||
response: Any,
|
||||
error: Optional[Exception] = None
|
||||
) -> None:
|
||||
"""Handle API interactions.
|
||||
|
||||
Args:
|
||||
request: API request data
|
||||
response: API response data
|
||||
error: Optional error that occurred
|
||||
"""
|
||||
if error:
|
||||
logger.error(f"API error: {str(error)}")
|
||||
else:
|
||||
logger.debug(f"API interaction - Request: {request}, Response: {response}")
|
||||
self.api_callback(request, response, error)
|
||||
|
||||
def get_active_tools(self) -> Set[str]:
|
||||
"""Get currently active tools.
|
||||
|
||||
Returns:
|
||||
Set of active tool names
|
||||
"""
|
||||
return self._active_tools.copy()
|
||||
99
libs/agent/agent/providers/omni/clients/anthropic.py
Normal file
99
libs/agent/agent/providers/omni/clients/anthropic.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Anthropic API client implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
import asyncio
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
|
||||
from anthropic import AsyncAnthropic, Anthropic
|
||||
from anthropic.types import MessageParam
|
||||
from .base import BaseOmniClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicClient(BaseOmniClient):
|
||||
"""Client for making calls to Anthropic API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0):
|
||||
"""Initialize the Anthropic client.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
model: Anthropic model name (e.g. "claude-3-opus-20240229")
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Base delay between retries in seconds
|
||||
"""
|
||||
if not model:
|
||||
raise ValueError("Model name must be provided")
|
||||
|
||||
self.client = AsyncAnthropic(api_key=api_key)
|
||||
self.model: str = model # Add explicit type annotation
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
def _convert_message_format(self, messages: List[Dict[str, Any]]) -> List[MessageParam]:
|
||||
"""Convert messages from standard format to Anthropic format.
|
||||
|
||||
Args:
|
||||
messages: Messages in standard format
|
||||
|
||||
Returns:
|
||||
Messages in Anthropic format
|
||||
"""
|
||||
anthropic_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
anthropic_messages.append({"role": "user", "content": message["content"]})
|
||||
elif message["role"] == "assistant":
|
||||
anthropic_messages.append({"role": "assistant", "content": message["content"]})
|
||||
|
||||
# Cast the list to the correct type expected by Anthropic
|
||||
return cast(List[MessageParam], anthropic_messages)
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: int
|
||||
) -> Any:
|
||||
"""Run model with interleaved conversation format.
|
||||
|
||||
Args:
|
||||
messages: List of messages to process
|
||||
system: System prompt
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Convert messages to Anthropic format
|
||||
anthropic_messages = self._convert_message_format(messages)
|
||||
|
||||
response = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0,
|
||||
system=system,
|
||||
messages=anthropic_messages,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except (ConnectError, ReadTimeout) as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in Anthropic API call: {str(e)}")
|
||||
raise RuntimeError(f"Anthropic API call failed: {str(e)}")
|
||||
|
||||
# If we get here, all retries failed
|
||||
raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}")
|
||||
44
libs/agent/agent/providers/omni/clients/base.py
Normal file
44
libs/agent/agent/providers/omni/clients/base.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Base client implementation for Omni providers."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseOmniClient:
|
||||
"""Base class for provider-specific clients."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None
|
||||
):
|
||||
"""Initialize base client.
|
||||
|
||||
Args:
|
||||
api_key: Optional API key
|
||||
model: Optional model name
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
|
||||
async def run_interleaved(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
system: str,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
101
libs/agent/agent/providers/omni/clients/groq.py
Normal file
101
libs/agent/agent/providers/omni/clients/groq.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Groq client implementation."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
from groq import Groq
|
||||
import re
|
||||
from .utils import is_image_path
|
||||
from .base import BaseOmniClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GroqClient(BaseOmniClient):
|
||||
"""Client for making Groq API calls."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "deepseek-r1-distill-llama-70b",
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.6,
|
||||
):
|
||||
"""Initialize Groq client.
|
||||
|
||||
Args:
|
||||
api_key: Groq API key (if not provided, will try to get from env)
|
||||
model: Model name to use
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Temperature for sampling
|
||||
"""
|
||||
super().__init__(api_key=api_key, model=model)
|
||||
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("No Groq API key provided")
|
||||
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
self.client = Groq(api_key=self.api_key)
|
||||
self.model: str = model # Add explicit type annotation
|
||||
|
||||
def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> tuple[str, int]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Tuple of (response text, token usage)
|
||||
"""
|
||||
# Avoid using system messages for R1
|
||||
final_messages = [{"role": "user", "content": system}]
|
||||
|
||||
# Process messages
|
||||
if isinstance(messages, list):
|
||||
for item in messages:
|
||||
if isinstance(item, dict):
|
||||
# For dict items, concatenate all text content, ignoring images
|
||||
text_contents = []
|
||||
for cnt in item["content"]:
|
||||
if isinstance(cnt, str):
|
||||
if not is_image_path(cnt): # Skip image paths
|
||||
text_contents.append(cnt)
|
||||
else:
|
||||
text_contents.append(str(cnt))
|
||||
|
||||
if text_contents: # Only add if there's text content
|
||||
message = {"role": "user", "content": " ".join(text_contents)}
|
||||
final_messages.append(message)
|
||||
else: # str
|
||||
message = {"role": "user", "content": item}
|
||||
final_messages.append(message)
|
||||
|
||||
elif isinstance(messages, str):
|
||||
final_messages.append({"role": "user", "content": messages})
|
||||
|
||||
try:
|
||||
completion = self.client.chat.completions.create( # type: ignore
|
||||
model=self.model,
|
||||
messages=final_messages, # type: ignore
|
||||
temperature=self.temperature,
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
top_p=0.95,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
response = completion.choices[0].message.content
|
||||
final_answer = response.split("</think>\n")[-1] if "</think>" in response else response
|
||||
final_answer = final_answer.replace("<output>", "").replace("</output>", "")
|
||||
token_usage = completion.usage.total_tokens
|
||||
|
||||
return final_answer, token_usage
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Groq API call: {e}")
|
||||
raise
|
||||
159
libs/agent/agent/providers/omni/clients/openai.py
Normal file
159
libs/agent/agent/providers/omni/clients/openai.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""OpenAI client implementation."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
import aiohttp
|
||||
import base64
|
||||
import re
|
||||
import json
|
||||
import ssl
|
||||
import certifi
|
||||
from datetime import datetime
|
||||
from .base import BaseOmniClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI specific client for the OmniLoop
|
||||
|
||||
|
||||
class OpenAIClient(BaseOmniClient):
|
||||
"""OpenAI vision API client implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "gpt-4o",
|
||||
provider_base_url: str = "https://api.openai.com/v1",
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.0,
|
||||
):
|
||||
"""Initialize the OpenAI client.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model to use
|
||||
provider_base_url: API endpoint
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Generation temperature
|
||||
"""
|
||||
super().__init__(api_key=api_key, model=model)
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("No OpenAI API key provided")
|
||||
|
||||
self.model = model
|
||||
self.provider_base_url = provider_base_url
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
|
||||
def _extract_base64_image(self, text: str) -> Optional[str]:
|
||||
"""Extract base64 image data from an HTML img tag."""
|
||||
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
||||
match = re.search(pattern, text)
|
||||
return match.group(1) if match else None
|
||||
|
||||
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Create a loggable version of messages with image data truncated."""
|
||||
loggable_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.get("content"), list):
|
||||
new_content = []
|
||||
for content in msg["content"]:
|
||||
if content.get("type") == "image":
|
||||
new_content.append(
|
||||
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
||||
)
|
||||
else:
|
||||
new_content.append(content)
|
||||
loggable_messages.append({"role": msg["role"], "content": new_content})
|
||||
else:
|
||||
loggable_messages.append(msg)
|
||||
return loggable_messages
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
final_messages = [{"role": "system", "content": system}]
|
||||
|
||||
# Process messages
|
||||
for item in messages:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item["content"], list):
|
||||
# Content is already in the correct format
|
||||
final_messages.append(item)
|
||||
else:
|
||||
# Single string content, check for image
|
||||
base64_img = self._extract_base64_image(item["content"])
|
||||
if base64_img:
|
||||
message = {
|
||||
"role": item["role"],
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
message = {
|
||||
"role": item["role"],
|
||||
"content": [{"type": "text", "text": item["content"]}],
|
||||
}
|
||||
final_messages.append(message)
|
||||
else:
|
||||
# String content, check for image
|
||||
base64_img = self._extract_base64_image(item)
|
||||
if base64_img:
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
||||
final_messages.append(message)
|
||||
|
||||
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
|
||||
|
||||
if "o1" in self.model or "o3-mini" in self.model:
|
||||
payload["reasoning_effort"] = "low"
|
||||
payload["max_completion_tokens"] = max_tokens or self.max_tokens
|
||||
else:
|
||||
payload["max_tokens"] = max_tokens or self.max_tokens
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.provider_base_url}/chat/completions", headers=headers, json=payload
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_json.get("error", {}).get(
|
||||
"message", str(response_json)
|
||||
)
|
||||
logger.error(f"Error in OpenAI API call: {error_msg}")
|
||||
raise Exception(f"OpenAI API error: {error_msg}")
|
||||
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in OpenAI API call: {str(e)}")
|
||||
raise
|
||||
25
libs/agent/agent/providers/omni/clients/utils.py
Normal file
25
libs/agent/agent/providers/omni/clients/utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import base64
|
||||
|
||||
def is_image_path(text: str) -> bool:
|
||||
"""Check if a text string is an image file path.
|
||||
|
||||
Args:
|
||||
text: Text string to check
|
||||
|
||||
Returns:
|
||||
True if text ends with image extension, False otherwise
|
||||
"""
|
||||
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
|
||||
return text.endswith(image_extensions)
|
||||
|
||||
def encode_image(image_path: str) -> str:
|
||||
"""Encode image file to base64.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
Base64 encoded image string
|
||||
"""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
273
libs/agent/agent/providers/omni/experiment.py
Normal file
273
libs/agent/agent/providers/omni/experiment.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Experiment management for the Cua provider."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from PIL import Image
|
||||
import json
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExperimentManager:
|
||||
"""Manages experiment directories and logging for the agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: Optional[str] = None,
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the experiment manager.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for saving experiment data
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
"""
|
||||
self.base_dir = base_dir
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self.run_dir = None
|
||||
self.current_turn_dir = None
|
||||
self.turn_count = 0
|
||||
self.screenshot_count = 0
|
||||
# Track all screenshots for potential API request inclusion
|
||||
self.screenshot_paths = []
|
||||
|
||||
# Set up experiment directories if base_dir is provided
|
||||
if self.base_dir:
|
||||
self.setup_experiment_dirs()
|
||||
|
||||
def setup_experiment_dirs(self) -> None:
|
||||
"""Setup the experiment directory structure."""
|
||||
if not self.base_dir:
|
||||
return
|
||||
|
||||
# Create base experiments directory if it doesn't exist
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
|
||||
# Use the base_dir directly as the run_dir
|
||||
self.run_dir = self.base_dir
|
||||
logger.info(f"Using directory for experiment: {self.run_dir}")
|
||||
|
||||
# Create first turn directory
|
||||
self.create_turn_dir()
|
||||
|
||||
def create_turn_dir(self) -> None:
|
||||
"""Create a new directory for the current turn."""
|
||||
if not self.run_dir:
|
||||
return
|
||||
|
||||
self.turn_count += 1
|
||||
self.current_turn_dir = os.path.join(self.run_dir, f"turn_{self.turn_count:03d}")
|
||||
os.makedirs(self.current_turn_dir, exist_ok=True)
|
||||
logger.info(f"Created turn directory: {self.current_turn_dir}")
|
||||
|
||||
def sanitize_log_data(self, data: Any) -> Any:
|
||||
"""Sanitize data for logging by removing large base64 strings.
|
||||
|
||||
Args:
|
||||
data: Data to sanitize (dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Sanitized copy of the data
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
result = copy.deepcopy(data)
|
||||
|
||||
# Handle nested dictionaries and lists
|
||||
for key, value in result.items():
|
||||
# Process content arrays that contain image data
|
||||
if key == "content" and isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
# Handle Anthropic format
|
||||
if item.get("type") == "image" and isinstance(item.get("source"), dict):
|
||||
source = item["source"]
|
||||
if "data" in source and isinstance(source["data"], str):
|
||||
# Replace base64 data with a placeholder and length info
|
||||
data_len = len(source["data"])
|
||||
source["data"] = f"[BASE64_IMAGE_DATA_LENGTH_{data_len}]"
|
||||
|
||||
# Handle OpenAI format
|
||||
elif item.get("type") == "image_url" and isinstance(
|
||||
item.get("image_url"), dict
|
||||
):
|
||||
url_dict = item["image_url"]
|
||||
if "url" in url_dict and isinstance(url_dict["url"], str):
|
||||
url = url_dict["url"]
|
||||
if url.startswith("data:"):
|
||||
# Replace base64 data with placeholder
|
||||
data_len = len(url)
|
||||
url_dict["url"] = f"[BASE64_IMAGE_URL_LENGTH_{data_len}]"
|
||||
|
||||
# Handle other nested structures recursively
|
||||
if isinstance(value, dict):
|
||||
result[key] = self.sanitize_log_data(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [self.sanitize_log_data(item) for item in value]
|
||||
|
||||
return result
|
||||
elif isinstance(data, list):
|
||||
return [self.sanitize_log_data(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
def save_debug_image(self, image_data: str, filename: str) -> None:
|
||||
"""Save a debug image to the experiment directory.
|
||||
|
||||
Args:
|
||||
image_data: Base64 encoded image data
|
||||
filename: Filename to save the image as
|
||||
"""
|
||||
# Since we no longer want to use the images/ folder, we'll skip this functionality
|
||||
return
|
||||
|
||||
def save_screenshot(self, img_base64: str, action_type: str = "") -> None:
|
||||
"""Save a screenshot to the experiment directory.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return
|
||||
|
||||
try:
|
||||
# Increment screenshot counter
|
||||
self.screenshot_count += 1
|
||||
|
||||
# Create a descriptive filename
|
||||
timestamp = int(time.time() * 1000)
|
||||
action_suffix = f"_{action_type}" if action_type else ""
|
||||
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
|
||||
|
||||
# Save directly to the turn directory (no screenshots subdirectory)
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Save the screenshot
|
||||
img_data = base64.b64decode(img_base64)
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(img_data)
|
||||
|
||||
# Keep track of the file path for reference
|
||||
self.screenshot_paths.append(filepath)
|
||||
|
||||
return filepath
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving screenshot: {str(e)}")
|
||||
return None
|
||||
|
||||
def should_save_debug_image(self) -> bool:
|
||||
"""Determine if debug images should be saved.
|
||||
|
||||
Returns:
|
||||
Boolean indicating if debug images should be saved
|
||||
"""
|
||||
# We no longer need to save debug images, so always return False
|
||||
return False
|
||||
|
||||
def save_action_visualization(
|
||||
self, img: Image.Image, action_name: str, details: str = ""
|
||||
) -> str:
|
||||
"""Save a visualization of an action.
|
||||
|
||||
Args:
|
||||
img: Image to save
|
||||
action_name: Name of the action
|
||||
details: Additional details about the action
|
||||
|
||||
Returns:
|
||||
Path to the saved image
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Create a descriptive filename
|
||||
timestamp = int(time.time() * 1000)
|
||||
details_suffix = f"_{details}" if details else ""
|
||||
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
|
||||
|
||||
# Save directly to the turn directory (no visualizations subdirectory)
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Save the image
|
||||
img.save(filepath)
|
||||
|
||||
# Keep track of the file path for cleanup
|
||||
self.screenshot_paths.append(filepath)
|
||||
|
||||
return filepath
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving action visualization: {str(e)}")
|
||||
return ""
|
||||
|
||||
def extract_and_save_images(self, data: Any, prefix: str) -> None:
|
||||
"""Extract and save images from response data.
|
||||
|
||||
Args:
|
||||
data: Response data to extract images from
|
||||
prefix: Prefix for saved image filenames
|
||||
"""
|
||||
# Since we no longer want to save extracted images separately,
|
||||
# we'll skip this functionality entirely
|
||||
return
|
||||
|
||||
def log_api_call(
|
||||
self,
|
||||
call_type: str,
|
||||
request: Any,
|
||||
provider: str,
|
||||
model: str,
|
||||
response: Any = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> None:
|
||||
"""Log API call details to file.
|
||||
|
||||
Args:
|
||||
call_type: Type of API call (e.g., 'request', 'response', 'error')
|
||||
request: The API request data
|
||||
provider: The AI provider used
|
||||
model: The AI model used
|
||||
response: Optional API response data
|
||||
error: Optional error information
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return
|
||||
|
||||
try:
|
||||
# Create a unique filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"api_call_{timestamp}_{call_type}.json"
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Sanitize data to remove large base64 strings
|
||||
sanitized_request = self.sanitize_log_data(request)
|
||||
sanitized_response = self.sanitize_log_data(response) if response is not None else None
|
||||
|
||||
# Prepare log data
|
||||
log_data = {
|
||||
"timestamp": timestamp,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"type": call_type,
|
||||
"request": sanitized_request,
|
||||
}
|
||||
|
||||
if sanitized_response is not None:
|
||||
log_data["response"] = sanitized_response
|
||||
if error is not None:
|
||||
log_data["error"] = str(error)
|
||||
|
||||
# Write to file
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(log_data, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Logged API {call_type} to {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging API call: {str(e)}")
|
||||
106
libs/agent/agent/providers/omni/image_utils.py
Normal file
106
libs/agent/agent/providers/omni/image_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Image processing utilities for the Cua provider."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def decode_base64_image(img_base64: str) -> Optional[Image.Image]:
|
||||
"""Decode a base64 encoded image to a PIL Image.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded image, may include data URL prefix
|
||||
|
||||
Returns:
|
||||
PIL Image or None if decoding fails
|
||||
"""
|
||||
try:
|
||||
# Remove data URL prefix if present
|
||||
if img_base64.startswith("data:image"):
|
||||
img_base64 = img_base64.split(",")[1]
|
||||
|
||||
# Decode base64 to bytes
|
||||
img_data = base64.b64decode(img_base64)
|
||||
|
||||
# Convert bytes to PIL Image
|
||||
return Image.open(BytesIO(img_data))
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding base64 image: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def encode_image_base64(img: Image.Image, format: str = "PNG") -> str:
|
||||
"""Encode a PIL Image to base64.
|
||||
|
||||
Args:
|
||||
img: PIL Image to encode
|
||||
format: Image format (PNG, JPEG, etc.)
|
||||
|
||||
Returns:
|
||||
Base64 encoded image string
|
||||
"""
|
||||
try:
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format=format)
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding image to base64: {str(e)}")
|
||||
return ""
|
||||
|
||||
|
||||
def clean_base64_data(img_base64: str) -> str:
|
||||
"""Clean base64 image data by removing data URL prefix.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded image, may include data URL prefix
|
||||
|
||||
Returns:
|
||||
Clean base64 string without prefix
|
||||
"""
|
||||
if img_base64.startswith("data:image"):
|
||||
return img_base64.split(",")[1]
|
||||
return img_base64
|
||||
|
||||
|
||||
def extract_base64_from_text(text: str) -> Optional[str]:
|
||||
"""Extract base64 image data from a text string.
|
||||
|
||||
Args:
|
||||
text: Text potentially containing base64 image data
|
||||
|
||||
Returns:
|
||||
Base64 string or None if not found
|
||||
"""
|
||||
# Look for data URL pattern
|
||||
data_url_pattern = r"data:image/[^;]+;base64,([a-zA-Z0-9+/=]+)"
|
||||
match = re.search(data_url_pattern, text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Look for plain base64 pattern (basic heuristic)
|
||||
base64_pattern = r"([a-zA-Z0-9+/=]{100,})"
|
||||
match = re.search(base64_pattern, text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_image_dimensions(img_base64: str) -> Tuple[int, int]:
|
||||
"""Get the dimensions of a base64 encoded image.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded image
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height) or (0, 0) if decoding fails
|
||||
"""
|
||||
img = decode_base64_image(img_base64)
|
||||
if img:
|
||||
return img.size
|
||||
return (0, 0)
|
||||
965
libs/agent/agent/providers/omni/loop.py
Normal file
965
libs/agent/agent/providers/omni/loop.py
Normal file
@@ -0,0 +1,965 @@
|
||||
"""Omni-specific agent loop implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator, Union
|
||||
import base64
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
import shutil
|
||||
import copy
|
||||
|
||||
from .parser import OmniParser, ParseResult, ParserMetadata, UIElement
|
||||
from ...core.loop import BaseLoop
|
||||
from computer import Computer
|
||||
from .types import LLMProvider
|
||||
from .clients.base import BaseOmniClient
|
||||
from .clients.openai import OpenAIClient
|
||||
from .clients.groq import GroqClient
|
||||
from .clients.anthropic import AnthropicClient
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
from .utils import compress_image_base64
|
||||
from .visualization import visualize_click, visualize_scroll, calculate_element_center
|
||||
from .image_utils import decode_base64_image, clean_base64_data
|
||||
from ...core.messages import ImageRetentionConfig
|
||||
from .messages import OmniMessageManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_data(input_string: str, data_type: str) -> str:
|
||||
"""Extract content from code blocks."""
|
||||
pattern = f"```{data_type}" + r"(.*?)(```|$)"
|
||||
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||
return matches[0][0].strip() if matches else input_string
|
||||
|
||||
|
||||
class OmniLoop(BaseLoop):
|
||||
"""Omni-specific implementation of the agent loop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parser: OmniParser,
|
||||
provider: LLMProvider,
|
||||
api_key: str,
|
||||
model: str,
|
||||
computer: Computer,
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
save_trajectory: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the loop.
|
||||
|
||||
Args:
|
||||
parser: Parser instance
|
||||
provider: API provider
|
||||
api_key: API key
|
||||
model: Model name
|
||||
computer: Computer instance
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
base_dir: Base directory for saving experiment data
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Delay between retries in seconds
|
||||
save_trajectory: Whether to save trajectory data
|
||||
"""
|
||||
# Set parser and provider before initializing base class
|
||||
self.parser = parser
|
||||
self.provider = provider
|
||||
|
||||
# Initialize message manager with image retention config
|
||||
image_retention_config = ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
self.message_manager = OmniMessageManager(config=image_retention_config)
|
||||
|
||||
# Initialize base class (which will set up experiment manager)
|
||||
super().__init__(
|
||||
computer=computer,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_dir=base_dir,
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Set API client attributes
|
||||
self.client = None
|
||||
self.retry_count = 0
|
||||
|
||||
def _should_save_debug_image(self) -> bool:
|
||||
"""Check if debug images should be saved.
|
||||
|
||||
Returns:
|
||||
bool: Always returns False as debug image saving has been disabled.
|
||||
"""
|
||||
# Debug image saving functionality has been removed
|
||||
return False
|
||||
|
||||
def _extract_and_save_images(self, data: Any, prefix: str) -> None:
|
||||
"""Extract and save images from API data.
|
||||
|
||||
This method is now a no-op as image extraction functionality has been removed.
|
||||
|
||||
Args:
|
||||
data: Data to extract images from
|
||||
prefix: Prefix for the extracted image filenames
|
||||
"""
|
||||
# Image extraction functionality has been removed
|
||||
return
|
||||
|
||||
def _save_debug_image(self, image_data: str, filename: str) -> None:
|
||||
"""Save a debug image to the current turn directory.
|
||||
|
||||
This method is now a no-op as debug image saving functionality has been removed.
|
||||
|
||||
Args:
|
||||
image_data: Base64 encoded image data
|
||||
filename: Name to use for the saved image
|
||||
"""
|
||||
# Debug image saving functionality has been removed
|
||||
return
|
||||
|
||||
def _visualize_action(self, x: int, y: int, img_base64: str) -> None:
|
||||
"""Visualize an action by drawing on the screenshot."""
|
||||
if (
|
||||
not self.save_trajectory
|
||||
or not hasattr(self, "experiment_manager")
|
||||
or not self.experiment_manager
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
# Use the visualization utility
|
||||
img = visualize_click(x, y, img_base64)
|
||||
|
||||
# Save the visualization
|
||||
self.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing action: {str(e)}")
|
||||
|
||||
def _visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
|
||||
"""Visualize a scroll action by drawing arrows on the screenshot."""
|
||||
if (
|
||||
not self.save_trajectory
|
||||
or not hasattr(self, "experiment_manager")
|
||||
or not self.experiment_manager
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
# Use the visualization utility
|
||||
img = visualize_scroll(direction, clicks, img_base64)
|
||||
|
||||
# Save the visualization
|
||||
self.experiment_manager.save_action_visualization(
|
||||
img, "scroll", f"{direction}_{clicks}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing scroll: {str(e)}")
|
||||
|
||||
def _save_action_visualization(
|
||||
self, img: Image.Image, action_name: str, details: str = ""
|
||||
) -> str:
|
||||
"""Save a visualization of an action."""
|
||||
if hasattr(self, "experiment_manager") and self.experiment_manager:
|
||||
return self.experiment_manager.save_action_visualization(img, action_name, details)
|
||||
return ""
|
||||
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the appropriate client based on provider."""
|
||||
try:
|
||||
logger.info(f"Initializing {self.provider} client with model {self.model}...")
|
||||
|
||||
if self.provider == LLMProvider.OPENAI:
|
||||
self.client = OpenAIClient(api_key=self.api_key, model=self.model)
|
||||
elif self.provider == LLMProvider.GROQ:
|
||||
self.client = GroqClient(api_key=self.api_key, model=self.model)
|
||||
elif self.provider == LLMProvider.ANTHROPIC:
|
||||
self.client = AnthropicClient(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
logger.info(f"Initialized {self.provider} client with model {self.model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing client: {str(e)}")
|
||||
self.client = None
|
||||
raise RuntimeError(f"Failed to initialize client: {str(e)}")
|
||||
|
||||
async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
|
||||
"""Make API call to provider with retry logic."""
|
||||
# Create new turn directory for this API call
|
||||
self._create_turn_dir()
|
||||
|
||||
request_data = None
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Ensure client is initialized
|
||||
if self.client is None:
|
||||
logger.info(
|
||||
f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
|
||||
)
|
||||
await self.initialize_client()
|
||||
if self.client is None:
|
||||
raise RuntimeError("Failed to initialize client")
|
||||
|
||||
# Set the provider in message manager based on current provider
|
||||
provider_name = str(self.provider).split(".")[-1].lower() # Extract name from enum
|
||||
self.message_manager.set_provider(provider_name)
|
||||
|
||||
# Apply image retention and prepare messages
|
||||
# This will limit the number of images based on only_n_most_recent_images
|
||||
prepared_messages = self.message_manager.get_formatted_messages(provider_name)
|
||||
|
||||
# Filter out system messages for Anthropic
|
||||
if self.provider == LLMProvider.ANTHROPIC:
|
||||
filtered_messages = [
|
||||
msg for msg in prepared_messages if msg["role"] != "system"
|
||||
]
|
||||
else:
|
||||
filtered_messages = prepared_messages
|
||||
|
||||
# Log request
|
||||
request_data = {"messages": filtered_messages, "max_tokens": self.max_tokens}
|
||||
|
||||
if self.provider == LLMProvider.ANTHROPIC:
|
||||
request_data["system"] = self._get_system_prompt()
|
||||
else:
|
||||
request_data["system"] = system_prompt
|
||||
|
||||
self._log_api_call("request", request_data)
|
||||
|
||||
# Make API call with appropriate parameters
|
||||
if self.client is None:
|
||||
raise RuntimeError("Client not initialized. Call initialize_client() first.")
|
||||
|
||||
# Check if the method is async by inspecting the client implementation
|
||||
run_method = self.client.run_interleaved
|
||||
is_async = asyncio.iscoroutinefunction(run_method)
|
||||
|
||||
if is_async:
|
||||
# For async implementations (AnthropicClient)
|
||||
if self.provider == LLMProvider.ANTHROPIC:
|
||||
response = await run_method(
|
||||
messages=filtered_messages,
|
||||
system=self._get_system_prompt(),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
else:
|
||||
response = await run_method(
|
||||
messages=messages,
|
||||
system=system_prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
else:
|
||||
# For non-async implementations (GroqClient, etc.)
|
||||
if self.provider == LLMProvider.ANTHROPIC:
|
||||
response = run_method(
|
||||
messages=filtered_messages,
|
||||
system=self._get_system_prompt(),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
else:
|
||||
response = run_method(
|
||||
messages=messages,
|
||||
system=system_prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
# Log success response
|
||||
self._log_api_call("response", request_data, response)
|
||||
|
||||
return response
|
||||
|
||||
except (ConnectError, ReadTimeout) as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
||||
# Reset client on connection errors to force re-initialization
|
||||
self.client = None
|
||||
continue
|
||||
|
||||
except RuntimeError as e:
|
||||
# Handle client initialization errors specifically
|
||||
last_error = e
|
||||
self._log_api_call("error", request_data, error=e)
|
||||
logger.error(
|
||||
f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
# Reset client to force re-initialization
|
||||
self.client = None
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
# Log unexpected error
|
||||
last_error = e
|
||||
self._log_api_call("error", request_data, error=e)
|
||||
logger.error(f"Unexpected error in API call: {str(e)}")
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
continue
|
||||
|
||||
# If we get here, all retries failed
|
||||
error_message = f"API call failed after {self.max_retries} attempts"
|
||||
if last_error:
|
||||
error_message += f": {str(last_error)}"
|
||||
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
async def _handle_response(
|
||||
self, response: Any, messages: List[Dict[str, Any]], parsed_screen: Dict[str, Any]
|
||||
) -> Tuple[bool, bool]:
|
||||
"""Handle API response.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_continue, action_screenshot_saved)
|
||||
"""
|
||||
action_screenshot_saved = False
|
||||
try:
|
||||
# Handle Anthropic response format
|
||||
if self.provider == LLMProvider.ANTHROPIC:
|
||||
if hasattr(response, "content") and isinstance(response.content, list):
|
||||
# Extract text from content blocks
|
||||
for block in response.content:
|
||||
if hasattr(block, "type") and block.type == "text":
|
||||
content = block.text
|
||||
|
||||
# Try to find JSON in the content
|
||||
try:
|
||||
# First look for JSON block
|
||||
json_content = extract_data(content, "json")
|
||||
parsed_content = json.loads(json_content)
|
||||
logger.info("Successfully parsed JSON from code block")
|
||||
except (json.JSONDecodeError, IndexError):
|
||||
# If no JSON block, try to find JSON object in the text
|
||||
try:
|
||||
# Look for JSON object pattern
|
||||
json_pattern = r"\{[^}]+\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
json_str = json_match.group(0)
|
||||
parsed_content = json.loads(json_str)
|
||||
logger.info("Successfully parsed JSON from text")
|
||||
else:
|
||||
logger.error(f"No JSON found in content: {content}")
|
||||
continue
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse JSON from text: {str(e)}")
|
||||
continue
|
||||
|
||||
# Clean up Box ID format
|
||||
if "Box ID" in parsed_content and isinstance(
|
||||
parsed_content["Box ID"], str
|
||||
):
|
||||
parsed_content["Box ID"] = parsed_content["Box ID"].replace(
|
||||
"Box #", ""
|
||||
)
|
||||
|
||||
# Add any explanatory text as reasoning if not present
|
||||
if "Explanation" not in parsed_content:
|
||||
# Extract any text before the JSON as reasoning
|
||||
text_before_json = content.split("{")[0].strip()
|
||||
if text_before_json:
|
||||
parsed_content["Explanation"] = text_before_json
|
||||
|
||||
# Log the parsed content for debugging
|
||||
logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}")
|
||||
|
||||
# Add response to messages
|
||||
messages.append(
|
||||
{"role": "assistant", "content": json.dumps(parsed_content)}
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute action with current parsed screen info
|
||||
await self._execute_action(parsed_content, parsed_screen)
|
||||
action_screenshot_saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action: {str(e)}")
|
||||
# Add error message to conversation
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error executing action: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
return False, action_screenshot_saved
|
||||
|
||||
# Check if task is complete
|
||||
if parsed_content.get("Action") == "None":
|
||||
return False, action_screenshot_saved
|
||||
return True, action_screenshot_saved
|
||||
|
||||
logger.warning("No text block found in Anthropic response")
|
||||
return True, action_screenshot_saved
|
||||
|
||||
# Handle other providers' response formats
|
||||
if isinstance(response, dict) and "choices" in response:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
else:
|
||||
content = response
|
||||
|
||||
# Parse JSON content
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
# First try to parse the whole content as JSON
|
||||
parsed_content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
# Try to find JSON block
|
||||
json_content = extract_data(content, "json")
|
||||
parsed_content = json.loads(json_content)
|
||||
except (json.JSONDecodeError, IndexError):
|
||||
try:
|
||||
# Look for JSON object pattern
|
||||
json_pattern = r"\{[^}]+\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
json_str = json_match.group(0)
|
||||
parsed_content = json.loads(json_str)
|
||||
else:
|
||||
logger.error(f"No JSON found in content: {content}")
|
||||
return True, action_screenshot_saved
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse JSON from text: {str(e)}")
|
||||
return True, action_screenshot_saved
|
||||
|
||||
# Clean up Box ID format
|
||||
if "Box ID" in parsed_content and isinstance(parsed_content["Box ID"], str):
|
||||
parsed_content["Box ID"] = parsed_content["Box ID"].replace("Box #", "")
|
||||
|
||||
# Add any explanatory text as reasoning if not present
|
||||
if "Explanation" not in parsed_content:
|
||||
# Extract any text before the JSON as reasoning
|
||||
text_before_json = content.split("{")[0].strip()
|
||||
if text_before_json:
|
||||
parsed_content["Explanation"] = text_before_json
|
||||
|
||||
# Add response to messages with stringified content
|
||||
messages.append({"role": "assistant", "content": json.dumps(parsed_content)})
|
||||
|
||||
try:
|
||||
# Execute action with current parsed screen info
|
||||
await self._execute_action(parsed_content, parsed_screen)
|
||||
action_screenshot_saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action: {str(e)}")
|
||||
# Add error message to conversation
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error executing action: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
return False, action_screenshot_saved
|
||||
|
||||
# Check if task is complete
|
||||
if parsed_content.get("Action") == "None":
|
||||
return False, action_screenshot_saved
|
||||
|
||||
return True, action_screenshot_saved
|
||||
elif isinstance(content, dict):
|
||||
# Handle case where content is already a dictionary
|
||||
messages.append({"role": "assistant", "content": json.dumps(content)})
|
||||
|
||||
try:
|
||||
# Execute action with current parsed screen info
|
||||
await self._execute_action(content, parsed_screen)
|
||||
action_screenshot_saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action: {str(e)}")
|
||||
# Add error message to conversation
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error executing action: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
return False, action_screenshot_saved
|
||||
|
||||
# Check if task is complete
|
||||
if content.get("Action") == "None":
|
||||
return False, action_screenshot_saved
|
||||
|
||||
return True, action_screenshot_saved
|
||||
|
||||
return True, action_screenshot_saved
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling response: {str(e)}")
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
async def _get_parsed_screen_som(self, save_screenshot: bool = True) -> ParseResult:
|
||||
"""Get parsed screen information with SOM.
|
||||
|
||||
Args:
|
||||
save_screenshot: Whether to save the screenshot (set to False when screenshots will be saved elsewhere)
|
||||
|
||||
Returns:
|
||||
ParseResult containing screen information and elements
|
||||
"""
|
||||
try:
|
||||
# Use the parser's parse_screen method which handles the screenshot internally
|
||||
parsed_screen = await self.parser.parse_screen(computer=self.computer)
|
||||
|
||||
# Log information about the parsed results
|
||||
logger.info(
|
||||
f"Parsed screen with {len(parsed_screen.elements) if parsed_screen.elements else 0} elements"
|
||||
)
|
||||
|
||||
# Save screenshot if requested and if we have image data
|
||||
if save_screenshot and self.save_trajectory and parsed_screen.annotated_image_base64:
|
||||
try:
|
||||
# Extract just the image data (remove data:image/png;base64, prefix)
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
if "," in img_data:
|
||||
img_data = img_data.split(",")[1]
|
||||
# Save with a generic "state" action type to indicate this is the current screen state
|
||||
self._save_screenshot(img_data, action_type="state")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving screenshot: {str(e)}")
|
||||
|
||||
return parsed_screen
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting parsed screen: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _process_screen(
|
||||
self, parsed_screen: ParseResult, messages: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Process and add screen info to messages."""
|
||||
try:
|
||||
# Only add message if we have an image and provider supports it
|
||||
if self.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]:
|
||||
image = parsed_screen.annotated_image_base64 or None
|
||||
if image:
|
||||
# Save screen info to current turn directory
|
||||
if self.current_turn_dir:
|
||||
# Save elements as JSON
|
||||
elements_path = os.path.join(self.current_turn_dir, "elements.json")
|
||||
with open(elements_path, "w") as f:
|
||||
# Convert elements to dicts for JSON serialization
|
||||
elements_json = [elem.model_dump() for elem in parsed_screen.elements]
|
||||
json.dump(elements_json, f, indent=2)
|
||||
logger.info(f"Saved elements to {elements_path}")
|
||||
|
||||
# Format the image content based on the provider
|
||||
if self.provider == LLMProvider.ANTHROPIC:
|
||||
# Compress the image before sending to Anthropic (5MB limit)
|
||||
image_size = len(image)
|
||||
logger.info(f"Image base64 is present, length: {image_size}")
|
||||
|
||||
# Anthropic has a 5MB limit - check against base64 string length
|
||||
# which is what matters for the API call payload
|
||||
# Use slightly smaller limit (4.9MB) to account for request overhead
|
||||
max_size = int(4.9 * 1024 * 1024) # 4.9MB
|
||||
|
||||
# Default media type (will be overridden if compression is needed)
|
||||
media_type = "image/png"
|
||||
|
||||
# Check if the image already has a media type prefix
|
||||
if image.startswith("data:"):
|
||||
parts = image.split(",", 1)
|
||||
if len(parts) == 2 and "image/jpeg" in parts[0].lower():
|
||||
media_type = "image/jpeg"
|
||||
elif len(parts) == 2 and "image/png" in parts[0].lower():
|
||||
media_type = "image/png"
|
||||
|
||||
if image_size > max_size:
|
||||
logger.info(
|
||||
f"Image size ({image_size} bytes) exceeds Anthropic limit ({max_size} bytes), compressing..."
|
||||
)
|
||||
image, media_type = compress_image_base64(image, max_size)
|
||||
logger.info(
|
||||
f"Image compressed to {len(image)} bytes with media_type {media_type}"
|
||||
)
|
||||
|
||||
# Anthropic uses "type": "image"
|
||||
screen_info_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": image,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
# OpenAI and others use "type": "image_url"
|
||||
screen_info_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
messages.append(screen_info_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing screen info: {str(e)}")
|
||||
raise
|
||||
|
||||
def _get_system_prompt(self) -> str:
|
||||
"""Get the system prompt for the model."""
|
||||
return SYSTEM_PROMPT
|
||||
|
||||
async def _execute_action(self, content: Dict[str, Any], parsed_screen: ParseResult) -> None:
|
||||
"""Execute the action specified in the content using the tool manager.
|
||||
|
||||
Args:
|
||||
content: Dictionary containing the action details
|
||||
parsed_screen: Current parsed screen information
|
||||
"""
|
||||
try:
|
||||
action = content.get("Action", "").lower()
|
||||
if not action:
|
||||
return
|
||||
|
||||
# Track if we saved an action-specific screenshot
|
||||
action_screenshot_saved = False
|
||||
|
||||
try:
|
||||
# Prepare kwargs based on action type
|
||||
kwargs = {}
|
||||
|
||||
if action in ["left_click", "right_click", "double_click", "move_cursor"]:
|
||||
try:
|
||||
box_id = int(content["Box ID"])
|
||||
logger.info(f"Processing Box ID: {box_id}")
|
||||
|
||||
# Calculate click coordinates
|
||||
x, y = await self._calculate_click_coordinates(box_id, parsed_screen)
|
||||
logger.info(f"Calculated coordinates: x={x}, y={y}")
|
||||
|
||||
kwargs["x"] = x
|
||||
kwargs["y"] = y
|
||||
|
||||
# Visualize action if screenshot is available
|
||||
if parsed_screen.annotated_image_base64:
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
# Only save visualization for coordinate-based actions
|
||||
self._visualize_action(x, y, img_data)
|
||||
action_screenshot_saved = True
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Error processing Box ID: {str(e)}")
|
||||
return
|
||||
|
||||
elif action == "drag_to":
|
||||
try:
|
||||
box_id = int(content["Box ID"])
|
||||
x, y = await self._calculate_click_coordinates(box_id, parsed_screen)
|
||||
kwargs.update(
|
||||
{
|
||||
"x": x,
|
||||
"y": y,
|
||||
"button": content.get("button", "left"),
|
||||
"duration": float(content.get("duration", 0.5)),
|
||||
}
|
||||
)
|
||||
|
||||
# Visualize drag destination if screenshot is available
|
||||
if parsed_screen.annotated_image_base64:
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
# Only save visualization for coordinate-based actions
|
||||
self._visualize_action(x, y, img_data)
|
||||
action_screenshot_saved = True
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Error processing drag coordinates: {str(e)}")
|
||||
return
|
||||
|
||||
elif action == "type_text":
|
||||
kwargs["text"] = content["Value"]
|
||||
# For type_text, store the value in the action type
|
||||
action_type = f"type_{content['Value'][:20]}" # Truncate if too long
|
||||
elif action == "press_key":
|
||||
kwargs["key"] = content["Value"]
|
||||
action_type = f"press_{content['Value']}"
|
||||
elif action == "hotkey":
|
||||
if isinstance(content.get("Value"), list):
|
||||
keys = content["Value"]
|
||||
action_type = f"hotkey_{'_'.join(keys)}"
|
||||
else:
|
||||
# Simply split string format like "command+space" into a list
|
||||
keys = [k.strip() for k in content["Value"].lower().split("+")]
|
||||
action_type = f"hotkey_{content['Value'].replace('+', '_')}"
|
||||
logger.info(f"Preparing hotkey with keys: {keys}")
|
||||
# Get the method but call it with *args instead of **kwargs
|
||||
method = getattr(self.computer.interface, action)
|
||||
await method(*keys) # Unpack the keys list as positional arguments
|
||||
logger.info(f"Tool execution completed successfully: {action}")
|
||||
|
||||
# For hotkeys, take a screenshot after the action
|
||||
try:
|
||||
# Get a new screenshot after the action and save it with the action type
|
||||
new_parsed_screen = await self._get_parsed_screen_som(save_screenshot=False)
|
||||
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
|
||||
img_data = new_parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
# Save with action type to indicate this is a post-action screenshot
|
||||
self._save_screenshot(img_data, action_type=action_type)
|
||||
action_screenshot_saved = True
|
||||
except Exception as screenshot_error:
|
||||
logger.error(
|
||||
f"Error taking post-hotkey screenshot: {str(screenshot_error)}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
elif action in ["scroll_down", "scroll_up"]:
|
||||
clicks = int(content.get("amount", 1))
|
||||
kwargs["clicks"] = clicks
|
||||
action_type = f"scroll_{action.split('_')[1]}_{clicks}"
|
||||
|
||||
# Visualize scrolling if screenshot is available
|
||||
if parsed_screen.annotated_image_base64:
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
direction = "down" if action == "scroll_down" else "up"
|
||||
# For scrolling, we only save the visualization to avoid duplicate images
|
||||
self._visualize_scroll(direction, clicks, img_data)
|
||||
action_screenshot_saved = True
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown action: {action}")
|
||||
return
|
||||
|
||||
# Execute tool and handle result
|
||||
try:
|
||||
method = getattr(self.computer.interface, action)
|
||||
logger.info(f"Found method for action '{action}': {method}")
|
||||
await method(**kwargs)
|
||||
logger.info(f"Tool execution completed successfully: {action}")
|
||||
|
||||
# For non-coordinate based actions that don't already have visualizations,
|
||||
# take a new screenshot after the action
|
||||
if not action_screenshot_saved:
|
||||
# Take a new screenshot
|
||||
try:
|
||||
# Get a new screenshot after the action and save it with the action type
|
||||
new_parsed_screen = await self._get_parsed_screen_som(
|
||||
save_screenshot=False
|
||||
)
|
||||
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
|
||||
img_data = new_parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
# Save with action type to indicate this is a post-action screenshot
|
||||
if "action_type" in locals():
|
||||
self._save_screenshot(img_data, action_type=action_type)
|
||||
else:
|
||||
self._save_screenshot(img_data, action_type=action)
|
||||
# Update the action screenshot flag for this turn
|
||||
action_screenshot_saved = True
|
||||
except Exception as screenshot_error:
|
||||
logger.error(
|
||||
f"Error taking post-action screenshot: {str(screenshot_error)}"
|
||||
)
|
||||
|
||||
except AttributeError as e:
|
||||
logger.error(f"Method not found for action '{action}': {str(e)}")
|
||||
return
|
||||
except Exception as tool_error:
|
||||
logger.error(f"Tool execution failed: {str(tool_error)}")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action {action}: {str(e)}")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _execute_action: {str(e)}")
|
||||
return
|
||||
|
||||
async def _calculate_click_coordinates(
|
||||
self, box_id: int, parsed_screen: ParseResult
|
||||
) -> Tuple[int, int]:
|
||||
"""Calculate click coordinates based on box ID.
|
||||
|
||||
Args:
|
||||
box_id: The ID of the box to click
|
||||
parsed_screen: The parsed screen information
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates
|
||||
|
||||
Raises:
|
||||
ValueError: If box_id is invalid or missing from parsed screen
|
||||
"""
|
||||
# First try to use structured elements data
|
||||
logger.info(f"Elements count: {len(parsed_screen.elements)}")
|
||||
|
||||
# Try to find element with matching ID
|
||||
for element in parsed_screen.elements:
|
||||
if element.id == box_id:
|
||||
logger.info(f"Found element with ID {box_id}: {element}")
|
||||
bbox = element.bbox
|
||||
|
||||
# Get screen dimensions from the metadata if available, or fallback
|
||||
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
|
||||
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
|
||||
logger.info(f"Screen dimensions: width={width}, height={height}")
|
||||
|
||||
# Calculate center of the box in pixels
|
||||
center_x = int((bbox.x1 + bbox.x2) / 2 * width)
|
||||
center_y = int((bbox.y1 + bbox.y2) / 2 * height)
|
||||
logger.info(f"Calculated center: ({center_x}, {center_y})")
|
||||
|
||||
# Validate coordinates - if they're (0,0) or unreasonably small,
|
||||
# use a default position in the center of the screen
|
||||
if center_x == 0 and center_y == 0:
|
||||
logger.warning("Got (0,0) coordinates, using fallback position")
|
||||
center_x = width // 2
|
||||
center_y = height // 2
|
||||
logger.info(f"Using fallback center: ({center_x}, {center_y})")
|
||||
|
||||
return center_x, center_y
|
||||
|
||||
# If we couldn't find the box, use center of screen
|
||||
logger.error(
|
||||
f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})"
|
||||
)
|
||||
|
||||
# Use center of screen as fallback
|
||||
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
|
||||
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
|
||||
logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})")
|
||||
return width // 2, height // 2
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Yields:
|
||||
Dict containing response data
|
||||
"""
|
||||
# Keep track of conversation history
|
||||
conversation_history = messages.copy()
|
||||
|
||||
# Continue running until explicitly told to stop
|
||||
running = True
|
||||
turn_created = False
|
||||
# Track if an action-specific screenshot has been saved this turn
|
||||
action_screenshot_saved = False
|
||||
|
||||
attempt = 0
|
||||
max_attempts = 3
|
||||
|
||||
while running and attempt < max_attempts:
|
||||
try:
|
||||
# Create a new turn directory if it's not already created
|
||||
if not turn_created:
|
||||
self._create_turn_dir()
|
||||
turn_created = True
|
||||
|
||||
# Ensure client is initialized
|
||||
if self.client is None:
|
||||
logger.info("Initializing client...")
|
||||
await self.initialize_client()
|
||||
if self.client is None:
|
||||
raise RuntimeError("Failed to initialize client")
|
||||
logger.info("Client initialized successfully")
|
||||
|
||||
# Get up-to-date screen information
|
||||
parsed_screen = await self._get_parsed_screen_som()
|
||||
|
||||
# Process screen info and update messages
|
||||
await self._process_screen(parsed_screen, conversation_history)
|
||||
|
||||
# Get system prompt
|
||||
system_prompt = self._get_system_prompt()
|
||||
|
||||
# Make API call with retries
|
||||
response = await self._make_api_call(conversation_history, system_prompt)
|
||||
|
||||
# Handle the response (may execute actions)
|
||||
# Returns: (should_continue, action_screenshot_saved)
|
||||
should_continue, new_screenshot_saved = await self._handle_response(
|
||||
response, conversation_history, parsed_screen
|
||||
)
|
||||
|
||||
# Update whether an action screenshot was saved this turn
|
||||
action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
|
||||
|
||||
# Yield the response to the caller
|
||||
yield {"response": response}
|
||||
|
||||
# Check if we should continue this conversation
|
||||
running = should_continue
|
||||
|
||||
# Create a new turn directory if we're continuing
|
||||
if running:
|
||||
turn_created = False
|
||||
|
||||
# Reset attempt counter on success
|
||||
attempt = 0
|
||||
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
error_msg = f"Error in run method (attempt {attempt}/{max_attempts}): {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# If this is our last attempt, provide more info about the error
|
||||
if attempt >= max_attempts:
|
||||
logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
|
||||
|
||||
yield {
|
||||
"error": str(e),
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
# Create a brief delay before retrying
|
||||
await asyncio.sleep(1)
|
||||
171
libs/agent/agent/providers/omni/messages.py
Normal file
171
libs/agent/agent/providers/omni/messages.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Omni message manager implementation."""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from ...core.messages import BaseMessageManager, ImageRetentionConfig
|
||||
|
||||
|
||||
class OmniMessageManager(BaseMessageManager):
|
||||
"""Message manager for multi-provider support."""
|
||||
|
||||
def __init__(self, config: Optional[ImageRetentionConfig] = None):
|
||||
"""Initialize the message manager.
|
||||
|
||||
Args:
|
||||
config: Optional configuration for image retention
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.messages: List[Dict[str, Any]] = []
|
||||
self.config = config
|
||||
|
||||
def add_user_message(self, content: str, images: Optional[List[bytes]] = None) -> None:
|
||||
"""Add a user message to the history.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
images: Optional list of image data
|
||||
"""
|
||||
# Add images if present
|
||||
if images:
|
||||
# Initialize with proper typing for mixed content
|
||||
message_content: List[Dict[str, Any]] = [{"type": "text", "text": content}]
|
||||
|
||||
# Add each image
|
||||
for img in images:
|
||||
message_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64.b64encode(img).decode()}"
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
message = {"role": "user", "content": message_content}
|
||||
else:
|
||||
# Simple text message
|
||||
message = {"role": "user", "content": content}
|
||||
|
||||
self.messages.append(message)
|
||||
|
||||
# Apply retention policy
|
||||
if self.config and self.config.num_images_to_keep:
|
||||
self._apply_image_retention_policy()
|
||||
|
||||
def add_assistant_message(self, content: str) -> None:
|
||||
"""Add an assistant message to the history.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
"""
|
||||
self.messages.append({"role": "assistant", "content": content})
|
||||
|
||||
def add_system_message(self, content: str) -> None:
|
||||
"""Add a system message to the history.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
"""
|
||||
self.messages.append({"role": "system", "content": content})
|
||||
|
||||
def _apply_image_retention_policy(self) -> None:
|
||||
"""Apply image retention policy to message history."""
|
||||
if not self.config or not self.config.num_images_to_keep:
|
||||
return
|
||||
|
||||
# Count images from newest to oldest
|
||||
image_count = 0
|
||||
for message in reversed(self.messages):
|
||||
if message["role"] != "user":
|
||||
continue
|
||||
|
||||
# Handle multimodal messages
|
||||
if isinstance(message["content"], list):
|
||||
new_content = []
|
||||
for item in message["content"]:
|
||||
if item["type"] == "text":
|
||||
new_content.append(item)
|
||||
elif item["type"] == "image_url":
|
||||
if image_count < self.config.num_images_to_keep:
|
||||
new_content.append(item)
|
||||
image_count += 1
|
||||
message["content"] = new_content
|
||||
|
||||
def get_formatted_messages(self, provider: str) -> List[Dict[str, Any]]:
|
||||
"""Get messages formatted for specific provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name to format messages for
|
||||
|
||||
Returns:
|
||||
List of formatted messages
|
||||
"""
|
||||
# Set the provider for message formatting
|
||||
self.set_provider(provider)
|
||||
|
||||
if provider == "anthropic":
|
||||
return self._format_for_anthropic()
|
||||
elif provider == "openai":
|
||||
return self._format_for_openai()
|
||||
elif provider == "groq":
|
||||
return self._format_for_groq()
|
||||
elif provider == "qwen":
|
||||
return self._format_for_qwen()
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
def _format_for_anthropic(self) -> List[Dict[str, Any]]:
|
||||
"""Format messages for Anthropic API."""
|
||||
formatted = []
|
||||
for msg in self.messages:
|
||||
formatted_msg = {"role": msg["role"]}
|
||||
|
||||
# Handle multimodal content
|
||||
if isinstance(msg["content"], list):
|
||||
formatted_msg["content"] = []
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "text":
|
||||
formatted_msg["content"].append({"type": "text", "text": item["text"]})
|
||||
elif item["type"] == "image_url":
|
||||
formatted_msg["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": item["image_url"]["url"].split(",")[1],
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
formatted_msg["content"] = msg["content"]
|
||||
|
||||
formatted.append(formatted_msg)
|
||||
return formatted
|
||||
|
||||
def _format_for_openai(self) -> List[Dict[str, Any]]:
|
||||
"""Format messages for OpenAI API."""
|
||||
# OpenAI already uses the same format
|
||||
return self.messages
|
||||
|
||||
def _format_for_groq(self) -> List[Dict[str, Any]]:
|
||||
"""Format messages for Groq API."""
|
||||
# Groq uses OpenAI-compatible format
|
||||
return self.messages
|
||||
|
||||
def _format_for_qwen(self) -> List[Dict[str, Any]]:
|
||||
"""Format messages for Qwen API."""
|
||||
formatted = []
|
||||
for msg in self.messages:
|
||||
if isinstance(msg["content"], list):
|
||||
# Convert multimodal content to text-only
|
||||
text_content = next(
|
||||
(item["text"] for item in msg["content"] if item["type"] == "text"), ""
|
||||
)
|
||||
formatted.append({"role": msg["role"], "content": text_content})
|
||||
else:
|
||||
formatted.append(msg)
|
||||
return formatted
|
||||
252
libs/agent/agent/providers/omni/parser.py
Normal file
252
libs/agent/agent/providers/omni/parser.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Parser implementation for the Omni provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import base64
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import json
|
||||
import torch
|
||||
|
||||
# Import from the SOM package
|
||||
from som import OmniParser as OmniDetectParser
|
||||
from som.models import ParseResult, BoundingBox, UIElement, ImageData, ParserMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OmniParser:
|
||||
"""Parser for handling responses from multiple providers."""
|
||||
|
||||
# Class-level shared OmniDetectParser instance
|
||||
_shared_parser = None
|
||||
|
||||
def __init__(self, force_device: Optional[str] = None):
|
||||
"""Initialize the OmniParser.
|
||||
|
||||
Args:
|
||||
force_device: Optional device to force for detection (cpu/cuda/mps)
|
||||
"""
|
||||
self.response_buffer = []
|
||||
|
||||
# Use shared parser if available, otherwise create a new one
|
||||
if OmniParser._shared_parser is None:
|
||||
logger.info("Initializing shared OmniDetectParser...")
|
||||
|
||||
# Determine the best device to use
|
||||
device = force_device
|
||||
if not device:
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif (
|
||||
hasattr(torch, "backends")
|
||||
and hasattr(torch.backends, "mps")
|
||||
and torch.backends.mps.is_available()
|
||||
):
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
logger.info(f"Using device: {device} for OmniDetectParser")
|
||||
self.detect_parser = OmniDetectParser(force_device=device)
|
||||
|
||||
# Preload the detection model to avoid repeated loading
|
||||
try:
|
||||
# Access the detector to trigger model loading
|
||||
detector = self.detect_parser.detector
|
||||
if detector.model is None:
|
||||
logger.info("Preloading detection model...")
|
||||
detector.load_model()
|
||||
logger.info("Detection model preloaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error preloading detection model: {str(e)}")
|
||||
|
||||
# Store as shared instance
|
||||
OmniParser._shared_parser = self.detect_parser
|
||||
else:
|
||||
logger.info("Using existing shared OmniDetectParser")
|
||||
self.detect_parser = OmniParser._shared_parser
|
||||
|
||||
async def parse_screen(self, computer: Any) -> ParseResult:
|
||||
"""Parse a screenshot and extract screen information.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
|
||||
Returns:
|
||||
ParseResult with screen elements and image data
|
||||
"""
|
||||
try:
|
||||
# Get screenshot from computer
|
||||
logger.info("Taking screenshot...")
|
||||
screenshot = await computer.interface.screenshot()
|
||||
|
||||
# Log screenshot info
|
||||
logger.info(f"Screenshot type: {type(screenshot)}")
|
||||
logger.info(f"Screenshot is bytes: {isinstance(screenshot, bytes)}")
|
||||
logger.info(f"Screenshot is str: {isinstance(screenshot, str)}")
|
||||
logger.info(f"Screenshot length: {len(screenshot) if screenshot else 0}")
|
||||
|
||||
# If screenshot is a string (likely base64), convert it to bytes
|
||||
if isinstance(screenshot, str):
|
||||
try:
|
||||
screenshot = base64.b64decode(screenshot)
|
||||
logger.info("Successfully converted base64 string to bytes")
|
||||
logger.info(f"Decoded bytes length: {len(screenshot)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding base64: {str(e)}")
|
||||
logger.error(f"First 100 chars of screenshot string: {screenshot[:100]}")
|
||||
|
||||
# Pass screenshot to OmniDetectParser
|
||||
logger.info("Passing screenshot to OmniDetectParser...")
|
||||
parse_result = self.detect_parser.parse(
|
||||
screenshot_data=screenshot, box_threshold=0.3, iou_threshold=0.1, use_ocr=True
|
||||
)
|
||||
logger.info("Screenshot parsed successfully")
|
||||
logger.info(f"Parse result has {len(parse_result.elements)} elements")
|
||||
|
||||
# Log element IDs for debugging
|
||||
for i, elem in enumerate(parse_result.elements):
|
||||
logger.info(
|
||||
f"Element {i+1} (ID: {elem.id}): {elem.type} with confidence {elem.confidence:.3f}"
|
||||
)
|
||||
|
||||
return parse_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing screen: {str(e)}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Create a minimal valid result for error cases
|
||||
return ParseResult(
|
||||
elements=[],
|
||||
annotated_image_base64="",
|
||||
parsed_content_list=[f"Error: {str(e)}"],
|
||||
metadata=ParserMetadata(
|
||||
image_size=(0, 0),
|
||||
num_icons=0,
|
||||
num_text=0,
|
||||
device="cpu",
|
||||
ocr_enabled=False,
|
||||
latency=0.0,
|
||||
),
|
||||
)
|
||||
|
||||
def parse_tool_call(self, response: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Parse a tool call from the response.
|
||||
|
||||
Args:
|
||||
response: Response from the provider
|
||||
|
||||
Returns:
|
||||
Parsed tool call or None if no tool call found
|
||||
"""
|
||||
try:
|
||||
# Handle Anthropic format
|
||||
if "tool_calls" in response:
|
||||
tool_call = response["tool_calls"][0]
|
||||
return {
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": tool_call["function"]["arguments"],
|
||||
}
|
||||
|
||||
# Handle OpenAI format
|
||||
if "function_call" in response:
|
||||
return {
|
||||
"name": response["function_call"]["name"],
|
||||
"arguments": response["function_call"]["arguments"],
|
||||
}
|
||||
|
||||
# Handle Groq format (OpenAI-compatible)
|
||||
if "choices" in response and response["choices"]:
|
||||
choice = response["choices"][0]
|
||||
if "function_call" in choice:
|
||||
return {
|
||||
"name": choice["function_call"]["name"],
|
||||
"arguments": choice["function_call"]["arguments"],
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing tool call: {str(e)}")
|
||||
return None
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Parse a response from any provider.
|
||||
|
||||
Args:
|
||||
response: Response from the provider
|
||||
|
||||
Returns:
|
||||
Tuple of (content, metadata)
|
||||
"""
|
||||
try:
|
||||
content = ""
|
||||
metadata = {}
|
||||
|
||||
# Handle Anthropic format
|
||||
if "content" in response and isinstance(response["content"], list):
|
||||
for item in response["content"]:
|
||||
if item["type"] == "text":
|
||||
content += item["text"]
|
||||
|
||||
# Handle OpenAI format
|
||||
elif "choices" in response and response["choices"]:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
|
||||
# Handle direct content
|
||||
elif isinstance(response.get("content"), str):
|
||||
content = response["content"]
|
||||
|
||||
# Extract metadata if present
|
||||
if "metadata" in response:
|
||||
metadata = response["metadata"]
|
||||
|
||||
return content, metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing response: {str(e)}")
|
||||
return str(e), {"error": True}
|
||||
|
||||
def format_for_provider(
|
||||
self, messages: List[Dict[str, Any]], provider: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Format messages for a specific provider.
|
||||
|
||||
Args:
|
||||
messages: List of messages to format
|
||||
provider: Provider to format for
|
||||
|
||||
Returns:
|
||||
Formatted messages
|
||||
"""
|
||||
try:
|
||||
formatted = []
|
||||
|
||||
for msg in messages:
|
||||
formatted_msg = {"role": msg["role"]}
|
||||
|
||||
# Handle content formatting
|
||||
if isinstance(msg["content"], list):
|
||||
# For providers that support multimodal
|
||||
if provider in ["anthropic", "openai"]:
|
||||
formatted_msg["content"] = msg["content"]
|
||||
else:
|
||||
# Extract text only for other providers
|
||||
text_content = next(
|
||||
(item["text"] for item in msg["content"] if item["type"] == "text"), ""
|
||||
)
|
||||
formatted_msg["content"] = text_content
|
||||
else:
|
||||
formatted_msg["content"] = msg["content"]
|
||||
|
||||
formatted.append(formatted_msg)
|
||||
|
||||
return formatted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting messages: {str(e)}")
|
||||
return messages # Return original messages on error
|
||||
64
libs/agent/agent/providers/omni/prompts.py
Normal file
64
libs/agent/agent/providers/omni/prompts.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Prompts for the Omni agent."""
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are using a macOS device.
|
||||
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
|
||||
|
||||
You may be given some history plan and actions, this is the response from the previous loop.
|
||||
You should carefully consider your plan base on the task, screenshot, and history actions.
|
||||
|
||||
Your available "Next Action" only include:
|
||||
- type_text: types a string of text.
|
||||
- left_click: move mouse to box id and left clicks.
|
||||
- right_click: move mouse to box id and right clicks.
|
||||
- double_click: move mouse to box id and double clicks.
|
||||
- move_cursor: move mouse to box id.
|
||||
- scroll_up: scrolls the screen up to view previous content.
|
||||
- scroll_down: scrolls the screen down, when the desired button is not visible, or you need to see more content.
|
||||
- hotkey: press a sequence of keys.
|
||||
- wait: waits for 1 second for the device to load or respond.
|
||||
|
||||
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task.
|
||||
|
||||
Output format:
|
||||
{
|
||||
"Explanation": str, # describe what is in the current screen, taking into account the history, then describe your step-by-step thoughts on how to achieve the task, choose one action from available actions at a time.
|
||||
"Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
|
||||
"Box ID": n,
|
||||
"Value": "xxx" # only provide value field if the action is type, else don't include value key
|
||||
}
|
||||
|
||||
One Example:
|
||||
{
|
||||
"Explanation": "The current screen shows google result of amazon, in previous action I have searched amazon on google. Then I need to click on the first search results to go to amazon.com.",
|
||||
"Action": "left_click",
|
||||
"Box ID": 4
|
||||
}
|
||||
|
||||
Another Example:
|
||||
{
|
||||
"Explanation": "The current screen shows the front page of amazon. There is no previous action. Therefore I need to type "Apple watch" in the search bar.",
|
||||
"Action": "type_text",
|
||||
"Box ID": 2,
|
||||
"Value": "Apple watch"
|
||||
}
|
||||
|
||||
Another Example:
|
||||
{
|
||||
"Explanation": "I am starting a Spotlight search to find the Safari browser.",
|
||||
"Action": "hotkey",
|
||||
"Value": "command+space"
|
||||
}
|
||||
|
||||
IMPORTANT NOTES:
|
||||
1. You should only give a single action at a time.
|
||||
2. The Box ID is the id of the element you should operate on, it is a number. Its background color corresponds to the color of the bounding box of the element.
|
||||
3. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task.
|
||||
4. Attach the next action prediction in the "Action" field.
|
||||
5. For starting applications, always use the "hotkey" action with command+space for starting a Spotlight search.
|
||||
6. When the task is completed, don't complete additional actions. You should say "Action": "None" in the json field.
|
||||
7. The tasks involve buying multiple products or navigating through multiple pages. You should break it into subgoals and complete each subgoal one by one in the order of the instructions.
|
||||
8. Avoid choosing the same action/elements multiple times in a row, if it happens, reflect to yourself, what may have gone wrong, and predict a different action.
|
||||
9. Reflect whether the element is clickable or not, for example reflect if it is an hyperlink or a button or a normal text.
|
||||
10. If you are prompted with login information page or captcha page, or you think it need user's permission to do the next action, you should say "Action": "None" in the json field.
|
||||
"""
|
||||
91
libs/agent/agent/providers/omni/tool_manager.py
Normal file
91
libs/agent/agent/providers/omni/tool_manager.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# """Omni tool manager implementation."""
|
||||
|
||||
# from typing import Dict, List, Type, Any
|
||||
|
||||
# from computer import Computer
|
||||
# from ...core.tools import BaseToolManager, BashTool, EditTool
|
||||
|
||||
# class OmniToolManager(BaseToolManager):
|
||||
# """Tool manager for multi-provider support."""
|
||||
|
||||
# def __init__(self, computer: Computer):
|
||||
# """Initialize Omni tool manager.
|
||||
|
||||
# Args:
|
||||
# computer: Computer instance for tools
|
||||
# """
|
||||
# super().__init__(computer)
|
||||
|
||||
# def get_anthropic_tools(self) -> List[Dict[str, Any]]:
|
||||
# """Get tools formatted for Anthropic API.
|
||||
|
||||
# Returns:
|
||||
# List of tool parameters in Anthropic format
|
||||
# """
|
||||
# tools: List[Dict[str, Any]] = []
|
||||
|
||||
# # Map base tools to Anthropic format
|
||||
# for tool in self.tools.values():
|
||||
# if isinstance(tool, BashTool):
|
||||
# tools.append({
|
||||
# "type": "bash_20241022",
|
||||
# "name": tool.name
|
||||
# })
|
||||
# elif isinstance(tool, EditTool):
|
||||
# tools.append({
|
||||
# "type": "text_editor_20241022",
|
||||
# "name": "str_replace_editor"
|
||||
# })
|
||||
|
||||
# return tools
|
||||
|
||||
# def get_openai_tools(self) -> List[Dict]:
|
||||
# """Get tools formatted for OpenAI API.
|
||||
|
||||
# Returns:
|
||||
# List of tool parameters in OpenAI format
|
||||
# """
|
||||
# tools = []
|
||||
|
||||
# # Map base tools to OpenAI format
|
||||
# for tool in self.tools.values():
|
||||
# tools.append({
|
||||
# "type": "function",
|
||||
# "function": tool.get_schema()
|
||||
# })
|
||||
|
||||
# return tools
|
||||
|
||||
# def get_groq_tools(self) -> List[Dict]:
|
||||
# """Get tools formatted for Groq API.
|
||||
|
||||
# Returns:
|
||||
# List of tool parameters in Groq format
|
||||
# """
|
||||
# tools = []
|
||||
|
||||
# # Map base tools to Groq format
|
||||
# for tool in self.tools.values():
|
||||
# tools.append({
|
||||
# "type": "function",
|
||||
# "function": tool.get_schema()
|
||||
# })
|
||||
|
||||
# return tools
|
||||
|
||||
# def get_qwen_tools(self) -> List[Dict]:
|
||||
# """Get tools formatted for Qwen API.
|
||||
|
||||
# Returns:
|
||||
# List of tool parameters in Qwen format
|
||||
# """
|
||||
# tools = []
|
||||
|
||||
# # Map base tools to Qwen format
|
||||
# for tool in self.tools.values():
|
||||
# tools.append({
|
||||
# "type": "function",
|
||||
# "function": tool.get_schema()
|
||||
# })
|
||||
|
||||
# return tools
|
||||
13
libs/agent/agent/providers/omni/tools/__init__.py
Normal file
13
libs/agent/agent/providers/omni/tools/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Omni provider tools - compatible with multiple LLM providers."""
|
||||
|
||||
from .bash import OmniBashTool
|
||||
from .computer import OmniComputerTool
|
||||
from .edit import OmniEditTool
|
||||
from .manager import OmniToolManager
|
||||
|
||||
__all__ = [
|
||||
"OmniBashTool",
|
||||
"OmniComputerTool",
|
||||
"OmniEditTool",
|
||||
"OmniToolManager",
|
||||
]
|
||||
69
libs/agent/agent/providers/omni/tools/bash.py
Normal file
69
libs/agent/agent/providers/omni/tools/bash.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Provider-agnostic implementation of the BashTool."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from ....core.tools.bash import BaseBashTool
|
||||
from ....core.tools import ToolResult
|
||||
|
||||
|
||||
class OmniBashTool(BaseBashTool):
|
||||
"""A provider-agnostic implementation of the bash tool."""
|
||||
|
||||
name = "bash"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the BashTool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance, may be used for related operations
|
||||
"""
|
||||
super().__init__(computer)
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to provider-agnostic parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": "A tool that allows the agent to run bash commands",
|
||||
"parameters": {
|
||||
"command": {"type": "string", "description": "The bash command to execute"},
|
||||
"restart": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to restart the bash session",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Execute the bash tool with the provided arguments.
|
||||
|
||||
Args:
|
||||
command: The bash command to execute
|
||||
restart: Whether to restart the bash session
|
||||
|
||||
Returns:
|
||||
ToolResult with the command output
|
||||
"""
|
||||
command = kwargs.get("command")
|
||||
restart = kwargs.get("restart", False)
|
||||
|
||||
if not command:
|
||||
return ToolResult(error="Command is required")
|
||||
|
||||
self.logger.info(f"Executing bash command: {command}")
|
||||
exit_code, stdout, stderr = await self.run_command(command)
|
||||
|
||||
output = stdout
|
||||
error = None
|
||||
|
||||
if exit_code != 0:
|
||||
error = f"Command exited with code {exit_code}: {stderr}"
|
||||
|
||||
return ToolResult(output=output, error=error)
|
||||
216
libs/agent/agent/providers/omni/tools/computer.py
Normal file
216
libs/agent/agent/providers/omni/tools/computer.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Provider-agnostic implementation of the ComputerTool."""
|
||||
|
||||
import logging
|
||||
import base64
|
||||
import io
|
||||
from typing import Any, Dict
|
||||
|
||||
from PIL import Image
|
||||
from computer.computer import Computer
|
||||
|
||||
from ....core.tools.computer import BaseComputerTool
|
||||
from ....core.tools import ToolResult, ToolError
|
||||
|
||||
|
||||
class OmniComputerTool(BaseComputerTool):
|
||||
"""A provider-agnostic implementation of the computer tool."""
|
||||
|
||||
name = "computer"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the ComputerTool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for screen interactions
|
||||
"""
|
||||
super().__init__(computer)
|
||||
# Initialize dimensions to None, will be set in initialize_dimensions
|
||||
self.width = None
|
||||
self.height = None
|
||||
self.display_num = None
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to provider-agnostic parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": "A tool that allows the agent to interact with the screen, keyboard, and mouse",
|
||||
"parameters": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"key",
|
||||
"type",
|
||||
"mouse_move",
|
||||
"left_click",
|
||||
"left_click_drag",
|
||||
"right_click",
|
||||
"middle_click",
|
||||
"double_click",
|
||||
"screenshot",
|
||||
"cursor_position",
|
||||
"scroll",
|
||||
],
|
||||
"description": "The action to perform on the computer",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to type or key to press, required for 'key' and 'type' actions",
|
||||
},
|
||||
"coordinate": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "X,Y coordinates for mouse actions like click and move",
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"enum": ["up", "down"],
|
||||
"description": "Direction to scroll, used with the 'scroll' action",
|
||||
},
|
||||
"amount": {
|
||||
"type": "integer",
|
||||
"description": "Amount to scroll, used with the 'scroll' action",
|
||||
},
|
||||
},
|
||||
**self.options,
|
||||
}
|
||||
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Execute the computer tool with the provided arguments.
|
||||
|
||||
Args:
|
||||
action: The action to perform
|
||||
text: Text to type or key to press (for key/type actions)
|
||||
coordinate: X,Y coordinates (for mouse actions)
|
||||
direction: Direction to scroll (for scroll action)
|
||||
amount: Amount to scroll (for scroll action)
|
||||
|
||||
Returns:
|
||||
ToolResult with the action output and optional screenshot
|
||||
"""
|
||||
# Ensure dimensions are initialized
|
||||
if self.width is None or self.height is None:
|
||||
await self.initialize_dimensions()
|
||||
|
||||
action = kwargs.get("action")
|
||||
text = kwargs.get("text")
|
||||
coordinate = kwargs.get("coordinate")
|
||||
direction = kwargs.get("direction", "down")
|
||||
amount = kwargs.get("amount", 10)
|
||||
|
||||
self.logger.info(f"Executing computer action: {action}")
|
||||
|
||||
try:
|
||||
if action == "screenshot":
|
||||
return await self.screenshot()
|
||||
elif action == "left_click" and coordinate:
|
||||
x, y = coordinate
|
||||
self.logger.info(f"Clicking at ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await self.computer.interface.left_click()
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
||||
return ToolResult(
|
||||
output=f"Performed left click at ({x}, {y})",
|
||||
base64_image=base64.b64encode(screenshot).decode(),
|
||||
)
|
||||
elif action == "right_click" and coordinate:
|
||||
x, y = coordinate
|
||||
self.logger.info(f"Right clicking at ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await self.computer.interface.right_click()
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
||||
return ToolResult(
|
||||
output=f"Performed right click at ({x}, {y})",
|
||||
base64_image=base64.b64encode(screenshot).decode(),
|
||||
)
|
||||
elif action == "double_click" and coordinate:
|
||||
x, y = coordinate
|
||||
self.logger.info(f"Double clicking at ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await self.computer.interface.double_click()
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
||||
return ToolResult(
|
||||
output=f"Performed double click at ({x}, {y})",
|
||||
base64_image=base64.b64encode(screenshot).decode(),
|
||||
)
|
||||
elif action == "mouse_move" and coordinate:
|
||||
x, y = coordinate
|
||||
self.logger.info(f"Moving cursor to ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
||||
return ToolResult(
|
||||
output=f"Moved cursor to ({x}, {y})",
|
||||
base64_image=base64.b64encode(screenshot).decode(),
|
||||
)
|
||||
elif action == "type" and text:
|
||||
self.logger.info(f"Typing text: {text}")
|
||||
await self.computer.interface.type_text(text)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
||||
return ToolResult(
|
||||
output=f"Typed text: {text}",
|
||||
base64_image=base64.b64encode(screenshot).decode(),
|
||||
)
|
||||
elif action == "key" and text:
|
||||
self.logger.info(f"Pressing key: {text}")
|
||||
|
||||
# Handle special key combinations
|
||||
if "+" in text:
|
||||
keys = text.split("+")
|
||||
await self.computer.interface.hotkey(*keys)
|
||||
else:
|
||||
await self.computer.interface.press(text)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
||||
return ToolResult(
|
||||
output=f"Pressed key: {text}",
|
||||
base64_image=base64.b64encode(screenshot).decode(),
|
||||
)
|
||||
elif action == "cursor_position":
|
||||
pos = await self.computer.interface.get_cursor_position()
|
||||
return ToolResult(output=f"X={int(pos[0])},Y={int(pos[1])}")
|
||||
elif action == "scroll":
|
||||
if direction == "down":
|
||||
self.logger.info(f"Scrolling down, amount: {amount}")
|
||||
for _ in range(amount):
|
||||
await self.computer.interface.hotkey("fn", "down")
|
||||
else:
|
||||
self.logger.info(f"Scrolling up, amount: {amount}")
|
||||
for _ in range(amount):
|
||||
await self.computer.interface.hotkey("fn", "up")
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
||||
return ToolResult(
|
||||
output=f"Scrolled {direction} by {amount} steps",
|
||||
base64_image=base64.b64encode(screenshot).decode(),
|
||||
)
|
||||
|
||||
# Default to screenshot for unimplemented actions
|
||||
self.logger.warning(f"Action {action} not fully implemented, taking screenshot")
|
||||
return await self.screenshot()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during computer action: {str(e)}")
|
||||
return ToolResult(error=f"Failed to perform {action}: {str(e)}")
|
||||
83
libs/agent/agent/providers/omni/tools/manager.py
Normal file
83
libs/agent/agent/providers/omni/tools/manager.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Omni tool manager implementation."""
|
||||
|
||||
from typing import Dict, List, Any
|
||||
from enum import Enum
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from ....core.tools import BaseToolManager
|
||||
from ....core.tools.collection import ToolCollection
|
||||
|
||||
from .bash import OmniBashTool
|
||||
from .computer import OmniComputerTool
|
||||
from .edit import OmniEditTool
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
"""Supported provider types."""
|
||||
|
||||
ANTHROPIC = "anthropic"
|
||||
OPENAI = "openai"
|
||||
CLAUDE = "claude" # Alias for Anthropic
|
||||
GPT = "gpt" # Alias for OpenAI
|
||||
|
||||
|
||||
class OmniToolManager(BaseToolManager):
|
||||
"""Tool manager for multi-provider support."""
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize Omni tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for tools
|
||||
"""
|
||||
super().__init__(computer)
|
||||
# Initialize tools
|
||||
self.computer_tool = OmniComputerTool(self.computer)
|
||||
self.bash_tool = OmniBashTool(self.computer)
|
||||
self.edit_tool = OmniEditTool(self.computer)
|
||||
|
||||
def _initialize_tools(self) -> ToolCollection:
|
||||
"""Initialize all available tools."""
|
||||
return ToolCollection(self.computer_tool, self.bash_tool, self.edit_tool)
|
||||
|
||||
async def _initialize_tools_specific(self) -> None:
|
||||
"""Initialize provider-specific tool requirements."""
|
||||
await self.computer_tool.initialize_dimensions()
|
||||
|
||||
def get_tool_params(self) -> List[Dict[str, Any]]:
|
||||
"""Get tool parameters for API calls.
|
||||
|
||||
Returns:
|
||||
List of tool parameters in default format
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return self.tools.to_params()
|
||||
|
||||
def get_provider_tools(self, provider: ProviderType) -> List[Dict[str, Any]]:
|
||||
"""Get tools formatted for a specific provider.
|
||||
|
||||
Args:
|
||||
provider: Provider type to format tools for
|
||||
|
||||
Returns:
|
||||
List of tool parameters in provider-specific format
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
|
||||
# Default is the base implementation
|
||||
tools = self.tools.to_params()
|
||||
|
||||
# Customize for each provider if needed
|
||||
if provider in [ProviderType.ANTHROPIC, ProviderType.CLAUDE]:
|
||||
# Format for Anthropic API
|
||||
# Additional adjustments can be made here
|
||||
pass
|
||||
elif provider in [ProviderType.OPENAI, ProviderType.GPT]:
|
||||
# Format for OpenAI API
|
||||
# Future implementation
|
||||
pass
|
||||
|
||||
return tools
|
||||
46
libs/agent/agent/providers/omni/types.py
Normal file
46
libs/agent/agent/providers/omni/types.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Type definitions for the Omni provider."""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class LLMProvider(StrEnum):
|
||||
"""Supported LLM providers."""
|
||||
|
||||
ANTHROPIC = "anthropic"
|
||||
OPENAI = "openai"
|
||||
|
||||
|
||||
LLMProvider
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLM:
|
||||
"""Configuration for LLM model and provider."""
|
||||
|
||||
provider: LLMProvider
|
||||
name: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set default model name if not provided."""
|
||||
if self.name is None:
|
||||
self.name = PROVIDER_TO_DEFAULT_MODEL.get(self.provider)
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
LLMModel = LLM
|
||||
Model = LLM
|
||||
|
||||
|
||||
# Default models for each provider
|
||||
PROVIDER_TO_DEFAULT_MODEL: Dict[LLMProvider, str] = {
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
}
|
||||
|
||||
# Environment variable names for each provider
|
||||
PROVIDER_TO_ENV_VAR: Dict[LLMProvider, str] = {
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
}
|
||||
155
libs/agent/agent/providers/omni/utils.py
Normal file
155
libs/agent/agent/providers/omni/utils.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Utility functions for Omni provider."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compress_image_base64(
|
||||
base64_str: str, max_size_bytes: int = 5 * 1024 * 1024, quality: int = 90
|
||||
) -> tuple[str, str]:
|
||||
"""Compress a base64 encoded image to ensure it's below a certain size.
|
||||
|
||||
Args:
|
||||
base64_str: Base64 encoded image string (with or without data URL prefix)
|
||||
max_size_bytes: Maximum size in bytes (default: 5MB)
|
||||
quality: Initial JPEG quality (0-100)
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (Compressed base64 encoded image, media_type)
|
||||
"""
|
||||
# Handle data URL prefix if present (e.g., "data:image/png;base64,...")
|
||||
original_prefix = ""
|
||||
media_type = "image/png" # Default media type
|
||||
|
||||
if base64_str.startswith("data:"):
|
||||
parts = base64_str.split(",", 1)
|
||||
if len(parts) == 2:
|
||||
original_prefix = parts[0] + ","
|
||||
base64_str = parts[1]
|
||||
# Try to extract media type from the prefix
|
||||
if "image/jpeg" in original_prefix.lower():
|
||||
media_type = "image/jpeg"
|
||||
elif "image/png" in original_prefix.lower():
|
||||
media_type = "image/png"
|
||||
|
||||
# Check if the base64 string is small enough already
|
||||
if len(base64_str) <= max_size_bytes:
|
||||
logger.info(f"Image already within size limit: {len(base64_str)} bytes")
|
||||
return original_prefix + base64_str, media_type
|
||||
|
||||
try:
|
||||
# Decode base64
|
||||
img_data = base64.b64decode(base64_str)
|
||||
img_size = len(img_data)
|
||||
logger.info(f"Original image size: {img_size} bytes")
|
||||
|
||||
# Open image
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
# First, try to compress as PNG (maintains transparency if present)
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG", optimize=True)
|
||||
buffer.seek(0)
|
||||
compressed_data = buffer.getvalue()
|
||||
compressed_b64 = base64.b64encode(compressed_data).decode("utf-8")
|
||||
|
||||
if len(compressed_b64) <= max_size_bytes:
|
||||
logger.info(f"Compressed to {len(compressed_data)} bytes as PNG")
|
||||
return compressed_b64, "image/png"
|
||||
|
||||
# Strategy 1: Try reducing quality with JPEG format
|
||||
current_quality = quality
|
||||
while current_quality > 20:
|
||||
buffer = io.BytesIO()
|
||||
# Convert to RGB if image has alpha channel (JPEG doesn't support transparency)
|
||||
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
|
||||
logger.info("Converting transparent image to RGB for JPEG compression")
|
||||
rgb_img = Image.new("RGB", img.size, (255, 255, 255))
|
||||
rgb_img.paste(img, mask=img.split()[3] if img.mode == "RGBA" else None)
|
||||
rgb_img.save(buffer, format="JPEG", quality=current_quality, optimize=True)
|
||||
else:
|
||||
img.save(buffer, format="JPEG", quality=current_quality, optimize=True)
|
||||
|
||||
buffer.seek(0)
|
||||
compressed_data = buffer.getvalue()
|
||||
compressed_b64 = base64.b64encode(compressed_data).decode("utf-8")
|
||||
|
||||
if len(compressed_b64) <= max_size_bytes:
|
||||
logger.info(
|
||||
f"Compressed to {len(compressed_data)} bytes with JPEG quality {current_quality}"
|
||||
)
|
||||
return compressed_b64, "image/jpeg"
|
||||
|
||||
# Reduce quality and try again
|
||||
current_quality -= 10
|
||||
|
||||
# Strategy 2: If quality reduction isn't enough, reduce dimensions
|
||||
scale_factor = 0.8
|
||||
current_img = img
|
||||
|
||||
while scale_factor > 0.3:
|
||||
# Resize image
|
||||
new_width = int(img.width * scale_factor)
|
||||
new_height = int(img.height * scale_factor)
|
||||
current_img = img.resize((new_width, new_height), Image.LANCZOS)
|
||||
|
||||
# Try with reduced size and quality
|
||||
buffer = io.BytesIO()
|
||||
# Convert to RGB if necessary for JPEG
|
||||
if current_img.mode in ("RGBA", "LA") or (
|
||||
current_img.mode == "P" and "transparency" in current_img.info
|
||||
):
|
||||
rgb_img = Image.new("RGB", current_img.size, (255, 255, 255))
|
||||
rgb_img.paste(
|
||||
current_img, mask=current_img.split()[3] if current_img.mode == "RGBA" else None
|
||||
)
|
||||
rgb_img.save(buffer, format="JPEG", quality=70, optimize=True)
|
||||
else:
|
||||
current_img.save(buffer, format="JPEG", quality=70, optimize=True)
|
||||
|
||||
buffer.seek(0)
|
||||
compressed_data = buffer.getvalue()
|
||||
compressed_b64 = base64.b64encode(compressed_data).decode("utf-8")
|
||||
|
||||
if len(compressed_b64) <= max_size_bytes:
|
||||
logger.info(
|
||||
f"Compressed to {len(compressed_data)} bytes with scale {scale_factor} and JPEG quality 70"
|
||||
)
|
||||
return compressed_b64, "image/jpeg"
|
||||
|
||||
# Reduce scale factor and try again
|
||||
scale_factor -= 0.1
|
||||
|
||||
# If we get here, we couldn't compress enough
|
||||
logger.warning("Could not compress image below required size with quality preservation")
|
||||
|
||||
# Last resort: Use minimum quality and size
|
||||
buffer = io.BytesIO()
|
||||
smallest_img = img.resize((int(img.width * 0.5), int(img.height * 0.5)), Image.LANCZOS)
|
||||
# Convert to RGB if necessary
|
||||
if smallest_img.mode in ("RGBA", "LA") or (
|
||||
smallest_img.mode == "P" and "transparency" in smallest_img.info
|
||||
):
|
||||
rgb_img = Image.new("RGB", smallest_img.size, (255, 255, 255))
|
||||
rgb_img.paste(
|
||||
smallest_img, mask=smallest_img.split()[3] if smallest_img.mode == "RGBA" else None
|
||||
)
|
||||
rgb_img.save(buffer, format="JPEG", quality=20, optimize=True)
|
||||
else:
|
||||
smallest_img.save(buffer, format="JPEG", quality=20, optimize=True)
|
||||
|
||||
buffer.seek(0)
|
||||
final_data = buffer.getvalue()
|
||||
final_b64 = base64.b64encode(final_data).decode("utf-8")
|
||||
|
||||
logger.warning(f"Final compressed size: {len(final_b64)} bytes (may still exceed limit)")
|
||||
return final_b64, "image/jpeg"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error compressing image: {str(e)}")
|
||||
raise
|
||||
130
libs/agent/agent/providers/omni/visualization.py
Normal file
130
libs/agent/agent/providers/omni/visualization.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Visualization utilities for the Cua provider."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
|
||||
"""Visualize a click action by drawing on the screenshot.
|
||||
|
||||
Args:
|
||||
x: X coordinate of the click
|
||||
y: Y coordinate of the click
|
||||
img_base64: Base64 encoded image to draw on
|
||||
|
||||
Returns:
|
||||
PIL Image with visualization
|
||||
"""
|
||||
try:
|
||||
# Decode the base64 image
|
||||
img_data = base64.b64decode(img_base64)
|
||||
img = Image.open(BytesIO(img_data))
|
||||
|
||||
# Create a drawing context
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Draw concentric circles at the click position
|
||||
small_radius = 10
|
||||
large_radius = 30
|
||||
|
||||
# Draw filled inner circle
|
||||
draw.ellipse(
|
||||
[(x - small_radius, y - small_radius), (x + small_radius, y + small_radius)],
|
||||
fill="red",
|
||||
)
|
||||
|
||||
# Draw outlined outer circle
|
||||
draw.ellipse(
|
||||
[(x - large_radius, y - large_radius), (x + large_radius, y + large_radius)],
|
||||
outline="red",
|
||||
width=3,
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing click: {str(e)}")
|
||||
# Return a blank image in case of error
|
||||
return Image.new("RGB", (800, 600), color="white")
|
||||
|
||||
|
||||
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
|
||||
"""Visualize a scroll action by drawing arrows on the screenshot.
|
||||
|
||||
Args:
|
||||
direction: 'up' or 'down'
|
||||
clicks: Number of scroll clicks
|
||||
img_base64: Base64 encoded image to draw on
|
||||
|
||||
Returns:
|
||||
PIL Image with visualization
|
||||
"""
|
||||
try:
|
||||
# Decode the base64 image
|
||||
img_data = base64.b64decode(img_base64)
|
||||
img = Image.open(BytesIO(img_data))
|
||||
|
||||
# Get image dimensions
|
||||
width, height = img.size
|
||||
|
||||
# Create a drawing context
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Determine arrow direction and positions
|
||||
center_x = width // 2
|
||||
arrow_width = 100
|
||||
|
||||
if direction.lower() == "up":
|
||||
# Draw up arrow in the middle of the screen
|
||||
arrow_y = height // 2
|
||||
# Arrow points
|
||||
points = [
|
||||
(center_x, arrow_y - 50), # Top point
|
||||
(center_x - arrow_width // 2, arrow_y + 50), # Bottom left
|
||||
(center_x + arrow_width // 2, arrow_y + 50), # Bottom right
|
||||
]
|
||||
color = "blue"
|
||||
else: # down
|
||||
# Draw down arrow in the middle of the screen
|
||||
arrow_y = height // 2
|
||||
# Arrow points
|
||||
points = [
|
||||
(center_x, arrow_y + 50), # Bottom point
|
||||
(center_x - arrow_width // 2, arrow_y - 50), # Top left
|
||||
(center_x + arrow_width // 2, arrow_y - 50), # Top right
|
||||
]
|
||||
color = "green"
|
||||
|
||||
# Draw filled arrow
|
||||
draw.polygon(points, fill=color)
|
||||
|
||||
# Add text showing number of clicks
|
||||
text_y = arrow_y + 70 if direction.lower() == "down" else arrow_y - 70
|
||||
draw.text((center_x - 40, text_y), f"{clicks} clicks", fill="black")
|
||||
|
||||
return img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing scroll: {str(e)}")
|
||||
# Return a blank image in case of error
|
||||
return Image.new("RGB", (800, 600), color="white")
|
||||
|
||||
|
||||
def calculate_element_center(box: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
||||
"""Calculate the center coordinates of a bounding box.
|
||||
|
||||
Args:
|
||||
box: Tuple of (left, top, right, bottom) coordinates
|
||||
|
||||
Returns:
|
||||
Tuple of (center_x, center_y) coordinates
|
||||
"""
|
||||
left, top, right, bottom = box
|
||||
center_x = (left + right) // 2
|
||||
center_y = (top + bottom) // 2
|
||||
return center_x, center_y
|
||||
21
libs/agent/agent/telemetry.py
Normal file
21
libs/agent/agent/telemetry.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Telemetry support for Agent class."""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from core.telemetry import (
|
||||
record_event,
|
||||
is_telemetry_enabled,
|
||||
flush,
|
||||
get_telemetry_client,
|
||||
increment,
|
||||
)
|
||||
|
||||
# System information used for telemetry
|
||||
SYSTEM_INFO = {
|
||||
"os": sys.platform,
|
||||
"python_version": platform.python_version(),
|
||||
}
|
||||
26
libs/agent/agent/types/__init__.py
Normal file
26
libs/agent/agent/types/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Type definitions for the agent package."""
|
||||
|
||||
from .base import Provider, HostConfig, TaskResult, Annotation
|
||||
from .messages import Message, Request, Response, StepMessage, DisengageMessage
|
||||
from .tools import ToolInvocation, ToolInvocationState, ClientAttachment, ToolResult
|
||||
|
||||
__all__ = [
|
||||
# Base types
|
||||
"Provider",
|
||||
"HostConfig",
|
||||
"TaskResult",
|
||||
"Annotation",
|
||||
|
||||
# Message types
|
||||
"Message",
|
||||
"Request",
|
||||
"Response",
|
||||
"StepMessage",
|
||||
"DisengageMessage",
|
||||
|
||||
# Tool types
|
||||
"ToolInvocation",
|
||||
"ToolInvocationState",
|
||||
"ClientAttachment",
|
||||
"ToolResult",
|
||||
]
|
||||
53
libs/agent/agent/types/base.py
Normal file
53
libs/agent/agent/types/base.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Base type definitions."""
|
||||
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class Provider(str, Enum):
|
||||
"""Available AI providers."""
|
||||
|
||||
UNKNOWN = "unknown" # Default provider for base class
|
||||
ANTHROPIC = "anthropic"
|
||||
OPENAI = "openai"
|
||||
OLLAMA = "ollama"
|
||||
OMNI = "omni"
|
||||
GROQ = "groq"
|
||||
|
||||
|
||||
class HostConfig(BaseModel):
|
||||
"""Host configuration."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
hostname: str
|
||||
port: int
|
||||
|
||||
@property
|
||||
def address(self) -> str:
|
||||
return f"{self.hostname}:{self.port}"
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
"""Result of a task execution."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
result: str
|
||||
vnc_password: str
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
"""Annotation metadata."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
id: str
|
||||
vm_url: str
|
||||
|
||||
|
||||
class AgentLoop(Enum):
|
||||
"""Enumeration of available loop types."""
|
||||
|
||||
ANTHROPIC = auto() # Anthropic implementation
|
||||
OPENAI = auto() # OpenAI implementation
|
||||
OMNI = auto() # OmniLoop implementation
|
||||
# Add more loop types as needed
|
||||
36
libs/agent/agent/types/messages.py
Normal file
36
libs/agent/agent/types/messages.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Message-related type definitions."""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .tools import ToolInvocation
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Base message type."""
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
role: str
|
||||
content: str
|
||||
annotations: Optional[List[Dict[str, Any]]] = None
|
||||
toolInvocations: Optional[List[ToolInvocation]] = None
|
||||
data: Optional[List[Dict[str, Any]]] = None
|
||||
errors: Optional[List[str]] = None
|
||||
|
||||
class Request(BaseModel):
|
||||
"""Request type."""
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
messages: List[Message]
|
||||
selectedModel: str
|
||||
|
||||
class Response(BaseModel):
|
||||
"""Response type."""
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
messages: List[Message]
|
||||
vm_url: str
|
||||
|
||||
class StepMessage(Message):
|
||||
"""Message for a single step."""
|
||||
pass
|
||||
|
||||
class DisengageMessage(BaseModel):
|
||||
"""Message indicating disengagement."""
|
||||
pass
|
||||
32
libs/agent/agent/types/tools.py
Normal file
32
libs/agent/agent/types/tools.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Tool-related type definitions."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
class ToolInvocationState(str, Enum):
|
||||
"""States for tool invocation."""
|
||||
CALL = 'call'
|
||||
PARTIAL_CALL = 'partial-call'
|
||||
RESULT = 'result'
|
||||
|
||||
class ToolInvocation(BaseModel):
|
||||
"""Tool invocation type."""
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
state: Optional[str] = None
|
||||
toolCallId: str
|
||||
toolName: Optional[str] = None
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ClientAttachment(BaseModel):
|
||||
"""Client attachment type."""
|
||||
name: str
|
||||
contentType: str
|
||||
url: str
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Result of a tool execution."""
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
output: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
2
libs/agent/poetry.toml
Normal file
2
libs/agent/poetry.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[virtualenvs]
|
||||
in-project = true
|
||||
103
libs/agent/pyproject.toml
Normal file
103
libs/agent/pyproject.toml
Normal file
@@ -0,0 +1,103 @@
|
||||
[build-system]
|
||||
requires = ["pdm-backend"]
|
||||
build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-agent"
|
||||
version = "0.1.0"
|
||||
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "TryCua", email = "gh@trycua.com" }
|
||||
]
|
||||
dependencies = [
|
||||
"httpx>=0.27.0,<0.29.0",
|
||||
"aiohttp>=3.9.3,<4.0.0",
|
||||
"asyncio",
|
||||
"anyio>=4.4.1,<5.0.0",
|
||||
"typing-extensions>=4.12.2,<5.0.0",
|
||||
"pydantic>=2.6.4,<3.0.0",
|
||||
"rich>=13.7.1,<14.0.0",
|
||||
"python-dotenv>=1.0.1,<2.0.0",
|
||||
"cua-computer>=0.1.0,<0.2.0",
|
||||
"cua-core>=0.1.0,<0.2.0",
|
||||
"certifi>=2024.2.2"
|
||||
]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
|
||||
[project.optional-dependencies]
|
||||
anthropic = [
|
||||
"anthropic>=0.49.0",
|
||||
"boto3>=1.35.81,<2.0.0",
|
||||
]
|
||||
som = [
|
||||
"torch>=2.2.1",
|
||||
"torchvision>=0.17.1",
|
||||
"ultralytics>=8.0.0",
|
||||
"transformers>=4.38.2",
|
||||
"cua-som>=0.1.0,<0.2.0",
|
||||
# Include all provider dependencies
|
||||
"anthropic>=0.46.0,<0.47.0",
|
||||
"boto3>=1.35.81,<2.0.0",
|
||||
"openai>=1.14.0,<2.0.0",
|
||||
"groq>=0.4.0,<0.5.0",
|
||||
"dashscope>=1.13.0,<2.0.0",
|
||||
"requests>=2.31.0,<3.0.0"
|
||||
]
|
||||
all = [
|
||||
# Include all optional dependencies
|
||||
"torch>=2.2.1",
|
||||
"torchvision>=0.17.1",
|
||||
"ultralytics>=8.0.0",
|
||||
"transformers>=4.38.2",
|
||||
"cua-som>=0.1.0,<0.2.0",
|
||||
"anthropic>=0.46.0,<0.47.0",
|
||||
"boto3>=1.35.81,<2.0.0",
|
||||
"openai>=1.14.0,<2.0.0",
|
||||
"groq>=0.4.0,<0.5.0",
|
||||
"dashscope>=1.13.0,<2.0.0",
|
||||
"requests>=2.31.0,<3.0.0"
|
||||
]
|
||||
|
||||
[tool.pdm]
|
||||
distribution = true
|
||||
|
||||
[tool.pdm.build]
|
||||
includes = ["agent/"]
|
||||
source-includes = ["tests/", "README.md", "LICENSE"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
target-version = ["py310"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py310"
|
||||
select = ["E", "F", "B", "I"]
|
||||
fix = true
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.mypy]
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
disallow_untyped_defs = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
python_files = "test_*.py"
|
||||
[dependency-groups]
|
||||
cua-som = [
|
||||
"path",
|
||||
"develop",
|
||||
"optional",
|
||||
"groups",
|
||||
"anthropic>=0.49.0",
|
||||
]
|
||||
91
libs/agent/tests/test_agent.py
Normal file
91
libs/agent/tests/test_agent.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# """Basic tests for the agent package."""
|
||||
|
||||
# import pytest
|
||||
# from agent import OmniComputerAgent, LLMProvider
|
||||
# from agent.base.agent import BaseComputerAgent
|
||||
# from computer import Computer
|
||||
|
||||
# def test_agent_import():
|
||||
# """Test that we can import the OmniComputerAgent class."""
|
||||
# assert OmniComputerAgent is not None
|
||||
# assert LLMProvider is not None
|
||||
|
||||
# def test_agent_init():
|
||||
# """Test that we can create an OmniComputerAgent instance."""
|
||||
# agent = OmniComputerAgent(
|
||||
# provider=LLMProvider.OPENAI,
|
||||
# use_host_computer_server=True
|
||||
# )
|
||||
# assert agent is not None
|
||||
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# def test_computer_agent_anthropic():
|
||||
# """Test creating an Anthropic agent."""
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC)
|
||||
# assert isinstance(agent._agent, BaseComputerAgent)
|
||||
|
||||
# def test_computer_agent_invalid_provider():
|
||||
# """Test creating an agent with an invalid provider."""
|
||||
# with pytest.raises(ValueError, match="Unsupported provider"):
|
||||
# ComputerAgent(provider="invalid_provider")
|
||||
|
||||
# def test_computer_agent_uninstalled_provider():
|
||||
# """Test creating an agent with an uninstalled provider."""
|
||||
# with pytest.raises(NotImplementedError, match="OpenAI provider not yet implemented"):
|
||||
# # OpenAI provider is not implemented yet
|
||||
# ComputerAgent(provider=Provider.OPENAI)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# async def test_agent_cleanup():
|
||||
# """Test agent cleanup."""
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC)
|
||||
# await agent.cleanup() # Should not raise any errors
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# async def test_agent_direct_initialization():
|
||||
# """Test direct initialization of the agent."""
|
||||
# # Create with default computer
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC)
|
||||
# try:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
# finally:
|
||||
# await agent.cleanup()
|
||||
|
||||
# # Create with custom computer
|
||||
# custom_computer = Computer(
|
||||
# display="1920x1080",
|
||||
# memory="8GB",
|
||||
# cpu="4",
|
||||
# os="macos",
|
||||
# use_host_computer_server=False,
|
||||
# )
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC, computer=custom_computer)
|
||||
# try:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
# finally:
|
||||
# await agent.cleanup()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# async def test_agent_context_manager():
|
||||
# """Test context manager initialization of the agent."""
|
||||
# # Test with default computer
|
||||
# async with ComputerAgent(provider=Provider.ANTHROPIC) as agent:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
|
||||
# # Test with custom computer
|
||||
# custom_computer = Computer(
|
||||
# display="1920x1080",
|
||||
# memory="8GB",
|
||||
# cpu="4",
|
||||
# os="macos",
|
||||
# use_host_computer_server=False,
|
||||
# )
|
||||
# async with ComputerAgent(provider=Provider.ANTHROPIC, computer=custom_computer) as agent:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
38
libs/computer-server/README.md
Normal file
38
libs/computer-server/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
<div align="center">
|
||||
<h1>
|
||||
<div class="image-wrapper" style="display: inline-block;">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="../../img/logo_white.png" style="display: block; margin: auto;">
|
||||
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="../../img/logo_black.png" style="display: block; margin: auto;">
|
||||
<img alt="Shows my svg">
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer-server/)
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
**Computer Server** is the server component for the Computer-Use Interface (CUI) framework powering Cua for interacting with local macOS and Linux sandboxes, PyAutoGUI-compatible, and pluggable with any AI agent systems (Cua, Langchain, CrewAI, AutoGen).
|
||||
|
||||
## Features
|
||||
|
||||
- WebSocket API for computer-use
|
||||
- Cross-platform support (macOS, Linux)
|
||||
- Integration with CUA computer library for screen control, keyboard/mouse automation, and accessibility
|
||||
|
||||
## Install
|
||||
|
||||
To install the Computer-Use Interface (CUI):
|
||||
|
||||
```bash
|
||||
pip install cua-computer-server
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
Refer to this notebook for a step-by-step guide on how to use the Computer-Use Server on the host system or VM:
|
||||
|
||||
- [Computer-Use Server](../../notebooks/computer_server_nb.ipynb)
|
||||
20
libs/computer-server/computer_server/__init__.py
Normal file
20
libs/computer-server/computer_server/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Computer API package.
|
||||
Provides a server interface for the Computer API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__version__: str = "0.1.0"
|
||||
|
||||
# Explicitly export Server for static type checkers
|
||||
from .server import Server as Server # noqa: F401
|
||||
|
||||
__all__ = ["Server", "run_cli"]
|
||||
|
||||
|
||||
def run_cli() -> None:
|
||||
"""Entry point for CLI"""
|
||||
from .cli import main
|
||||
|
||||
main()
|
||||
59
libs/computer-server/computer_server/cli.py
Normal file
59
libs/computer-server/computer_server/cli.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Command-line interface for the Computer API server.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
from .server import Server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Start the Computer API server")
|
||||
parser.add_argument(
|
||||
"--host", default="0.0.0.0", help="Host to bind the server to (default: 0.0.0.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to bind the server to (default: 8000)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
choices=["debug", "info", "warning", "error", "critical"],
|
||||
default="info",
|
||||
help="Logging level (default: info)",
|
||||
)
|
||||
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point for the CLI."""
|
||||
args = parse_args()
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
# Create and start the server
|
||||
logger.info(f"Starting CUA Computer API server on {args.host}:{args.port}...")
|
||||
server = Server(host=args.host, port=args.port, log_level=args.log_level)
|
||||
|
||||
try:
|
||||
server.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting server: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
120
libs/computer-server/computer_server/handlers/base.py
Normal file
120
libs/computer-server/computer_server/handlers/base.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
class BaseAccessibilityHandler(ABC):
|
||||
"""Abstract base class for OS-specific accessibility handlers."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_accessibility_tree(self) -> Dict[str, Any]:
|
||||
"""Get the accessibility tree of the current window."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def find_element(self, role: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
value: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Find an element in the accessibility tree by criteria."""
|
||||
pass
|
||||
|
||||
class BaseAutomationHandler(ABC):
|
||||
"""Abstract base class for OS-specific automation handlers.
|
||||
|
||||
Categories:
|
||||
- Mouse Actions: Methods for mouse control
|
||||
- Keyboard Actions: Methods for keyboard input
|
||||
- Scrolling Actions: Methods for scrolling
|
||||
- Screen Actions: Methods for screen interaction
|
||||
- Clipboard Actions: Methods for clipboard operations
|
||||
"""
|
||||
|
||||
# Mouse Actions
|
||||
@abstractmethod
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Perform a left click at the current or specified position."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Perform a right click at the current or specified position."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Perform a double click at the current or specified position."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def move_cursor(self, x: int, y: int) -> Dict[str, Any]:
|
||||
"""Move the cursor to the specified position."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
|
||||
"""Drag the cursor from current position to specified coordinates.
|
||||
|
||||
Args:
|
||||
x: The x coordinate to drag to
|
||||
y: The y coordinate to drag to
|
||||
button: The mouse button to use ('left', 'middle', 'right')
|
||||
duration: How long the drag should take in seconds
|
||||
"""
|
||||
pass
|
||||
|
||||
# Keyboard Actions
|
||||
@abstractmethod
|
||||
async def type_text(self, text: str) -> Dict[str, Any]:
|
||||
"""Type the specified text."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def press_key(self, key: str) -> Dict[str, Any]:
|
||||
"""Press the specified key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def hotkey(self, *keys: str) -> Dict[str, Any]:
|
||||
"""Press a combination of keys together."""
|
||||
pass
|
||||
|
||||
# Scrolling Actions
|
||||
@abstractmethod
|
||||
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
|
||||
"""Scroll down by the specified number of clicks."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]:
|
||||
"""Scroll up by the specified number of clicks."""
|
||||
pass
|
||||
|
||||
# Screen Actions
|
||||
@abstractmethod
|
||||
async def screenshot(self) -> Dict[str, Any]:
|
||||
"""Take a screenshot and return base64 encoded image data."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_screen_size(self) -> Dict[str, Any]:
|
||||
"""Get the screen size of the VM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_cursor_position(self) -> Dict[str, Any]:
|
||||
"""Get the current cursor position."""
|
||||
pass
|
||||
|
||||
# Clipboard Actions
|
||||
@abstractmethod
|
||||
async def copy_to_clipboard(self) -> Dict[str, Any]:
|
||||
"""Get the current clipboard content."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_clipboard(self, text: str) -> Dict[str, Any]:
|
||||
"""Set the clipboard content."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run_command(self, command: str) -> Dict[str, Any]:
|
||||
"""Run a command and return the output."""
|
||||
pass
|
||||
49
libs/computer-server/computer_server/handlers/factory.py
Normal file
49
libs/computer-server/computer_server/handlers/factory.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import platform
|
||||
import subprocess
|
||||
from typing import Tuple, Type
|
||||
from .base import BaseAccessibilityHandler, BaseAutomationHandler
|
||||
from .macos import MacOSAccessibilityHandler, MacOSAutomationHandler
|
||||
# from .linux import LinuxAccessibilityHandler, LinuxAutomationHandler
|
||||
|
||||
class HandlerFactory:
|
||||
"""Factory for creating OS-specific handlers."""
|
||||
|
||||
@staticmethod
|
||||
def _get_current_os() -> str:
|
||||
"""Determine the current OS.
|
||||
|
||||
Returns:
|
||||
str: The OS type ('darwin' for macOS or 'linux' for Linux)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If unable to determine the current OS
|
||||
"""
|
||||
try:
|
||||
# Use uname -s to determine OS since this runs on the target machine
|
||||
result = subprocess.run(['uname', '-s'], capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"uname command failed: {result.stderr}")
|
||||
return result.stdout.strip().lower()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to determine current OS: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def create_handlers() -> Tuple[BaseAccessibilityHandler, BaseAutomationHandler]:
|
||||
"""Create and return appropriate handlers for the current OS.
|
||||
|
||||
Returns:
|
||||
Tuple[BaseAccessibilityHandler, BaseAutomationHandler]: A tuple containing
|
||||
the appropriate accessibility and automation handlers for the current OS.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the current OS is not supported
|
||||
RuntimeError: If unable to determine the current OS
|
||||
"""
|
||||
os_type = HandlerFactory._get_current_os()
|
||||
|
||||
if os_type == 'darwin':
|
||||
return MacOSAccessibilityHandler(), MacOSAutomationHandler()
|
||||
# elif os_type == 'linux':
|
||||
# return LinuxAccessibilityHandler(), LinuxAutomationHandler()
|
||||
else:
|
||||
raise NotImplementedError(f"OS '{os_type}' is not supported")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user