diff --git a/.cursorignore b/.cursorignore new file mode 100644 index 00000000..12e8e403 --- /dev/null +++ b/.cursorignore @@ -0,0 +1,233 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +!libs/lume/scripts/build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Scripts +server/scripts/ + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Conda +.conda/ + +# Local environment +.env.local + +# macOS DS_Store +.DS_Store + +weights/ +weights/icon_detect/ +weights/icon_detect/model.pt +weights/icon_detect/model.pt.zip +weights/icon_detect/model.pt.zip.part* + +libs/omniparser/weights/icon_detect/model.pt + +# Example test data and output +examples/test_data/ +examples/output/ + +/screenshots/ + +/experiments/ + +/logs/ + +# Xcode +# +# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore + +## User settings +xcuserdata/ + +## Obj-C/Swift specific +*.hmap + +## App packaging +*.ipa +*.dSYM.zip +*.dSYM + +## Playgrounds +timeline.xctimeline +playground.xcworkspace + +# Swift Package Manager +# +# Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. +# Packages/ +# Package.pins +# Package.resolved +# *.xcodeproj +# +# Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata +# hence it is not needed unless you have added a package configuration file to your project +.swiftpm/ +.build/ + +# CocoaPods +# +# We recommend against adding the Pods directory to your .gitignore. However +# you should judge for yourself, the pros and cons are mentioned at: +# https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control +# +# Pods/ +# +# Add this line if you want to avoid checking in source code from the Xcode workspace +# *.xcworkspace + +# Carthage +# +# Add this line if you want to avoid checking in source code from Carthage dependencies. +# Carthage/Checkouts +Carthage/Build/ + +# fastlane +# +# It is recommended to not store the screenshots in the git repo. +# Instead, use fastlane to re-generate the screenshots whenever they are needed. +# For more information about the recommended setup visit: +# https://docs.fastlane.tools/best-practices/source-control/#source-control +fastlane/report.xml +fastlane/Preview.html +fastlane/screenshots/**/*.png +fastlane/test_output + +# Ignore folder +ignore + +# .release +.release/ \ No newline at end of file diff --git a/.github/workflows/publish-agent.yml b/.github/workflows/publish-agent.yml new file mode 100644 index 00000000..1566880b --- /dev/null +++ b/.github/workflows/publish-agent.yml @@ -0,0 +1,162 @@ +name: Publish Agent Package + +on: + push: + tags: + - 'agent-v*' + workflow_dispatch: + inputs: + version: + description: 'Version to publish (without v prefix)' + required: true + default: '0.1.0' + workflow_call: + inputs: + version: + description: 'Version to publish' + required: true + type: string + +# Adding permissions at workflow level +permissions: + contents: write + +jobs: + prepare: + runs-on: macos-latest + outputs: + version: ${{ steps.get-version.outputs.version }} + computer_version: ${{ steps.update-deps.outputs.computer_version }} + som_version: ${{ steps.update-deps.outputs.som_version }} + core_version: ${{ steps.update-deps.outputs.core_version }} + steps: + - uses: actions/checkout@v4 + + - name: Determine version + id: get-version + run: | + if [ "${{ github.event_name }}" == "push" ]; then + # Extract version from tag (for package-specific tags) + if [[ "${{ github.ref }}" =~ ^refs/tags/agent-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then + VERSION=${BASH_REMATCH[1]} + else + echo "Invalid tag format for agent" + exit 1 + fi + elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then + # Use version from workflow dispatch + VERSION=${{ github.event.inputs.version }} + else + # Use version from workflow_call + VERSION=${{ inputs.version }} + fi + echo "VERSION=$VERSION" + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Update dependencies to latest versions + id: update-deps + run: | + cd libs/agent + + # Install required package for PyPI API access + pip install requests + + # Create a more robust Python script for PyPI version checking + cat > get_latest_versions.py << 'EOF' + import requests + import json + import sys + + def get_package_version(package_name, fallback="0.1.0"): + try: + response = requests.get(f'https://pypi.org/pypi/{package_name}/json') + print(f"API Response Status for {package_name}: {response.status_code}", file=sys.stderr) + + if response.status_code != 200: + print(f"API request failed for {package_name}, using fallback version", file=sys.stderr) + return fallback + + data = json.loads(response.text) + + if 'info' not in data: + print(f"Missing 'info' key in API response for {package_name}, using fallback version", file=sys.stderr) + return fallback + + return data['info']['version'] + except Exception as e: + print(f"Error fetching version for {package_name}: {str(e)}", file=sys.stderr) + return fallback + + # Get latest versions + print(get_package_version('cua-computer')) + print(get_package_version('cua-som')) + print(get_package_version('cua-core')) + EOF + + # Execute the script to get the versions + VERSIONS=($(python get_latest_versions.py)) + LATEST_COMPUTER=${VERSIONS[0]} + LATEST_SOM=${VERSIONS[1]} + LATEST_CORE=${VERSIONS[2]} + + echo "Latest cua-computer version: $LATEST_COMPUTER" + echo "Latest cua-som version: $LATEST_SOM" + echo "Latest cua-core version: $LATEST_CORE" + + # Output the versions for the next job + echo "computer_version=$LATEST_COMPUTER" >> $GITHUB_OUTPUT + echo "som_version=$LATEST_SOM" >> $GITHUB_OUTPUT + echo "core_version=$LATEST_CORE" >> $GITHUB_OUTPUT + + # Determine major version for version constraint + COMPUTER_MAJOR=$(echo $LATEST_COMPUTER | cut -d. -f1) + SOM_MAJOR=$(echo $LATEST_SOM | cut -d. -f1) + CORE_MAJOR=$(echo $LATEST_CORE | cut -d. -f1) + + NEXT_COMPUTER_MAJOR=$((COMPUTER_MAJOR + 1)) + NEXT_SOM_MAJOR=$((SOM_MAJOR + 1)) + NEXT_CORE_MAJOR=$((CORE_MAJOR + 1)) + + # Update dependencies in pyproject.toml + if [[ "$OSTYPE" == "darwin"* ]]; then + # macOS version of sed needs an empty string for -i + sed -i '' "s/\"cua-computer>=.*,<.*\"/\"cua-computer>=$LATEST_COMPUTER,<$NEXT_COMPUTER_MAJOR.0.0\"/" pyproject.toml + sed -i '' "s/\"cua-som>=.*,<.*\"/\"cua-som>=$LATEST_SOM,<$NEXT_SOM_MAJOR.0.0\"/" pyproject.toml + sed -i '' "s/\"cua-core>=.*,<.*\"/\"cua-core>=$LATEST_CORE,<$NEXT_CORE_MAJOR.0.0\"/" pyproject.toml + else + # Linux version + sed -i "s/\"cua-computer>=.*,<.*\"/\"cua-computer>=$LATEST_COMPUTER,<$NEXT_COMPUTER_MAJOR.0.0\"/" pyproject.toml + sed -i "s/\"cua-som>=.*,<.*\"/\"cua-som>=$LATEST_SOM,<$NEXT_SOM_MAJOR.0.0\"/" pyproject.toml + sed -i "s/\"cua-core>=.*,<.*\"/\"cua-core>=$LATEST_CORE,<$NEXT_CORE_MAJOR.0.0\"/" pyproject.toml + fi + + # Display the updated dependencies + echo "Updated dependencies in pyproject.toml:" + grep -E "cua-computer|cua-som|cua-core" pyproject.toml + + publish: + needs: prepare + uses: ./.github/workflows/reusable-publish.yml + with: + package_name: "agent" + package_dir: "libs/agent" + version: ${{ needs.prepare.outputs.version }} + is_lume_package: false + base_package_name: "cua-agent" + secrets: + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} + + set-env-variables: + needs: [prepare, publish] + runs-on: macos-latest + steps: + - name: Set environment variables for use in other jobs + run: | + echo "COMPUTER_VERSION=${{ needs.prepare.outputs.computer_version }}" >> $GITHUB_ENV + echo "SOM_VERSION=${{ needs.prepare.outputs.som_version }}" >> $GITHUB_ENV + echo "CORE_VERSION=${{ needs.prepare.outputs.core_version }}" >> $GITHUB_ENV \ No newline at end of file diff --git a/.github/workflows/publish-all.yml b/.github/workflows/publish-all.yml new file mode 100644 index 00000000..f7e7e405 --- /dev/null +++ b/.github/workflows/publish-all.yml @@ -0,0 +1,227 @@ +name: Publish All Packages + +on: + push: + tags: + - 'v*' # For global releases (vX.Y.Z format) + workflow_dispatch: + inputs: + version: + description: 'Version to publish (without v prefix)' + required: true + default: '0.1.0' + packages: + description: 'Packages to publish (comma-separated)' + required: false + default: 'core,pylume,computer,som,agent,computer-server' + +# Adding permissions at workflow level +permissions: + contents: write + +jobs: + setup: + runs-on: macos-latest + outputs: + version: ${{ steps.get-version.outputs.version }} + publish_core: ${{ steps.set-packages.outputs.publish_core }} + publish_pylume: ${{ steps.set-packages.outputs.publish_pylume }} + publish_computer: ${{ steps.set-packages.outputs.publish_computer }} + publish_som: ${{ steps.set-packages.outputs.publish_som }} + publish_agent: ${{ steps.set-packages.outputs.publish_agent }} + publish_computer_server: ${{ steps.set-packages.outputs.publish_computer_server }} + steps: + - name: Determine version + id: get-version + run: | + if [ "${{ github.event_name }}" == "push" ]; then + # Extract version from tag (for global releases) + if [[ "${{ github.ref }}" =~ ^refs/tags/v([0-9]+\.[0-9]+\.[0-9]+) ]]; then + VERSION=${BASH_REMATCH[1]} + else + echo "Invalid tag format for global release" + exit 1 + fi + else + # Use version from workflow dispatch + VERSION=${{ github.event.inputs.version }} + fi + echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "Using version: $VERSION" + + - name: Determine packages to publish + id: set-packages + run: | + # Default to all packages for tag-based releases + if [ "${{ github.event_name }}" == "push" ]; then + PACKAGES="core,pylume,computer,som,agent,computer-server" + else + # Use packages from workflow dispatch + PACKAGES="${{ github.event.inputs.packages }}" + fi + + # Set individual flags for each package + echo "publish_core=$(echo $PACKAGES | grep -q "core" && echo "true" || echo "false")" >> $GITHUB_OUTPUT + echo "publish_pylume=$(echo $PACKAGES | grep -q "pylume" && echo "true" || echo "false")" >> $GITHUB_OUTPUT + echo "publish_computer=$(echo $PACKAGES | grep -q "computer" && echo "true" || echo "false")" >> $GITHUB_OUTPUT + echo "publish_som=$(echo $PACKAGES | grep -q "som" && echo "true" || echo "false")" >> $GITHUB_OUTPUT + echo "publish_agent=$(echo $PACKAGES | grep -q "agent" && echo "true" || echo "false")" >> $GITHUB_OUTPUT + echo "publish_computer_server=$(echo $PACKAGES | grep -q "computer-server" && echo "true" || echo "false")" >> $GITHUB_OUTPUT + + echo "Publishing packages: $PACKAGES" + + publish-core: + needs: setup + if: ${{ needs.setup.outputs.publish_core == 'true' }} + uses: ./.github/workflows/publish-core.yml + with: + version: ${{ needs.setup.outputs.version }} + secrets: inherit + + # Add a delay to ensure PyPI has registered the new core version + wait-for-core: + needs: [setup, publish-core] + if: ${{ needs.setup.outputs.publish_core == 'true' && (needs.setup.outputs.publish_computer == 'true' || needs.setup.outputs.publish_agent == 'true') }} + runs-on: macos-latest + steps: + - name: Wait for PyPI to update + run: | + echo "Waiting for PyPI to register the new core version..." + sleep 60 # Wait 60 seconds for PyPI to update its index + + publish-pylume: + needs: setup + if: ${{ needs.setup.outputs.publish_pylume == 'true' }} + uses: ./.github/workflows/publish-pylume.yml + with: + version: ${{ needs.setup.outputs.version }} + secrets: inherit + + # Add a delay to ensure PyPI has registered the new pylume version + wait-for-pylume: + needs: [setup, publish-pylume] + if: ${{ needs.setup.outputs.publish_pylume == 'true' && (needs.setup.outputs.publish_computer == 'true' || needs.setup.outputs.publish_som == 'true') }} + runs-on: macos-latest + steps: + - name: Wait for PyPI to update + run: | + echo "Waiting for PyPI to register the new pylume version..." + sleep 60 # Wait 60 seconds for PyPI to update its index + + publish-computer: + needs: [setup, publish-core, publish-pylume, wait-for-core, wait-for-pylume] + if: ${{ needs.setup.outputs.publish_computer == 'true' }} + uses: ./.github/workflows/publish-computer.yml + with: + version: ${{ needs.setup.outputs.version }} + secrets: inherit + + publish-som: + needs: [setup, publish-pylume, wait-for-pylume] + if: ${{ needs.setup.outputs.publish_som == 'true' }} + uses: ./.github/workflows/publish-som.yml + with: + version: ${{ needs.setup.outputs.version }} + secrets: inherit + + # Add a delay to ensure PyPI has registered the new computer and som versions + wait-for-deps: + needs: [setup, publish-computer, publish-som] + if: ${{ (needs.setup.outputs.publish_computer == 'true' || needs.setup.outputs.publish_som == 'true') && needs.setup.outputs.publish_agent == 'true' }} + runs-on: macos-latest + steps: + - name: Wait for PyPI to update + run: | + echo "Waiting for PyPI to register new dependency versions..." + sleep 60 # Wait 60 seconds for PyPI to update its index + + publish-agent: + needs: [setup, publish-core, publish-computer, publish-som, wait-for-core, wait-for-deps] + if: ${{ needs.setup.outputs.publish_agent == 'true' }} + uses: ./.github/workflows/publish-agent.yml + with: + version: ${{ needs.setup.outputs.version }} + secrets: inherit + + publish-computer-server: + needs: [setup, publish-computer] + if: ${{ needs.setup.outputs.publish_computer_server == 'true' }} + uses: ./.github/workflows/publish-computer-server.yml + with: + version: ${{ needs.setup.outputs.version }} + secrets: inherit + + # Create a global release for the entire CUA project + create-global-release: + needs: [setup, publish-core, publish-pylume, publish-computer, publish-som, publish-agent, publish-computer-server] + if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') }} + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + + - name: Create Summary Release Notes + id: release_notes + run: | + VERSION=${{ needs.setup.outputs.version }} + + # Create the release notes file + echo "# CUA v${VERSION} Release" > release_notes.md + echo "" >> release_notes.md + echo "This is a global release of the Computer Universal Automation (CUA) project." >> release_notes.md + echo "" >> release_notes.md + + echo "## Released Packages" >> release_notes.md + echo "" >> release_notes.md + + # Add links to individual package releases + if [[ "${{ needs.setup.outputs.publish_core }}" == "true" ]]; then + echo "- [cua-core v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/core-v${VERSION})" >> release_notes.md + fi + + if [[ "${{ needs.setup.outputs.publish_pylume }}" == "true" ]]; then + echo "- [pylume v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/pylume-v${VERSION})" >> release_notes.md + fi + + if [[ "${{ needs.setup.outputs.publish_computer }}" == "true" ]]; then + echo "- [cua-computer v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/computer-v${VERSION})" >> release_notes.md + fi + + if [[ "${{ needs.setup.outputs.publish_som }}" == "true" ]]; then + echo "- [cua-som v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/som-v${VERSION})" >> release_notes.md + fi + + if [[ "${{ needs.setup.outputs.publish_agent }}" == "true" ]]; then + echo "- [cua-agent v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/agent-v${VERSION})" >> release_notes.md + fi + + if [[ "${{ needs.setup.outputs.publish_computer_server }}" == "true" ]]; then + echo "- [cua-computer-server v${VERSION}](https://github.com/${{ github.repository }}/releases/tag/computer-server-v${VERSION})" >> release_notes.md + fi + + echo "" >> release_notes.md + echo "## Installation" >> release_notes.md + echo "" >> release_notes.md + echo "### Core Libraries" >> release_notes.md + echo "```bash" >> release_notes.md + echo "pip install cua-core==${VERSION}" >> release_notes.md + echo "pip install cua-computer==${VERSION}" >> release_notes.md + echo "pip install pylume==${VERSION}" >> release_notes.md + echo "```" >> release_notes.md + echo "" >> release_notes.md + echo "### Agent with SOM (Recommended)" >> release_notes.md + echo "```bash" >> release_notes.md + echo "pip install cua-agent[som]==${VERSION}" >> release_notes.md + echo "```" >> release_notes.md + + echo "Release notes created:" + cat release_notes.md + + - name: Create GitHub Global Release + uses: softprops/action-gh-release@v1 + with: + name: "CUA v${{ needs.setup.outputs.version }}" + body_path: release_notes.md + draft: false + prerelease: false + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/publish-computer-server.yml b/.github/workflows/publish-computer-server.yml new file mode 100644 index 00000000..15eca348 --- /dev/null +++ b/.github/workflows/publish-computer-server.yml @@ -0,0 +1,80 @@ +name: Publish Computer Server Package + +on: + push: + tags: + - 'computer-server-v*' + workflow_dispatch: + inputs: + version: + description: 'Version to publish (without v prefix)' + required: true + default: '0.1.0' + workflow_call: + inputs: + version: + description: 'Version to publish' + required: true + type: string + outputs: + version: + description: "The version that was published" + value: ${{ jobs.prepare.outputs.version }} + +# Adding permissions at workflow level +permissions: + contents: write + +jobs: + prepare: + runs-on: macos-latest + outputs: + version: ${{ steps.get-version.outputs.version }} + steps: + - uses: actions/checkout@v4 + + - name: Determine version + id: get-version + run: | + if [ "${{ github.event_name }}" == "push" ]; then + # Extract version from tag (for package-specific tags) + if [[ "${{ github.ref }}" =~ ^refs/tags/computer-server-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then + VERSION=${BASH_REMATCH[1]} + else + echo "Invalid tag format for computer-server" + exit 1 + fi + elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then + # Use version from workflow dispatch + VERSION=${{ github.event.inputs.version }} + else + # Use version from workflow_call + VERSION=${{ inputs.version }} + fi + echo "VERSION=$VERSION" + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + publish: + needs: prepare + uses: ./.github/workflows/reusable-publish.yml + with: + package_name: "computer-server" + package_dir: "libs/computer-server" + version: ${{ needs.prepare.outputs.version }} + is_lume_package: false + base_package_name: "cua-computer-server" + secrets: + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} + + set-env-variables: + needs: [prepare, publish] + runs-on: macos-latest + steps: + - name: Set environment variables for use in other jobs + run: | + echo "COMPUTER_VERSION=${{ needs.prepare.outputs.version }}" >> $GITHUB_ENV \ No newline at end of file diff --git a/.github/workflows/publish-computer.yml b/.github/workflows/publish-computer.yml new file mode 100644 index 00000000..9175907e --- /dev/null +++ b/.github/workflows/publish-computer.yml @@ -0,0 +1,148 @@ +name: Publish Computer Package + +on: + push: + tags: + - 'computer-v*' + workflow_dispatch: + inputs: + version: + description: 'Version to publish (without v prefix)' + required: true + default: '0.1.0' + workflow_call: + inputs: + version: + description: 'Version to publish' + required: true + type: string + +# Adding permissions at workflow level +permissions: + contents: write + +jobs: + prepare: + runs-on: macos-latest + outputs: + version: ${{ steps.get-version.outputs.version }} + pylume_version: ${{ steps.update-deps.outputs.pylume_version }} + core_version: ${{ steps.update-deps.outputs.core_version }} + steps: + - uses: actions/checkout@v4 + + - name: Determine version + id: get-version + run: | + if [ "${{ github.event_name }}" == "push" ]; then + # Extract version from tag (for package-specific tags) + if [[ "${{ github.ref }}" =~ ^refs/tags/computer-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then + VERSION=${BASH_REMATCH[1]} + else + echo "Invalid tag format for computer" + exit 1 + fi + elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then + # Use version from workflow dispatch + VERSION=${{ github.event.inputs.version }} + else + # Use version from workflow_call + VERSION=${{ inputs.version }} + fi + echo "VERSION=$VERSION" + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Update dependencies to latest versions + id: update-deps + run: | + cd libs/computer + # Install required package for PyPI API access + pip install requests + + # Create a more robust Python script for PyPI version checking + cat > get_latest_versions.py << 'EOF' + import requests + import json + import sys + + def get_package_version(package_name, fallback="0.1.0"): + try: + response = requests.get(f'https://pypi.org/pypi/{package_name}/json') + print(f"API Response Status for {package_name}: {response.status_code}", file=sys.stderr) + + if response.status_code != 200: + print(f"API request failed for {package_name}, using fallback version", file=sys.stderr) + return fallback + + data = json.loads(response.text) + + if 'info' not in data: + print(f"Missing 'info' key in API response for {package_name}, using fallback version", file=sys.stderr) + return fallback + + return data['info']['version'] + except Exception as e: + print(f"Error fetching version for {package_name}: {str(e)}", file=sys.stderr) + return fallback + + # Get latest versions + print(get_package_version('pylume')) + print(get_package_version('cua-core')) + EOF + + # Execute the script to get the versions + VERSIONS=($(python get_latest_versions.py)) + LATEST_PYLUME=${VERSIONS[0]} + LATEST_CORE=${VERSIONS[1]} + + echo "Latest pylume version: $LATEST_PYLUME" + echo "Latest cua-core version: $LATEST_CORE" + + # Output the versions for the next job + echo "pylume_version=$LATEST_PYLUME" >> $GITHUB_OUTPUT + echo "core_version=$LATEST_CORE" >> $GITHUB_OUTPUT + + # Determine major version for version constraint + CORE_MAJOR=$(echo $LATEST_CORE | cut -d. -f1) + NEXT_CORE_MAJOR=$((CORE_MAJOR + 1)) + + # Update dependencies in pyproject.toml + if [[ "$OSTYPE" == "darwin"* ]]; then + # macOS version of sed needs an empty string for -i + sed -i '' "s/\"pylume>=.*\"/\"pylume>=$LATEST_PYLUME\"/" pyproject.toml + sed -i '' "s/\"cua-core>=.*,<.*\"/\"cua-core>=$LATEST_CORE,<$NEXT_CORE_MAJOR.0.0\"/" pyproject.toml + else + # Linux version + sed -i "s/\"pylume>=.*\"/\"pylume>=$LATEST_PYLUME\"/" pyproject.toml + sed -i "s/\"cua-core>=.*,<.*\"/\"cua-core>=$LATEST_CORE,<$NEXT_CORE_MAJOR.0.0\"/" pyproject.toml + fi + + # Display the updated dependencies + echo "Updated dependencies in pyproject.toml:" + grep -E "pylume|cua-core" pyproject.toml + + publish: + needs: prepare + uses: ./.github/workflows/reusable-publish.yml + with: + package_name: "computer" + package_dir: "libs/computer" + version: ${{ needs.prepare.outputs.version }} + is_lume_package: false + base_package_name: "cua-computer" + secrets: + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} + + set-env-variables: + needs: [prepare, publish] + runs-on: macos-latest + steps: + - name: Set environment variables for use in other jobs + run: | + echo "PYLUME_VERSION=${{ needs.prepare.outputs.pylume_version }}" >> $GITHUB_ENV + echo "CORE_VERSION=${{ needs.prepare.outputs.core_version }}" >> $GITHUB_ENV \ No newline at end of file diff --git a/.github/workflows/publish-core.yml b/.github/workflows/publish-core.yml new file mode 100644 index 00000000..4f868f26 --- /dev/null +++ b/.github/workflows/publish-core.yml @@ -0,0 +1,63 @@ +name: Publish Core Package + +on: + push: + tags: + - 'core-v*' + workflow_dispatch: + inputs: + version: + description: 'Version to publish (without v prefix)' + required: true + default: '0.1.0' + workflow_call: + inputs: + version: + description: 'Version to publish' + required: true + type: string + +# Adding permissions at workflow level +permissions: + contents: write + +jobs: + prepare: + runs-on: macos-latest + outputs: + version: ${{ steps.get-version.outputs.version }} + steps: + - uses: actions/checkout@v4 + + - name: Determine version + id: get-version + run: | + if [ "${{ github.event_name }}" == "push" ]; then + # Extract version from tag (for package-specific tags) + if [[ "${{ github.ref }}" =~ ^refs/tags/core-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then + VERSION=${BASH_REMATCH[1]} + else + echo "Invalid tag format for core" + exit 1 + fi + elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then + # Use version from workflow dispatch + VERSION=${{ github.event.inputs.version }} + else + # Use version from workflow_call + VERSION=${{ inputs.version }} + fi + echo "VERSION=$VERSION" + echo "version=$VERSION" >> $GITHUB_OUTPUT + + publish: + needs: prepare + uses: ./.github/workflows/reusable-publish.yml + with: + package_name: "core" + package_dir: "libs/core" + version: ${{ needs.prepare.outputs.version }} + is_lume_package: false + base_package_name: "cua-core" + secrets: + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/publish-omniparser.yml b/.github/workflows/publish-omniparser.yml new file mode 100644 index 00000000..8a8227f9 --- /dev/null +++ b/.github/workflows/publish-omniparser.yml @@ -0,0 +1,67 @@ +name: Publish OmniParser Package + +on: + push: + tags: + - 'omniparser-v*' + workflow_dispatch: + inputs: + version: + description: 'Version to publish (without v prefix)' + required: true + default: '0.1.0' + workflow_call: + inputs: + version: + description: 'Version to publish' + required: true + type: string + outputs: + version: + description: "The version that was published" + value: ${{ jobs.determine-version.outputs.version }} + +# Adding permissions at workflow level +permissions: + contents: write + +jobs: + determine-version: + runs-on: macos-latest + outputs: + version: ${{ steps.get-version.outputs.version }} + steps: + - uses: actions/checkout@v4 + + - name: Determine version + id: get-version + run: | + if [ "${{ github.event_name }}" == "push" ]; then + # Extract version from tag (for package-specific tags) + if [[ "${{ github.ref }}" =~ ^refs/tags/omniparser-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then + VERSION=${BASH_REMATCH[1]} + else + echo "Invalid tag format for omniparser" + exit 1 + fi + elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then + # Use version from workflow dispatch + VERSION=${{ github.event.inputs.version }} + else + # Use version from workflow_call + VERSION=${{ inputs.version }} + fi + echo "VERSION=$VERSION" + echo "version=$VERSION" >> $GITHUB_OUTPUT + + publish: + needs: determine-version + uses: ./.github/workflows/reusable-publish.yml + with: + package_name: "omniparser" + package_dir: "libs/omniparser" + version: ${{ needs.determine-version.outputs.version }} + is_lume_package: false + base_package_name: "cua-omniparser" + secrets: + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/publish-pylume.yml b/.github/workflows/publish-pylume.yml new file mode 100644 index 00000000..dddde233 --- /dev/null +++ b/.github/workflows/publish-pylume.yml @@ -0,0 +1,67 @@ +name: Publish Pylume Package + +on: + push: + tags: + - 'pylume-v*' + workflow_dispatch: + inputs: + version: + description: 'Version to publish (without v prefix)' + required: true + default: '0.1.0' + workflow_call: + inputs: + version: + description: 'Version to publish' + required: true + type: string + outputs: + version: + description: "The version that was published" + value: ${{ jobs.determine-version.outputs.version }} + +# Adding permissions at workflow level +permissions: + contents: write + +jobs: + determine-version: + runs-on: macos-latest + outputs: + version: ${{ steps.get-version.outputs.version }} + steps: + - uses: actions/checkout@v4 + + - name: Determine version + id: get-version + run: | + if [ "${{ github.event_name }}" == "push" ]; then + # Extract version from tag (for package-specific tags) + if [[ "${{ github.ref }}" =~ ^refs/tags/pylume-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then + VERSION=${BASH_REMATCH[1]} + else + echo "Invalid tag format for pylume" + exit 1 + fi + elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then + # Use version from workflow dispatch + VERSION=${{ github.event.inputs.version }} + else + # Use version from workflow_call + VERSION=${{ inputs.version }} + fi + echo "VERSION=$VERSION" + echo "version=$VERSION" >> $GITHUB_OUTPUT + + publish: + needs: determine-version + uses: ./.github/workflows/reusable-publish.yml + with: + package_name: "pylume" + package_dir: "libs/pylume" + version: ${{ needs.determine-version.outputs.version }} + is_lume_package: true + base_package_name: "pylume" + secrets: + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/publish-som.yml b/.github/workflows/publish-som.yml new file mode 100644 index 00000000..04ea9bb8 --- /dev/null +++ b/.github/workflows/publish-som.yml @@ -0,0 +1,67 @@ +name: Publish SOM Package + +on: + push: + tags: + - 'som-v*' + workflow_dispatch: + inputs: + version: + description: 'Version to publish (without v prefix)' + required: true + default: '0.1.0' + workflow_call: + inputs: + version: + description: 'Version to publish' + required: true + type: string + outputs: + version: + description: "The version that was published" + value: ${{ jobs.determine-version.outputs.version }} + +# Adding permissions at workflow level +permissions: + contents: write + +jobs: + determine-version: + runs-on: macos-latest + outputs: + version: ${{ steps.get-version.outputs.version }} + steps: + - uses: actions/checkout@v4 + + - name: Determine version + id: get-version + run: | + if [ "${{ github.event_name }}" == "push" ]; then + # Extract version from tag (for package-specific tags) + if [[ "${{ github.ref }}" =~ ^refs/tags/som-v([0-9]+\.[0-9]+\.[0-9]+) ]]; then + VERSION=${BASH_REMATCH[1]} + else + echo "Invalid tag format for som" + exit 1 + fi + elif [ "${{ github.event_name }}" == "workflow_dispatch" ]; then + # Use version from workflow dispatch + VERSION=${{ github.event.inputs.version }} + else + # Use version from workflow_call + VERSION=${{ inputs.version }} + fi + echo "VERSION=$VERSION" + echo "version=$VERSION" >> $GITHUB_OUTPUT + + publish: + needs: determine-version + uses: ./.github/workflows/reusable-publish.yml + with: + package_name: "som" + package_dir: "libs/som" # Updated to the new directory name + version: ${{ needs.determine-version.outputs.version }} + is_lume_package: false + base_package_name: "cua-som" + secrets: + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/reusable-publish.yml b/.github/workflows/reusable-publish.yml new file mode 100644 index 00000000..02e2aa0c --- /dev/null +++ b/.github/workflows/reusable-publish.yml @@ -0,0 +1,269 @@ +name: Reusable Package Publish Workflow + +on: + workflow_call: + inputs: + package_name: + description: 'Name of the package (e.g. pylume, computer, agent)' + required: true + type: string + package_dir: + description: 'Directory containing the package relative to workspace root (e.g. libs/pylume)' + required: true + type: string + version: + description: 'Version to publish' + required: true + type: string + is_lume_package: + description: 'Whether this package includes the lume binary' + required: false + type: boolean + default: false + base_package_name: + description: 'PyPI package name (e.g. pylume, cua-agent)' + required: true + type: string + secrets: + PYPI_TOKEN: + required: true + outputs: + version: + description: "The version that was published" + value: ${{ jobs.build-and-publish.outputs.version }} + +jobs: + build-and-publish: + runs-on: macos-latest + permissions: + contents: write # This permission is needed for creating releases + outputs: + version: ${{ steps.set-version.outputs.version }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Full history for release creation + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Create root pdm.lock file + run: | + # Create an empty pdm.lock file in the root + touch pdm.lock + + - name: Install PDM + uses: pdm-project/setup-pdm@v3 + with: + python-version: '3.10' + cache: true + + - name: Set version + id: set-version + run: | + echo "VERSION=${{ inputs.version }}" >> $GITHUB_ENV + echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT + + - name: Initialize PDM in package directory + run: | + # Make sure we're working with a properly initialized PDM project + cd ${{ inputs.package_dir }} + + # Create pdm.lock if it doesn't exist + if [ ! -f "pdm.lock" ]; then + echo "No pdm.lock found, initializing PDM project..." + pdm lock + fi + + - name: Set version in package + run: | + cd ${{ inputs.package_dir }} + # Replace pdm bump with direct edit of pyproject.toml + if [[ "$OSTYPE" == "darwin"* ]]; then + # macOS version of sed needs an empty string for -i + sed -i '' "s/version = \".*\"/version = \"$VERSION\"/" pyproject.toml + else + # Linux version + sed -i "s/version = \".*\"/version = \"$VERSION\"/" pyproject.toml + fi + # Verify version was updated + echo "Updated version in pyproject.toml:" + grep "version =" pyproject.toml + + # Conditional step for lume binary download (only for pylume package) + - name: Download and setup lume binary + if: inputs.is_lume_package + run: | + # Install required packages + pip install requests + + # Create a simple Python script for better error handling + cat > get_lume_version.py << 'EOF' + import requests, re, json, sys + try: + response = requests.get('https://api.github.com/repos/trycua/lume/releases/latest') + sys.stderr.write(f"API Status Code: {response.status_code}\n") + sys.stderr.write(f"API Response: {response.text}\n") + data = json.loads(response.text) + if 'tag_name' not in data: + sys.stderr.write(f"Warning: tag_name not found in API response. Keys: {list(data.keys())}\n") + print('0.1.9') + else: + tag_name = data['tag_name'] + match = re.match(r'v(\d+\.\d+\.\d+)', tag_name) + if match: + print(match.group(1)) + else: + sys.stderr.write("Error: Could not parse version from tag\n") + print('0.1.9') + except Exception as e: + sys.stderr.write(f"Error fetching release info: {str(e)}\n") + print('0.1.9') + EOF + + # Execute the script to get the version + LUME_VERSION=$(python get_lume_version.py) + echo "Using lume version: $LUME_VERSION" + + # Create a temporary directory for extraction + mkdir -p temp_lume + + # Download the lume release (silently) + echo "Downloading lume version v${LUME_VERSION}..." + curl -sL "https://github.com/trycua/lume/releases/download/v${LUME_VERSION}/lume.tar.gz" -o temp_lume/lume.tar.gz + + # Extract the tar file (ignore ownership and suppress warnings) + cd temp_lume && tar --no-same-owner -xzf lume.tar.gz + + # Make the binary executable + chmod +x lume + + # Copy the lume binary to the correct location in the pylume package + mkdir -p "${GITHUB_WORKSPACE}/${{ inputs.package_dir }}/pylume" + cp lume "${GITHUB_WORKSPACE}/${{ inputs.package_dir }}/pylume/lume" + + # Verify the binary exists and is executable + test -x "${GITHUB_WORKSPACE}/${{ inputs.package_dir }}/pylume/lume" || { echo "lume binary not found or not executable"; exit 1; } + + # Cleanup + cd "${GITHUB_WORKSPACE}" && rm -rf temp_lume + + # Save the lume version for reference + echo "LUME_VERSION=${LUME_VERSION}" >> $GITHUB_ENV + + - name: Build and publish + env: + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} + run: | + cd ${{ inputs.package_dir }} + # Build with PDM + pdm build + + # For pylume package, verify the binary is in the wheel + if [ "${{ inputs.is_lume_package }}" = "true" ]; then + python -m pip install wheel + wheel unpack dist/*.whl --dest temp_wheel + echo "Listing contents of wheel directory:" + find temp_wheel -type f + test -f temp_wheel/pylume-*/pylume/lume || { echo "lume binary not found in wheel"; exit 1; } + rm -rf temp_wheel + echo "Publishing ${{ inputs.base_package_name }} ${VERSION} with lume ${LUME_VERSION}" + else + echo "Publishing ${{ inputs.base_package_name }} ${VERSION}" + fi + + # Install and use twine directly instead of PDM publish + echo "Installing twine for direct publishing..." + pip install twine + + echo "Publishing to PyPI using twine..." + TWINE_USERNAME="__token__" TWINE_PASSWORD="$PYPI_TOKEN" python -m twine upload dist/* + + # Save the wheel file path for the release + WHEEL_FILE=$(ls dist/*.whl | head -1) + echo "WHEEL_FILE=${WHEEL_FILE}" >> $GITHUB_ENV + + - name: Prepare Simple Release Notes + if: startsWith(github.ref, 'refs/tags/') + run: | + # Create release notes based on package type + echo "# ${{ inputs.base_package_name }} v${VERSION}" > release_notes.md + echo "" >> release_notes.md + + if [ "${{ inputs.package_name }}" = "pylume" ]; then + echo "## Python SDK for lume - run macOS and Linux VMs on Apple Silicon" >> release_notes.md + echo "" >> release_notes.md + echo "This package provides Python bindings for the lume virtualization tool." >> release_notes.md + echo "" >> release_notes.md + echo "## Dependencies" >> release_notes.md + echo "* lume binary: v${LUME_VERSION}" >> release_notes.md + elif [ "${{ inputs.package_name }}" = "computer" ]; then + echo "## Computer control library for the Computer Universal Automation (CUA) project" >> release_notes.md + echo "" >> release_notes.md + echo "## Dependencies" >> release_notes.md + echo "* pylume: ${PYLUME_VERSION:-latest}" >> release_notes.md + elif [ "${{ inputs.package_name }}" = "agent" ]; then + echo "## Dependencies" >> release_notes.md + echo "* cua-computer: ${COMPUTER_VERSION:-latest}" >> release_notes.md + echo "* cua-som: ${SOM_VERSION:-latest}" >> release_notes.md + echo "" >> release_notes.md + echo "## Installation Options" >> release_notes.md + echo "" >> release_notes.md + echo "### Basic installation with Anthropic" >> release_notes.md + echo '```bash' >> release_notes.md + echo "pip install cua-agent[anthropic]==${VERSION}" >> release_notes.md + echo '```' >> release_notes.md + echo "" >> release_notes.md + echo "### With SOM (recommended)" >> release_notes.md + echo '```bash' >> release_notes.md + echo "pip install cua-agent[som]==${VERSION}" >> release_notes.md + echo '```' >> release_notes.md + echo "" >> release_notes.md + echo "### All features" >> release_notes.md + echo '```bash' >> release_notes.md + echo "pip install cua-agent[all]==${VERSION}" >> release_notes.md + echo '```' >> release_notes.md + elif [ "${{ inputs.package_name }}" = "som" ]; then + echo "## Computer Vision and OCR library for detecting and analyzing UI elements" >> release_notes.md + echo "" >> release_notes.md + echo "This package provides enhanced UI understanding capabilities through computer vision and OCR." >> release_notes.md + elif [ "${{ inputs.package_name }}" = "computer-server" ]; then + echo "## Computer Server for the Computer Universal Automation (CUA) project" >> release_notes.md + echo "" >> release_notes.md + echo "A FastAPI-based server implementation for computer control." >> release_notes.md + echo "" >> release_notes.md + echo "## Dependencies" >> release_notes.md + echo "* cua-computer: ${COMPUTER_VERSION:-latest}" >> release_notes.md + echo "" >> release_notes.md + echo "## Usage" >> release_notes.md + echo '```bash' >> release_notes.md + echo "# Run the server" >> release_notes.md + echo "cua-computer-server" >> release_notes.md + echo '```' >> release_notes.md + fi + + # Add installation section if not agent (which has its own installation section) + if [ "${{ inputs.package_name }}" != "agent" ]; then + echo "" >> release_notes.md + echo "## Installation" >> release_notes.md + echo '```bash' >> release_notes.md + echo "pip install ${{ inputs.base_package_name }}==${VERSION}" >> release_notes.md + echo '```' >> release_notes.md + fi + + echo "Release notes created:" + cat release_notes.md + + - name: Create GitHub Release + uses: softprops/action-gh-release@v1 + if: startsWith(github.ref, 'refs/tags/') + with: + name: "${{ inputs.base_package_name }} v${{ env.VERSION }}" + body_path: release_notes.md + files: ${{ inputs.package_dir }}/${{ env.WHEEL_FILE }} + draft: false + prerelease: false + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 244c29a3..0d151783 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,175 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +!libs/lume/scripts/build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Scripts +server/scripts/ + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Conda +.conda/ + +# Local environment +.env.local + +# macOS DS_Store +.DS_Store + +weights/ +weights/icon_detect/ +weights/icon_detect/model.pt +weights/icon_detect/model.pt.zip +weights/icon_detect/model.pt.zip.part* + +libs/omniparser/weights/icon_detect/model.pt + +# Example test data and output +examples/test_data/ +examples/output/ + +/screenshots/ + +/experiments/ + +/logs/ + # Xcode # # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore -.DS_Store - ## User settings xcuserdata/ @@ -29,8 +195,7 @@ playground.xcworkspace # # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata # hence it is not needed unless you have added a package configuration file to your project -# .swiftpm - +.swiftpm/ .build/ # CocoaPods @@ -48,7 +213,6 @@ playground.xcworkspace # # Add this line if you want to avoid checking in source code from Carthage dependencies. # Carthage/Checkouts - Carthage/Build/ # fastlane @@ -57,20 +221,22 @@ Carthage/Build/ # Instead, use fastlane to re-generate the screenshots whenever they are needed. # For more information about the recommended setup visit: # https://docs.fastlane.tools/best-practices/source-control/#source-control - fastlane/report.xml fastlane/Preview.html fastlane/screenshots/**/*.png fastlane/test_output -# Local environment variables -.env.local - # Ignore folder ignore # .release .release/ -# Swift Package Manager -.swiftpm/ +# Shared folder +shared + +# Trajectories +trajectories/ + +# Installation ID Storage +.storage/ \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 4f5b8ce7..00000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,241 +0,0 @@ -{ - "configurations": [ - { - "type": "bashdb", - "request": "launch", - "name": "Bash-Debug (select script from list of sh files)", - "cwd": "${workspaceFolder}", - "program": "${command:SelectScriptName}", - "pathBash": "/opt/homebrew/bin/bash", - "args": [] - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "serve" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume serve", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "create", - "macos-vm", - "--cpu", - "4", - "--memory", - "4GB", - "--disk-size", - "40GB", - "--ipsw", - "/Users//Downloads/UniversalMac_15.2_24C101_Restore.ipsw" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume create --os macos --ipsw 'path/to/ipsw' (macos)", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "create", - "macos-vm", - "--cpu", - "4", - "--memory", - "4GB", - "--disk-size", - "20GB", - "--ipsw", - "latest" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume create --os macos --ipsw latest (macos)", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "create", - "linux-vm", - "--os", - "linux", - "--cpu", - "4", - "--memory", - "4GB", - "--disk-size", - "20GB" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume create --os linux (linux)", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "pull", - "macos-sequoia-vanilla:15.2", - "--name", - "macos-vm-cloned" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume pull (macos)", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "run", - "macos-vm", - "--shared-dir", - "/Users//repos/trycua/lume/shared_folder:rw", - "--start-vnc" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume run (macos)", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "run", - "linux-vm", - "--start-vnc", - "--mount", - "/Users//Downloads/ubuntu-24.04.1-live-server-arm64.iso" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume run with setup mount (linux)", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "run", - "linux-vm", - "--start-vnc" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume run (linux)", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "get", - "macos-vm" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume get (macos)", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "ls" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume ls", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "images" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume images", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "sourceLanguages": [ - "swift" - ], - "args": [ - "stop", - "macos-vm" - ], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume stop (macos)", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "build-debug" - }, - { - "type": "lldb", - "request": "launch", - "args": [], - "cwd": "${workspaceFolder:lume}", - "name": "Debug lume", - "program": "${workspaceFolder:lume}/.build/debug/lume", - "preLaunchTask": "swift: Build Debug lume" - }, - { - "type": "lldb", - "request": "launch", - "args": [], - "cwd": "${workspaceFolder:lume}", - "name": "Release lume", - "program": "${workspaceFolder:lume}/.build/release/lume", - "preLaunchTask": "swift: Build Release lume" - } - ] -} \ No newline at end of file diff --git a/.vscode/lume.code-workspace b/.vscode/lume.code-workspace new file mode 100644 index 00000000..a6d28d7e --- /dev/null +++ b/.vscode/lume.code-workspace @@ -0,0 +1,319 @@ +{ + "folders": [ + { + "name": "lume", + "path": "../libs/lume" + } + ], + "settings": { + "files.exclude": { + "**/.git": true, + "**/.svn": true, + "**/.hg": true, + "**/CVS": true, + "**/.DS_Store": true + }, + "swift.path.swift_driver_bin": "/usr/bin/swift", + "swift.enableLanguageServer": true, + "files.associations": { + "*.swift": "swift" + }, + "[swift]": { + "editor.formatOnSave": true, + "editor.detectIndentation": true, + "editor.tabSize": 4 + }, + "swift.path": "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin", + "swift.swiftEnvironmentVariables": { + "DEVELOPER_DIR": "/Applications/Xcode.app" + }, + "lldb.library": "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/LLDB", + "lldb.launch.expressions": "native" + }, + "tasks": { + "version": "2.0.0", + "tasks": [ + { + "label": "build-debug", + "type": "shell", + "command": "${workspaceFolder:lume}/scripts/build/build-debug.sh", + "options": { + "cwd": "${workspaceFolder:lume}" + }, + "group": { + "kind": "build", + "isDefault": true + }, + "presentation": { + "reveal": "silent", + "panel": "shared" + }, + "problemMatcher": [] + }, + { + "label": "swift: Build Debug lume", + "type": "shell", + "command": "${workspaceFolder:lume}/scripts/build/build-debug.sh", + "options": { + "cwd": "${workspaceFolder:lume}" + }, + "group": "build", + "presentation": { + "reveal": "silent", + "panel": "shared" + }, + "problemMatcher": [] + } + ] + }, + "launch": { + "configurations": [ + { + "type": "bashdb", + "request": "launch", + "name": "Bash-Debug (select script from list of sh files)", + "cwd": "${workspaceFolder}", + "program": "${command:SelectScriptName}", + "pathBash": "/opt/homebrew/bin/bash", + "args": [] + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "serve" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume serve", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "create", + "macos-vm", + "--cpu", + "4", + "--memory", + "4GB", + "--disk-size", + "40GB", + "--ipsw", + "/Users//Downloads/UniversalMac_15.2_24C101_Restore.ipsw" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume create --os macos --ipsw 'path/to/ipsw' (macos)", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "create", + "macos-vm", + "--cpu", + "4", + "--memory", + "4GB", + "--disk-size", + "20GB", + "--ipsw", + "latest" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume create --os macos --ipsw latest (macos)", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "create", + "linux-vm", + "--os", + "linux", + "--cpu", + "4", + "--memory", + "4GB", + "--disk-size", + "20GB" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume create --os linux (linux)", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "pull", + "macos-sequoia-vanilla:15.2", + "--name", + "macos-vm-cloned" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume pull (macos)", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "run", + "macos-vm", + "--shared-dir", + "/Users//repos/trycua/lume/shared_folder:rw", + "--start-vnc" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume run (macos)", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "run", + "linux-vm", + "--start-vnc", + "--mount", + "/Users//Downloads/ubuntu-24.04.1-live-server-arm64.iso" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume run with setup mount (linux)", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "run", + "linux-vm", + "--start-vnc" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume run (linux)", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "get", + "macos-vm" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume get (macos)", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "ls" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume ls", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "images" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume images", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "sourceLanguages": [ + "swift" + ], + "args": [ + "stop", + "macos-vm" + ], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume stop (macos)", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "build-debug" + }, + { + "type": "lldb", + "request": "launch", + "args": [], + "cwd": "${workspaceFolder:lume}", + "name": "Debug lume", + "program": "${workspaceFolder:lume}/.build/debug/lume", + "preLaunchTask": "swift: Build Debug lume" + }, + { + "type": "lldb", + "request": "launch", + "args": [], + "cwd": "${workspaceFolder:lume}", + "name": "Release lume", + "program": "${workspaceFolder:lume}/.build/release/lume", + "preLaunchTask": "swift: Build Release lume" + }, + { + "type": "bashdb", + "request": "launch", + "name": "Bash-Debug (select script)", + "cwd": "${workspaceFolder:lume}", + "program": "${command:SelectScriptName}", + "pathBash": "/opt/homebrew/bin/bash", + "args": [] + } + ] + } +} \ No newline at end of file diff --git a/.vscode/py.code-workspace b/.vscode/py.code-workspace new file mode 100644 index 00000000..6adef38b --- /dev/null +++ b/.vscode/py.code-workspace @@ -0,0 +1,292 @@ +{ + "folders": [ + { + "name": "cua-root", + "path": ".." + }, + { + "name": "computer", + "path": "../libs/computer" + }, + { + "name": "agent", + "path": "../libs/agent" + }, + { + "name": "som", + "path": "../libs/som" + }, + { + "name": "computer-server", + "path": "../libs/computer-server" + }, + { + "name": "pylume", + "path": "../libs/pylume" + }, + { + "name": "core", + "path": "../libs/core" + } + ], + "settings": { + "files.exclude": { + "**/.git": true, + "**/.svn": true, + "**/.hg": true, + "**/CVS": true, + "**/.DS_Store": true, + "**/__pycache__": true, + "**/.pytest_cache": true, + "**/*.pyc": true + }, + "python.testing.pytestEnabled": true, + "python.testing.unittestEnabled": false, + "python.testing.nosetestsEnabled": false, + "python.testing.pytestArgs": [ + "libs" + ], + "python.analysis.extraPaths": [ + "${workspaceFolder:cua-root}/libs/core", + "${workspaceFolder:cua-root}/libs/computer", + "${workspaceFolder:cua-root}/libs/agent", + "${workspaceFolder:cua-root}/libs/som", + "${workspaceFolder:cua-root}/libs/pylume", + "${workspaceFolder:cua-root}/.vscode/typings" + ], + "python.envFile": "${workspaceFolder:cua-root}/.env", + "python.defaultInterpreterPath": "${workspaceFolder:cua-root}/.venv/bin/python", + "python.analysis.diagnosticMode": "workspace", + "python.analysis.typeCheckingMode": "basic", + "python.analysis.autoSearchPaths": true, + "python.analysis.stubPath": "${workspaceFolder:cua-root}/.vscode/typings", + "python.analysis.indexing": false, + "python.analysis.exclude": [ + "**/node_modules/**", + "**/__pycache__/**", + "**/.*/**", + "**/venv/**", + "**/.venv/**", + "**/dist/**", + "**/build/**", + ".pdm-build/**", + "**/.git/**", + "examples/**", + "notebooks/**", + "logs/**", + "screenshots/**" + ], + "python.analysis.packageIndexDepths": [ + { + "name": "computer", + "depth": 2 + }, + { + "name": "agent", + "depth": 2 + }, + { + "name": "som", + "depth": 2 + }, + { + "name": "pylume", + "depth": 2 + }, + { + "name": "core", + "depth": 2 + } + ], + "python.autoComplete.extraPaths": [ + "${workspaceFolder:cua-root}/libs/core", + "${workspaceFolder:cua-root}/libs/computer", + "${workspaceFolder:cua-root}/libs/agent", + "${workspaceFolder:cua-root}/libs/som", + "${workspaceFolder:cua-root}/libs/pylume" + ], + "python.languageServer": "Pylance", + "[python]": { + "editor.formatOnSave": true, + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.codeActionsOnSave": { + "source.organizeImports": "explicit" + } + }, + "files.associations": { + "examples/computer_examples.py": "python", + "examples/agent_examples.py": "python" + }, + "python.interpreterPaths": { + "examples/computer_examples.py": "${workspaceFolder}/libs/computer/.venv/bin/python", + "examples/agent_examples.py": "${workspaceFolder}/libs/agent/.venv/bin/python" + } + }, + "tasks": { + "version": "2.0.0", + "tasks": [ + { + "label": "Build Dependencies", + "type": "shell", + "command": "${workspaceFolder}/scripts/build.sh", + "presentation": { + "reveal": "always", + "panel": "new", + "clear": true + }, + "group": { + "kind": "build", + "isDefault": true + }, + "options": { + "shell": { + "executable": "/bin/bash", + "args": ["-l", "-c"] + } + }, + "problemMatcher": [] + } + ] + }, + "launch": { + "version": "0.2.0", + "configurations": [ + { + "name": "Run Computer Examples", + "type": "debugpy", + "request": "launch", + "program": "examples/computer_examples.py", + "console": "integratedTerminal", + "justMyCode": true, + "python": "${workspaceFolder:cua-root}/.venv/bin/python", + "cwd": "${workspaceFolder:cua-root}", + "env": { + "PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume" + } + }, + { + "name": "Run Agent Examples", + "type": "debugpy", + "request": "launch", + "program": "examples/agent_examples.py", + "console": "integratedTerminal", + "justMyCode": false, + "python": "${workspaceFolder:cua-root}/.venv/bin/python", + "cwd": "${workspaceFolder:cua-root}", + "env": { + "PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume" + } + }, + { + "name": "Run PyLume Examples", + "type": "debugpy", + "request": "launch", + "program": "examples/pylume_examples.py", + "console": "integratedTerminal", + "justMyCode": true, + "python": "${workspaceFolder:cua-root}/.venv/bin/python", + "cwd": "${workspaceFolder:cua-root}", + "env": { + "PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume" + } + }, + { + "name": "SOM: Run Experiments (No OCR)", + "type": "debugpy", + "request": "launch", + "program": "examples/som_examples.py", + "args": [ + "examples/test_data", + "--output-dir", "examples/output", + "--ocr", "none", + "--mode", "experiment" + ], + "console": "integratedTerminal", + "justMyCode": false, + "python": "${workspaceFolder:cua-root}/.venv/bin/python", + "cwd": "${workspaceFolder:cua-root}", + "env": { + "PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume" + } + }, + { + "name": "SOM: Run Experiments (EasyOCR)", + "type": "debugpy", + "request": "launch", + "program": "examples/som_examples.py", + "args": [ + "examples/test_data", + "--output-dir", "examples/output", + "--ocr", "easyocr", + "--mode", "experiment" + ], + "console": "integratedTerminal", + "justMyCode": false, + "python": "${workspaceFolder:cua-root}/.venv/bin/python", + "cwd": "${workspaceFolder:cua-root}", + "env": { + "PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume" + } + }, + { + "name": "Run Computer Server", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/libs/computer-server/run_server.py", + "console": "integratedTerminal", + "justMyCode": true, + "python": "${workspaceFolder:cua-root}/.venv/bin/python", + "cwd": "${workspaceFolder:cua-root}", + "env": { + "PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer:${workspaceFolder:cua-root}/libs/agent:${workspaceFolder:cua-root}/libs/som:${workspaceFolder:cua-root}/libs/pylume" + } + }, + { + "name": "Run Computer Server with Args", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/libs/computer-server/run_server.py", + "args": [ + "--host", "0.0.0.0", + "--port", "8000", + "--log-level", "debug" + ], + "console": "integratedTerminal", + "justMyCode": false, + "python": "${workspaceFolder:cua-root}/.venv/bin/python", + "cwd": "${workspaceFolder:cua-root}", + "env": { + "PYTHONPATH": "${workspaceFolder:cua-root}/libs/core:${workspaceFolder:cua-root}/libs/computer-server" + } + } + ] + }, + "compounds": [ + { + "name": "Run Computer Examples + Server", + "configurations": ["Run Computer Examples", "Run Computer Server"], + "stopAll": true, + "presentation": { + "group": "Computer", + "order": 1 + } + }, + { + "name": "Run Server with Keep-Alive Client", + "configurations": ["Run Computer Server", "Test Server Connection (Keep Alive)"], + "stopAll": true, + "presentation": { + "group": "Computer", + "order": 2 + } + } + ], + "inputs": [ + { + "id": "imagePath", + "type": "promptString", + "description": "Path to the image file or directory for icon detection", + "default": "${workspaceFolder}/examples/test_data" + } + ] +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json deleted file mode 100644 index 0640efba..00000000 --- a/.vscode/tasks.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "version": "2.0.0", - "tasks": [ - { - "label": "build-debug", - "type": "shell", - "command": "${workspaceFolder:lume}/scripts/build/build-debug.sh", - "group": { - "kind": "build", - "isDefault": true - }, - "presentation": { - "reveal": "silent" - }, - "problemMatcher": [] - } - ] -} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6c51a416..5a479add 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ -# Contributing to lume +# Contributing to cua -We deeply appreciate your interest in contributing to lume! Whether you're reporting bugs, suggesting enhancements, improving docs, or submitting pull requests, your contributions help improve the project for everyone. +We deeply appreciate your interest in contributing to cua! Whether you're reporting bugs, suggesting enhancements, improving docs, or submitting pull requests, your contributions help improve the project for everyone. ## Reporting Bugs @@ -34,6 +34,6 @@ Documentation improvements are always welcome. You can: - Improve API documentation - Add tutorials or guides -For detailed instructions on setting up your development environment and submitting code contributions, please see our [Development.md](docs/Development.md) guide. +For detailed instructions on setting up your development environment and submitting code contributions, please see our [Developer-Guide.md](docs/Developer-Guide.md) guide. Feel free to join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get help with your contributions. \ No newline at end of file diff --git a/LICENSE b/LICENSE.md similarity index 97% rename from LICENSE rename to LICENSE.md index dc94027f..3ae3445c 100644 --- a/LICENSE +++ b/LICENSE.md @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2024 trycua +Copyright (c) 2025 trycua Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 438380a8..f390d869 100644 --- a/README.md +++ b/README.md @@ -1,145 +1,82 @@
-

-
- - - - Shows my svg - -
+ + + + Cua logo + - [![Swift 6](https://img.shields.io/badge/Swift_6-F54A2A?logo=swift&logoColor=white&labelColor=F54A2A)](#) + + + [![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#) + [![Swift](https://img.shields.io/badge/Swift-F05138?logo=swift&logoColor=white)](#) [![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#) - [![Homebrew](https://img.shields.io/badge/Homebrew-FBB040?logo=homebrew&logoColor=fff)](#install) [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85) - [![All Contributors](https://img.shields.io/github/all-contributors/trycua/lume?color=ee8449&style=flat-square)](#contributors) -

+# Cua -**lume** is a lightweight Command Line Interface and local API server to create, run and manage macOS and Linux virtual machines (VMs) with near-native performance on Apple Silicon, using Apple's `Virtualization.Framework`. +Create and run high-performance macOS and Linux VMs on Apple Silicon, with built-in support for AI agents. -### Run prebuilt macOS images in just 1 step +## Libraries -
-lume cli -
+| Library | Description | Installation | Version | +|---------|-------------|--------------|---------| +| [**Lume**](./libs/lume/README.md) | CLI for running macOS/Linux VMs with near-native performance using Apple's `Virtualization.Framework`. | `brew install lume` | [![brew](https://img.shields.io/badge/brew-0.1.10-333333)](https://formulae.brew.sh/formula/lume) | +| [**Computer**](./libs/computer/README.md) | Computer-Use Interface (CUI) framework for interacting with macOS/Linux sandboxes | `pip install cua-computer` | [![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/) | +| [**Agent**](./libs/agent/README.md) | Computer-Use Agent (CUA) framework for running agentic workflows in macOS/Linux dedicated sandboxes | `pip install cua-agent` | [![PyPI](https://img.shields.io/pypi/v/cua-agent?color=333333)](https://pypi.org/project/cua-agent/) | +## Lume -```bash -lume run macos-sequoia-vanilla:latest -``` - -For a python interface, check out [pylume](https://github.com/trycua/pylume). - -## Usage - -```bash -lume - -Commands: - lume create Create a new macOS or Linux VM - lume run Run a VM - lume ls List all VMs - lume get Get detailed information about a VM - lume set Modify VM configuration - lume stop Stop a running VM - lume delete Delete a VM - lume pull Pull a macOS image from container registry - lume clone Clone an existing VM - lume images List available macOS images in local cache - lume ipsw Get the latest macOS restore image URL - lume prune Remove cached images - lume serve Start the API server - -Options: - --help Show help [boolean] - --version Show version number [boolean] - -Command Options: - create: - --os Operating system to install (macOS or linux, default: macOS) - --cpu Number of CPU cores (default: 4) - --memory Memory size, e.g., 8GB (default: 4GB) - --disk-size Disk size, e.g., 50GB (default: 40GB) - --display Display resolution (default: 1024x768) - --ipsw Path to IPSW file or 'latest' for macOS VMs - - run: - --no-display Do not start the VNC client app - --shared-dir Share directory with VM (format: path[:ro|rw]) - --mount For Linux VMs only, attach a read-only disk image - --registry Container registry URL (default: ghcr.io) - --organization Organization to pull from (default: trycua) - --vnc-port Port to use for the VNC server (default: 0 for auto-assign) - --recovery-mode For MacOS VMs only, start VM in recovery mode (default: false) - - set: - --cpu New number of CPU cores (e.g., 4) - --memory New memory size (e.g., 8192MB or 8GB) - --disk-size New disk size (e.g., 40960MB or 40GB) - --display New display resolution in format WIDTHxHEIGHT (e.g., 1024x768) - - delete: - --force Force deletion without confirmation - - pull: - --registry Container registry URL (default: ghcr.io) - --organization Organization to pull from (default: trycua) - - serve: - --port Port to listen on (default: 3000) -``` - -## Install - -```bash -brew tap trycua/lume -brew install lume -``` - -You can also download the `lume.pkg.tar.gz` archive from the [latest release](https://github.com/trycua/lume/releases), extract it, and install the package manually. - -## Prebuilt Images - -Pre-built images are available in the registry [ghcr.io/trycua](https://github.com/orgs/trycua/packages). -These images come with an SSH server pre-configured and auto-login enabled. - -For the security of your VM, change the default password `lume` immediately after your first login. - -| Image | Tag | Description | Size | -|-------|------------|-------------|------| -| `macos-sequoia-vanilla` | `latest`, `15.2` | macOS Sequoia 15.2 | 40GB | -| `macos-sequoia-xcode` | `latest`, `15.2` | macOS Sequoia 15.2 with Xcode command line tools | 50GB | -| `ubuntu-noble-vanilla` | `latest`, `24.04.1` | [Ubuntu Server for ARM 24.04.1 LTS](https://ubuntu.com/download/server/arm) with Ubuntu Desktop | 20GB | - -For additional disk space, resize the VM disk after pulling the image using the `lume set --disk-size ` command. - -## Local API Server - -`lume` exposes a local HTTP API server that listens on `http://localhost:3000/lume`, enabling automated management of VMs. - -```bash -lume serve -``` - -For detailed API documentation, please refer to [API Reference](docs/API-Reference.md). +**Originally looking for Lume?** If you're here for the original Lume project, it's now part of this monorepo. Simply install with `brew` and refer to its [documentation](./libs/lume/README.md). ## Docs -- [API Reference](docs/API-Reference.md) -- [Development](docs/Development.md) -- [FAQ](docs/FAQ.md) +For optimal onboarding, we recommend starting with the [Computer](./libs/computer/README.md) documentation to cover the core functionality of the Computer sandbox, then exploring the [Agent](./libs/agent/README.md) documentation to understand Cua's AI agent capabilities, and finally working through the Notebook examples to try out the Computer-Use interface and agent. + +- [Computer](./libs/computer/README.md) +- [Agent](./libs/agent/README.md) +- [Notebooks](./notebooks/) + +## Demos + +Demos of the Computer-Use Agent in action. Share your most impressive demos in Cua's [Discord community](https://discord.com/invite/mVnXXpdE85)! + +
+AI-Gradio: multi-app workflow requiring browser, VS Code and terminal access +
+
+ +
+ +
+ +
+Notebook: Fix GitHub issue in Cursor +
+
+ +
+ +
+ +## Accessory Libraries + +| Library | Description | Installation | Version | +|---------|-------------|--------------|---------| +| [**Core**](./libs/core/README.md) | Core functionality and utilities used by other Cua packages | `pip install cua-core` | [![PyPI](https://img.shields.io/pypi/v/cua-core?color=333333)](https://pypi.org/project/cua-core/) | +| [**PyLume**](./libs/pylume/README.md) | Python bindings for Lume | `pip install pylume` | [![PyPI](https://img.shields.io/pypi/v/pylume?color=333333)](https://pypi.org/project/pylume/) | +| [**Computer Server**](./libs/computer-server/README.md) | Server component for the Computer-Use Interface (CUI) framework | `pip install cua-computer-server` | [![PyPI](https://img.shields.io/pypi/v/cua-computer-server?color=333333)](https://pypi.org/project/cua-computer-server/) | +| [**SOM**](./libs/som/README.md) | Self-of-Mark library for Agent | `pip install cua-som` | [![PyPI](https://img.shields.io/pypi/v/cua-som?color=333333)](https://pypi.org/project/cua-som/) | ## Contributing -We welcome and greatly appreciate contributions to lume! Whether you're improving documentation, adding new features, fixing bugs, or adding new VM images, your efforts help make lume better for everyone. For detailed instructions on how to contribute, please refer to our [Contributing Guidelines](CONTRIBUTING.md). +We welcome and greatly appreciate contributions to Cua! Whether you're improving documentation, adding new features, fixing bugs, or adding new VM images, your efforts help make lume better for everyone. For detailed instructions on how to contribute, please refer to our [Contributing Guidelines](CONTRIBUTING.md). Join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get assistance. ## License -lume is open-sourced under the MIT License - see the [LICENSE](LICENSE) file for details. +Cua is open-sourced under the MIT License - see the [LICENSE](LICENSE) file for details. ## Trademarks @@ -147,7 +84,7 @@ Apple, macOS, and Apple Silicon are trademarks of Apple Inc. Ubuntu and Canonica ## Stargazers over time -[![Stargazers over time](https://starchart.cc/trycua/lume.svg?variant=adaptive)](https://starchart.cc/trycua/lume) +[![Stargazers over time](https://starchart.cc/trycua/cua.svg?variant=adaptive)](https://starchart.cc/trycua/cua) ## Contributors diff --git a/docs/Developer-Guide.md b/docs/Developer-Guide.md new file mode 100644 index 00000000..6fd95fce --- /dev/null +++ b/docs/Developer-Guide.md @@ -0,0 +1,147 @@ +## Developer Guide + +### Project Structure + +The project is organized as a monorepo with these main packages: +- `libs/core/` - Base package with telemetry support +- `libs/pylume/` - Python bindings for Lume +- `libs/computer/` - Core computer interaction library +- `libs/agent/` - AI agent library with multi-provider support +- `libs/som/` - Computer vision and NLP processing library (formerly omniparser) +- `libs/computer-server/` - Server implementation for computer control +- `libs/lume/` - Swift implementation for enhanced macOS integration + +Each package has its own virtual environment and dependencies, managed through PDM. + +### Local Development Setup + +1. Clone the repository: +```bash +git clone https://github.com/trycua/cua.git +cd cua +``` + +2. Create a `.env.local` file in the root directory with your API keys: +```bash +# Required for Anthropic provider +ANTHROPIC_API_KEY=your_anthropic_key_here + +# Required for OpenAI provider +OPENAI_API_KEY=your_openai_key_here +``` + +3. Run the build script to set up all packages: +```bash +./scripts/build.sh +``` + +This will: +- Create a virtual environment for the project +- Install all packages in development mode +- Set up the correct Python path +- Install development tools + +4. Open the workspace in VSCode or Cursor: +```bash +# Using VSCode or Cursor +code .vscode/py.code-workspace + +# For Lume (Swift) development +code .vscode/lume.code-workspace +``` + +Using the workspace file is strongly recommended as it: +- Sets up correct Python environments for each package +- Configures proper import paths +- Enables debugging configurations +- Maintains consistent settings across packages + +### Cleanup and Reset + +If you need to clean up the environment and start fresh: + +```bash +./scripts/cleanup.sh +``` + +This will: +- Remove all virtual environments +- Clean Python cache files and directories +- Remove build artifacts +- Clean PDM-related files +- Reset environment configurations + +### Package Virtual Environments + +The build script creates a shared virtual environment for all packages. The workspace configuration automatically handles import paths with the correct Python path settings. + +### Running Examples + +The Python workspace includes launch configurations for all packages: + +- "Run Computer Examples" - Runs computer examples +- "Run Computer API Server" - Runs the computer-server +- "Run Omni Agent Examples" - Runs agent examples +- "SOM" configurations - Various settings for running SOM + +To run examples: +1. Open the workspace file (`.vscode/py.code-workspace`) +2. Press F5 or use the Run/Debug view +3. Select the desired configuration + +The workspace also includes compound launch configurations: +- "Run Computer Examples + Server" - Runs both the Computer Examples and Server simultaneously + +## Release and Publishing Process + +This monorepo contains multiple Python packages that can be published to PyPI. The packages +have dependencies on each other in the following order: + +1. `pylume` - Base package for VM management +2. `cua-computer` - Computer control interface (depends on pylume) +3. `cua-som` - Parser for UI elements (independent, formerly omniparser) +4. `cua-agent` - AI agent (depends on cua-computer and optionally cua-som) +5. `computer-server` - Server component installed on the sandbox + +#### Workflow Structure + +The publishing process is managed by these GitHub workflow files: + +- **Package-specific workflows**: + - `.github/workflows/publish-pylume.yml` + - `.github/workflows/publish-computer.yml` + - `.github/workflows/publish-som.yml` + - `.github/workflows/publish-agent.yml` + - `.github/workflows/publish-computer-server.yml` + +- **Coordinator workflow**: + - `.github/workflows/publish-all.yml` - Manages global releases and manual selections + +### Version Management + +#### Special Considerations for Pylume + +The `pylume` package requires special handling as it incorporates the binary executable from the [lume repository](https://github.com/trycua/lume): + +- When releasing `pylume`, ensure the version matches a corresponding release in the lume repository +- The workflow automatically downloads the matching lume binary and includes it in the pylume package +- If you need to release a new version of pylume, make sure to coordinate with a matching lume release + +## Development Workspaces + +This monorepo includes multiple VS Code workspace configurations to optimize the development experience based on which components you're working with: + +### Available Workspace Files + +- **[py.code-workspace](.vscode/py.code-workspace)**: For Python package development (Computer, Agent, SOM, etc.) +- **[lume.code-workspace](.vscode/lume.code-workspace)**: For Swift-based Lume development + +To open a specific workspace: + +```bash +# For Python development +code .vscode/py.code-workspace + +# For Lume (Swift) development +code .vscode/lume.code-workspace +``` \ No newline at end of file diff --git a/docs/FAQ.md b/docs/FAQ.md index 9150fbb5..a342f6c0 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -1,55 +1,50 @@ # FAQs -### Where are the VMs stored? +### Why a local sandbox? -VMs are stored in `~/.lume`. +A local sandbox is a dedicated environment that is isolated from the rest of the system. As AI agents rapidly evolve towards 70-80% success rates on average tasks, having a controlled and secure environment becomes crucial. Cua's Computer-Use AI agents run in a local sandbox to ensure reliability, safety, and controlled execution. -### How are images cached? +Benefits of using a local sandbox rather than running the Computer-Use AI agent in the host system: -Images are cached in `~/.lume/cache`. When doing `lume pull `, it will check if the image is already cached. If not, it will download the image and cache it, removing any older versions. +- **Reliability**: The sandbox provides a reproducible environment - critical for benchmarking and debugging agent behavior. Frameworks like [OSWorld](https://github.com/xlang-ai/OSWorld), [Simular AI](https://github.com/simular-ai/Agent-S), Microsoft's [OmniTool](https://github.com/microsoft/OmniParser/tree/master/omnitool), [WindowsAgentArena](https://github.com/microsoft/WindowsAgentArena) and more are using Computer-Use AI agents running in local sandboxes. +- **Safety & Isolation**: The sandbox is isolated from the rest of the system, protecting sensitive data and system resources. As CUA agent capabilities grow, this isolation becomes increasingly important for preventing potential safety breaches. +- **Control**: The sandbox can be easily monitored and terminated if needed, providing oversight for autonomous agent operation. -### Are VM disks taking up all the disk space? +### Where are the sandbox images stored? + +Sandbox are stored in `~/.lume`, and cached images are stored in `~/.lume/cache`. + +### Which image is Computer using? + +Computer uses an optimized macOS image for Computer-Use interactions, with pre-installed apps and settings for optimal performance. +The image is available on our [ghcr registry](https://github.com/orgs/trycua/packages/container/package/macos-sequoia-cua). + +### Are Sandbox disks taking up all the disk space? No, macOS uses sparse files, which only allocate space as needed. For example, VM disks totaling 50 GB may only use 20 GB on disk. -### How do I get the latest macOS restore image URL? - -```bash -lume ipsw -``` - ### How do I delete a VM? ```bash lume delete ``` -### How to Install macOS from an IPSW Image +### How do I troubleshoot Computer not connecting to lume daemon? -#### Create a new macOS VM using the latest supported IPSW image: -Run the following command to create a new macOS virtual machine using the latest available IPSW image: +If you're experiencing connection issues between Computer and the lume daemon, it could be because the port 3000 (used by lume) is already in use by an orphaned process. You can diagnose this issue with: ```bash -lume create --os macos --ipsw latest +sudo lsof -i :3000 ``` -#### Create a new macOS VM using a specific IPSW image: -To create a macOS virtual machine from an older or specific IPSW file, first download the desired IPSW (UniversalMac) from a trusted source. - -Then, use the downloaded IPSW path: +This command will show all processes using port 3000. If you see a lume process already running, you can terminate it with: ```bash -lume create --os macos --ipsw +kill ``` -### How do I install a custom Linux image? +Where `` is the process ID shown in the output of the `lsof` command. After terminating the process, run `lume serve` again to start the lume daemon. -The process for creating a custom Linux image differs than macOS, with IPSW restore files not being used. You need to create a linux VM first, then mount a setup image file to the VM for the first boot. +### What information does Cua track? -```bash -lume create --os linux - -lume run --mount - -lume run -``` +Cua tracks anonymized usage and error report statistics; we ascribe to Posthog's approach as detailed [here](https://posthog.com/blog/open-source-telemetry-ethical). If you would like to opt out of sending anonymized info, you can set `telemetry_enabled` to false in the Computer or Agent constructor. Check out our [Telemetry](Telemetry.md) documentation for more details. diff --git a/docs/Telemetry.md b/docs/Telemetry.md new file mode 100644 index 00000000..01731287 --- /dev/null +++ b/docs/Telemetry.md @@ -0,0 +1,74 @@ +# Telemetry in CUA + +This document explains how telemetry works in CUA libraries and how you can control it. + +CUA tracks anonymized usage and error report statistics; we ascribe to Posthog's approach as detailed [here](https://posthog.com/blog/open-source-telemetry-ethical). If you would like to opt out of sending anonymized info, you can set `telemetry_enabled` to false. + +## What telemetry data we collect + +CUA libraries collect minimal anonymous usage data to help improve our software. The telemetry data we collect is specifically limited to: + +- Basic system information: + - Operating system (e.g., 'darwin', 'win32', 'linux') + - Python version (e.g., '3.10.0') +- Module initialization events: + - When a module (like 'computer' or 'agent') is imported + - Version of the module being used + +We do NOT collect: +- Personal information +- Contents of files +- Specific text being typed +- Actual screenshots or screen contents +- User-specific identifiers +- API keys +- File contents +- Application data or content +- User interactions with the computer +- Information about files being accessed + +## Controlling Telemetry + +We are committed to transparency and user control over telemetry. There are two ways to control telemetry: + +## 1. Environment Variable (Global Control) + +Telemetry is enabled by default. To disable telemetry, set the `CUA_TELEMETRY_ENABLED` environment variable to a falsy value (`0`, `false`, `no`, or `off`): + +```bash +# Disable telemetry before running your script +export CUA_TELEMETRY_ENABLED=false + +# Or as part of the command +CUA_TELEMETRY_ENABLED=1 python your_script.py + +``` +Or from Python: +```python +import os +os.environ["CUA_TELEMETRY_ENABLED"] = "false" +``` + +## 2. Instance-Level Control + +You can control telemetry for specific CUA instances by setting `telemetry_enabled` when creating them: + +```python +# Disable telemetry for a specific Computer instance +computer = Computer(telemetry_enabled=False) + +# Enable telemetry for a specific Agent instance +agent = ComputerAgent(telemetry_enabled=True) +``` + +You can check if telemetry is enabled for an instance: + +```python +print(computer.telemetry_enabled) # Will print True or False +``` + +Note that telemetry settings must be configured during initialization and cannot be changed after the object is created. + +## Transparency + +We believe in being transparent about the data we collect. If you have any questions about our telemetry practices, please open an issue on our GitHub repository. \ No newline at end of file diff --git a/examples/agent_examples.py b/examples/agent_examples.py new file mode 100644 index 00000000..ebcb1070 --- /dev/null +++ b/examples/agent_examples.py @@ -0,0 +1,99 @@ +"""Example demonstrating the ComputerAgent capabilities with the Omni provider.""" + +import os +import asyncio +import logging +import traceback +from pathlib import Path +from datetime import datetime +import signal + +from computer import Computer + +# Import the unified agent class and types +from agent import ComputerAgent, AgentLoop, LLMProvider, LLM + +# Import utility functions +from utils import load_dotenv_files, handle_sigint + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def run_omni_agent_example(): + """Run example of using the ComputerAgent with OpenAI and Omni provider.""" + print(f"\n=== Example: ComputerAgent with OpenAI and Omni provider ===") + try: + # Create Computer instance with default parameters + computer = Computer(verbosity=logging.DEBUG) + + # Create agent with loop and provider + agent = ComputerAgent( + computer=computer, + # loop=AgentLoop.OMNI, + loop=AgentLoop.ANTHROPIC, + # model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"), + model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"), + save_trajectory=True, + trajectory_dir=str(Path("trajectories")), + only_n_most_recent_images=3, + verbosity=logging.INFO, + ) + + tasks = [ + """ +1. Look for a repository named trycua/lume on GitHub. +2. Check the open issues, open the most recent one and read it. +3. Clone the repository in users/lume/projects if it doesn't exist yet. +4. Open the repository with an app named Cursor (on the dock, black background and white cube icon). +5. From Cursor, open Composer if not already open. +6. Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue. +""" + ] + + async with agent: + for i, task in enumerate(tasks, 1): + print(f"\nExecuting task {i}/{len(tasks)}: {task}") + async for result in agent.run(task): + # Check if result has the expected structure + if "role" in result and "content" in result and "metadata" in result: + title = result["metadata"].get("title", "Screen Analysis") + content = result["content"] + else: + title = result.get("metadata", {}).get("title", "Screen Analysis") + content = result.get("content", str(result)) + + print(f"\n{title}") + print(content) + print(f"Task {i} completed") + + except Exception as e: + logger.error(f"Error in run_anthropic_agent_example: {e}") + traceback.print_exc() + raise + finally: + # Clean up resources + if computer and computer._initialized: + try: + await computer.stop() + except Exception as e: + logger.warning(f"Error stopping computer: {e}") + + +def main(): + """Run the Anthropic agent example.""" + try: + load_dotenv_files() + + # Register signal handler for graceful exit + signal.signal(signal.SIGINT, handle_sigint) + + asyncio.run(run_omni_agent_example()) + except Exception as e: + print(f"Error running example: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/examples/computer_examples.py b/examples/computer_examples.py new file mode 100644 index 00000000..b5e9fc84 --- /dev/null +++ b/examples/computer_examples.py @@ -0,0 +1,97 @@ +import os +import asyncio +from pathlib import Path +import sys +import json +import traceback + +# Load environment variables from .env file +project_root = Path(__file__).parent.parent +env_file = project_root / ".env" +print(f"Loading environment from: {env_file}") +from dotenv import load_dotenv + +load_dotenv(env_file) + +# Add paths to sys.path if needed +pythonpath = os.environ.get("PYTHONPATH", "") +for path in pythonpath.split(":"): + if path and path not in sys.path: + sys.path.append(path) + print(f"Added to sys.path: {path}") + +from computer.computer import Computer +from computer.logger import LogLevel +from computer.utils import get_image_size + + +async def main(): + try: + print("\n=== Using direct initialization ===") + computer = Computer( + display="1024x768", # Higher resolution + memory="8GB", # More memory + cpu="4", # More CPU cores + os="macos", + verbosity=LogLevel.NORMAL, # Use QUIET to suppress most logs + use_host_computer_server=False, + ) + try: + await computer.run() + + await computer.interface.hotkey("command", "space") + + # res = await computer.interface.run_command("touch ./Downloads/empty_file") + # print(f"Run command result: {res}") + + accessibility_tree = await computer.interface.get_accessibility_tree() + print(f"Accessibility tree: {accessibility_tree}") + + # Screen Actions Examples + print("\n=== Screen Actions ===") + screenshot = await computer.interface.screenshot() + with open("screenshot_direct.png", "wb") as f: + f.write(screenshot) + + screen_size = await computer.interface.get_screen_size() + print(f"Screen size: {screen_size}") + + # Demonstrate coordinate conversion + center_x, center_y = 733, 736 + print(f"Center in screen coordinates: ({center_x}, {center_y})") + + screenshot_center = await computer.to_screenshot_coordinates(center_x, center_y) + print(f"Center in screenshot coordinates: {screenshot_center}") + + screen_center = await computer.to_screen_coordinates(*screenshot_center) + print(f"Back to screen coordinates: {screen_center}") + + # Mouse Actions Examples + print("\n=== Mouse Actions ===") + await computer.interface.move_cursor(100, 100) + await computer.interface.left_click() + await computer.interface.right_click(300, 300) + await computer.interface.double_click(400, 400) + + # Keyboard Actions Examples + print("\n=== Keyboard Actions ===") + await computer.interface.type_text("Hello, World!") + await computer.interface.press_key("enter") + + # Clipboard Actions Examples + print("\n=== Clipboard Actions ===") + await computer.interface.set_clipboard("Test clipboard") + content = await computer.interface.copy_to_clipboard() + print(f"Clipboard content: {content}") + + finally: + # Important to clean up resources + pass + # await computer.stop() + except Exception as e: + print(f"Error in main: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/pylume_examples.py b/examples/pylume_examples.py new file mode 100644 index 00000000..6700c99b --- /dev/null +++ b/examples/pylume_examples.py @@ -0,0 +1,98 @@ +import asyncio +from pylume import ( + PyLume, + ImageRef, + VMRunOpts, + SharedDirectory, + VMConfig, + VMUpdateOpts +) + +async def main(): + """Example usage of PyLume.""" + async with PyLume(port=3000, use_existing_server=False, debug=True) as pylume: + + # Get latest IPSW URL + print("\n=== Getting Latest IPSW URL ===") + url = await pylume.get_latest_ipsw_url() + print("Latest IPSW URL:", url) + + # Create a new VM + print("\n=== Creating a new VM ===") + vm_config = VMConfig( + name="lume-vm-new", + os="macOS", + cpu=2, + memory="4GB", + disk_size="64GB", + display="1024x768", + ipsw="latest" + ) + await pylume.create_vm(vm_config) + + # Get latest IPSW URL + print("\n=== Getting Latest IPSW URL ===") + url = await pylume.get_latest_ipsw_url() + print("Latest IPSW URL:", url) + + # List available images + print("\n=== Listing Available Images ===") + images = await pylume.get_images() + print("Available Images:", images) + + # List all VMs to verify creation + print("\n=== Listing All VMs ===") + vms = await pylume.list_vms() + print("VMs:", vms) + + # Get specific VM details + print("\n=== Getting VM Details ===") + vm = await pylume.get_vm("lume-vm") + print("VM Details:", vm) + + # Update VM settings + print("\n=== Updating VM Settings ===") + update_opts = VMUpdateOpts( + cpu=8, + memory="4GB" + ) + await pylume.update_vm("lume-vm", update_opts) + + # Pull an image + image_ref = ImageRef( + image="macos-sequoia-vanilla", + tag="latest", + registry="ghcr.io", + organization="trycua" + ) + await pylume.pull_image(image_ref, name="lume-vm-pulled") + + # Run with shared directory + run_opts = VMRunOpts( + no_display=False, + shared_directories=[ + SharedDirectory( + host_path="~/shared", + read_only=False + ) + ] + ) + await pylume.run_vm("lume-vm", run_opts) + + # Or simpler: + await pylume.run_vm("lume-vm") + + # Clone VM + print("\n=== Cloning VM ===") + await pylume.clone_vm("lume-vm", "lume-vm-cloned") + + # Stop VM + print("\n=== Stopping VM ===") + await pylume.stop_vm("lume-vm") + + # Delete VM + print("\n=== Deleting VM ===") + await pylume.delete_vm("lume-vm-cloned") + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/examples/som_examples.py b/examples/som_examples.py new file mode 100644 index 00000000..75b798ac --- /dev/null +++ b/examples/som_examples.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 +""" +Example script demonstrating the usage of OmniParser's UI element detection functionality. +This script shows how to: +1. Initialize the OmniParser +2. Load and process images +3. Visualize detection results +4. Compare performance between CPU and MPS (Apple Silicon) +""" + +import argparse +import logging +import sys +from pathlib import Path +import time +from PIL import Image +from typing import Dict, Any, List, Optional +import numpy as np +import io +import base64 +import glob +import os + +# Load environment variables from .env file +project_root = Path(__file__).parent.parent +env_file = project_root / ".env" +print(f"Loading environment from: {env_file}") +from dotenv import load_dotenv + +load_dotenv(env_file) + +# Add paths to sys.path if needed +pythonpath = os.environ.get("PYTHONPATH", "") +for path in pythonpath.split(":"): + if path and path not in sys.path: + sys.path.append(path) + print(f"Added to sys.path: {path}") + +# Add the libs directory to the path to find som +libs_path = project_root / "libs" +if str(libs_path) not in sys.path: + sys.path.append(str(libs_path)) + print(f"Added to sys.path: {libs_path}") + +from som import OmniParser, ParseResult, IconElement, TextElement +from som.models import UIElement, ParserMetadata, BoundingBox + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + + +def setup_logging(): + """Configure logging with a nice format.""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + +class Timer: + """Enhanced context manager for timing code blocks.""" + + def __init__(self, name: str, logger): + self.name = name + self.logger = logger + self.start_time: float = 0.0 + self.elapsed_time: float = 0.0 + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, *args): + self.elapsed_time = time.time() - self.start_time + self.logger.info(f"{self.name}: {self.elapsed_time:.3f}s") + return False + + +def image_to_bytes(image: Image.Image) -> bytes: + """Convert PIL Image to PNG bytes.""" + buf = io.BytesIO() + image.save(buf, format="PNG") + return buf.getvalue() + + +def process_image( + parser: OmniParser, image_path: str, output_dir: Path, use_ocr: bool = False +) -> None: + """Process a single image and save the result.""" + try: + # Load image + logger.info(f"Processing image: {image_path}") + image = Image.open(image_path).convert("RGB") + logger.info(f"Image loaded successfully, size: {image.size}") + + # Create output filename + input_filename = Path(image_path).stem + output_path = output_dir / f"{input_filename}_analyzed.png" + + # Convert image to PNG bytes + image_bytes = image_to_bytes(image) + + # Process image + with Timer(f"Processing {input_filename}", logger): + result = parser.parse(image_bytes, use_ocr=use_ocr) + logger.info( + f"Found {result.metadata.num_icons} icons and {result.metadata.num_text} text elements" + ) + + # Save the annotated image + logger.info(f"Saving annotated image to: {output_path}") + try: + # Save image from base64 + img_data = base64.b64decode(result.annotated_image_base64) + img = Image.open(io.BytesIO(img_data)) + img.save(output_path) + + # Print detailed results + logger.info("\nDetected Elements:") + for elem in result.elements: + if isinstance(elem, IconElement): + logger.info( + f"Icon: confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}" + ) + elif isinstance(elem, TextElement): + logger.info( + f"Text: '{elem.content}', confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}" + ) + + # Verify file exists and log size + if output_path.exists(): + logger.info( + f"Successfully saved image. File size: {output_path.stat().st_size} bytes" + ) + else: + logger.error(f"Failed to verify file at {output_path}") + except Exception as e: + logger.error(f"Error saving image: {str(e)}", exc_info=True) + + except Exception as e: + logger.error(f"Error processing image {image_path}: {str(e)}", exc_info=True) + + +def run_detection_benchmark( + input_path: str, + output_dir: Path, + use_ocr: bool = False, + box_threshold: float = 0.01, + iou_threshold: float = 0.1, +): + """Run detection benchmark on images.""" + logger.info( + f"Starting benchmark with OCR enabled: {use_ocr}, box_threshold: {box_threshold}, iou_threshold: {iou_threshold}" + ) + + try: + # Initialize parser + logger.info("Initializing OmniParser...") + parser = OmniParser() + + # Create output directory + output_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Output directory created at: {output_dir}") + + # Get list of PNG files + if os.path.isdir(input_path): + image_files = glob.glob(os.path.join(input_path, "*.png")) + else: + image_files = [input_path] + + logger.info(f"Found {len(image_files)} images to process") + + # Process each image with specified thresholds + for image_path in image_files: + try: + # Load image + logger.info(f"Processing image: {image_path}") + image = Image.open(image_path).convert("RGB") + logger.info(f"Image loaded successfully, size: {image.size}") + + # Create output filename + input_filename = Path(image_path).stem + output_path = output_dir / f"{input_filename}_analyzed.png" + + # Convert image to PNG bytes + image_bytes = image_to_bytes(image) + + # Process image with specified thresholds + with Timer(f"Processing {input_filename}", logger): + result = parser.parse( + image_bytes, + use_ocr=use_ocr, + box_threshold=box_threshold, + iou_threshold=iou_threshold, + ) + logger.info( + f"Found {result.metadata.num_icons} icons and {result.metadata.num_text} text elements" + ) + + # Save the annotated image + logger.info(f"Saving annotated image to: {output_path}") + try: + # Save image from base64 + img_data = base64.b64decode(result.annotated_image_base64) + img = Image.open(io.BytesIO(img_data)) + img.save(output_path) + + # Print detailed results + logger.info("\nDetected Elements:") + for elem in result.elements: + if isinstance(elem, IconElement): + logger.info( + f"Icon: confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}" + ) + elif isinstance(elem, TextElement): + logger.info( + f"Text: '{elem.content}', confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}" + ) + + # Verify file exists and log size + if output_path.exists(): + logger.info( + f"Successfully saved image. File size: {output_path.stat().st_size} bytes" + ) + else: + logger.error(f"Failed to verify file at {output_path}") + except Exception as e: + logger.error(f"Error saving image: {str(e)}", exc_info=True) + + except Exception as e: + logger.error(f"Error processing image {image_path}: {str(e)}", exc_info=True) + + except Exception as e: + logger.error(f"Benchmark failed: {str(e)}", exc_info=True) + raise + + +def run_experiments(input_path: str, output_dir: Path, use_ocr: bool = False): + """Run experiments with different threshold combinations.""" + # Define threshold values to test + box_thresholds = [0.01, 0.05, 0.1, 0.3] + iou_thresholds = [0.05, 0.1, 0.2, 0.5] + + logger.info("Starting threshold experiments...") + logger.info("Box thresholds to test: %s", box_thresholds) + logger.info("IOU thresholds to test: %s", iou_thresholds) + + # Create results directory for this experiment + timestamp = time.strftime("%Y%m%d-%H%M%S") + ocr_suffix = "_ocr" if use_ocr else "_no_ocr" + exp_dir = output_dir / f"experiment_{timestamp}{ocr_suffix}" + exp_dir.mkdir(parents=True, exist_ok=True) + + # Create a summary file + summary_file = exp_dir / "results_summary.txt" + with open(summary_file, "w") as f: + f.write("Threshold Experiments Results\n") + f.write("==========================\n\n") + f.write(f"Input: {input_path}\n") + f.write(f"OCR Enabled: {use_ocr}\n") + f.write(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n") + f.write("Results:\n") + f.write("-" * 80 + "\n") + f.write( + f"{'Box Thresh':^10} | {'IOU Thresh':^10} | {'Num Icons':^10} | {'Num Text':^10} | {'Time (s)':^10}\n" + ) + f.write("-" * 80 + "\n") + + # Initialize parser once for all experiments + parser = OmniParser() + + # Run experiments with each combination + for box_thresh in box_thresholds: + for iou_thresh in iou_thresholds: + logger.info(f"\nTesting box_threshold={box_thresh}, iou_threshold={iou_thresh}") + + # Create directory for this combination + combo_dir = exp_dir / f"box_{box_thresh}_iou_{iou_thresh}" + combo_dir.mkdir(exist_ok=True) + + try: + # Process each image + if os.path.isdir(input_path): + image_files = glob.glob(os.path.join(input_path, "*.png")) + else: + image_files = [input_path] + + total_icons = 0 + total_text = 0 + total_time = 0 + + for image_path in image_files: + # Load and process image + image = Image.open(image_path).convert("RGB") + image_bytes = image_to_bytes(image) + + # Process with current thresholds + with Timer(f"Processing {Path(image_path).stem}", logger) as t: + result = parser.parse( + image_bytes, + use_ocr=use_ocr, + box_threshold=box_thresh, + iou_threshold=iou_thresh, + ) + + # Save annotated image + output_path = combo_dir / f"{Path(image_path).stem}_analyzed.png" + img_data = base64.b64decode(result.annotated_image_base64) + img = Image.open(io.BytesIO(img_data)) + img.save(output_path) + + # Update totals + total_icons += result.metadata.num_icons + total_text += result.metadata.num_text + total_time += t.elapsed_time + + # Log detailed results + detail_file = combo_dir / f"{Path(image_path).stem}_details.txt" + with open(detail_file, "w") as detail_f: + detail_f.write(f"Results for {Path(image_path).name}\n") + detail_f.write("-" * 40 + "\n") + detail_f.write(f"Number of icons: {result.metadata.num_icons}\n") + detail_f.write( + f"Number of text elements: {result.metadata.num_text}\n\n" + ) + + detail_f.write("Icon Detections:\n") + icon_count = 1 + text_count = ( + result.metadata.num_icons + 1 + ) # Text boxes start after icons + + # First list all icons + for elem in result.elements: + if isinstance(elem, IconElement): + detail_f.write(f"Box #{icon_count}: Icon\n") + detail_f.write(f" - Confidence: {elem.confidence:.3f}\n") + detail_f.write( + f" - Coordinates: {elem.bbox.coordinates}\n" + ) + icon_count += 1 + + if use_ocr: + detail_f.write("\nText Detections:\n") + for elem in result.elements: + if isinstance(elem, TextElement): + detail_f.write(f"Box #{text_count}: Text\n") + detail_f.write(f" - Content: '{elem.content}'\n") + detail_f.write( + f" - Confidence: {elem.confidence:.3f}\n" + ) + detail_f.write( + f" - Coordinates: {elem.bbox.coordinates}\n" + ) + text_count += 1 + + # Write summary for this combination + avg_time = total_time / len(image_files) + f.write( + f"{box_thresh:^10.3f} | {iou_thresh:^10.3f} | {total_icons:^10d} | {total_text:^10d} | {avg_time:^10.3f}\n" + ) + + except Exception as e: + logger.error( + f"Error in experiment box={box_thresh}, iou={iou_thresh}: {str(e)}" + ) + f.write( + f"{box_thresh:^10.3f} | {iou_thresh:^10.3f} | {'ERROR':^10s} | {'ERROR':^10s} | {'ERROR':^10s}\n" + ) + + # Write summary footer + f.write("-" * 80 + "\n") + f.write("\nExperiment completed successfully!\n") + + logger.info(f"\nExperiment results saved to {exp_dir}") + logger.info(f"Summary file: {summary_file}") + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser(description="Run OmniParser benchmark") + parser.add_argument("input_path", help="Path to input image or directory containing images") + parser.add_argument( + "--output-dir", default="examples/output", help="Output directory for annotated images" + ) + parser.add_argument( + "--ocr", + choices=["none", "easyocr"], + default="none", + help="OCR engine to use (default: none)", + ) + parser.add_argument( + "--mode", + choices=["single", "experiment"], + default="single", + help="Run mode: single run or threshold experiments (default: single)", + ) + parser.add_argument( + "--box-threshold", + type=float, + default=0.01, + help="Confidence threshold for detection (default: 0.01)", + ) + parser.add_argument( + "--iou-threshold", + type=float, + default=0.1, + help="IOU threshold for Non-Maximum Suppression (default: 0.1)", + ) + args = parser.parse_args() + + logger.info(f"Starting OmniParser with arguments: {args}") + use_ocr = args.ocr != "none" + output_dir = Path(args.output_dir) + + try: + if args.mode == "experiment": + run_experiments(args.input_path, output_dir, use_ocr) + else: + run_detection_benchmark( + args.input_path, output_dir, use_ocr, args.box_threshold, args.iou_threshold + ) + except Exception as e: + logger.error(f"Process failed: {str(e)}", exc_info=True) + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/utils.py b/examples/utils.py new file mode 100644 index 00000000..759b0b6e --- /dev/null +++ b/examples/utils.py @@ -0,0 +1,55 @@ +"""Utility functions for example scripts.""" + +import os +import sys +import signal +from pathlib import Path +from typing import Optional + + +def load_env_file(path: Path) -> bool: + """Load environment variables from a file. + + Args: + path: Path to the .env file + + Returns: + True if file was loaded successfully, False otherwise + """ + if not path.exists(): + return False + + print(f"Loading environment from {path}") + with open(path, "r") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + + key, value = line.split("=", 1) + os.environ[key] = value + + return True + + +def load_dotenv_files(): + """Load environment variables from .env files. + + Tries to load from .env.local first, then .env if .env.local doesn't exist. + """ + # Get the project root directory (parent of the examples directory) + project_root = Path(__file__).parent.parent + + # Try loading .env.local first, then .env if .env.local doesn't exist + env_local_path = project_root / ".env.local" + env_path = project_root / ".env" + + # Load .env.local if it exists, otherwise try .env + if not load_env_file(env_local_path): + load_env_file(env_path) + + +def handle_sigint(signum, frame): + """Handle SIGINT (Ctrl+C) gracefully.""" + print("\nExiting gracefully...") + sys.exit(0) diff --git a/img/agent.png b/img/agent.png new file mode 100644 index 00000000..bcca2358 Binary files /dev/null and b/img/agent.png differ diff --git a/img/computer.png b/img/computer.png new file mode 100644 index 00000000..c0b32946 Binary files /dev/null and b/img/computer.png differ diff --git a/libs/agent/README.md b/libs/agent/README.md new file mode 100644 index 00000000..c008d99c --- /dev/null +++ b/libs/agent/README.md @@ -0,0 +1,74 @@ +
+

+
+ + + + Shows my svg + +
+ + [![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#) + [![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#) + [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85) + [![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/) +

+
+ +**Agent** is a Computer Use (CUA) framework for running multi-app agentic workflows targeting macOS and Linux sandbox, supporting local (Ollama) and cloud model providers (OpenAI, Anthropic, Groq, DeepSeek, Qwen). The framework integrates with Microsoft's OmniParser for enhanced UI understanding and interaction. + +### Get started with Agent + +```python +from agent import ComputerAgent, AgentLoop, LLMProvider +from computer import Computer + +computer = Computer(verbosity=logging.INFO) + +agent = ComputerAgent( + computer=computer, + loop=AgentLoop.ANTHROPIC, + # loop=AgentLoop.OMNI, + model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"), + # model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"), + save_trajectory=True, + trajectory_dir=str(Path("trajectories")), + only_n_most_recent_images=3, + verbosity=logging.INFO, +) + +tasks = [ +""" +Please help me with the following task: +1. Open Safari browser +2. Go to Wikipedia.org +3. Search for "Claude AI" +4. Summarize the main points you find about Claude AI +""" +] + +async with agent: + for i, task in enumerate(tasks, 1): + print(f"\nExecuting task {i}/{len(tasks)}: {task}") + async for result in agent.run(task): + print(result) + print(f"Task {i} completed") +``` + +## Install + +### cua-agent + +```bash +pip install "cua-agent[all]" + +# or install specific loop providers +pip install "cua-agent[anthropic]" +pip install "cua-agent[omni]" +``` + +## Run + +Refer to these notebooks for step-by-step guides on how to use the Computer-Use Agent (CUA): + +- [Agent Notebook](../../notebooks/agent_nb.ipynb) - Complete examples and workflows \ No newline at end of file diff --git a/libs/agent/agent/README.md b/libs/agent/agent/README.md new file mode 100644 index 00000000..d1712bd6 --- /dev/null +++ b/libs/agent/agent/README.md @@ -0,0 +1,63 @@ +# Agent Package Structure + +## Overview +The agent package provides a modular and extensible framework for AI-powered computer agents. + +## Directory Structure +``` +agent/ +├── __init__.py # Package exports +├── core/ # Core functionality +│ ├── __init__.py +│ ├── computer_agent.py # Main entry point +│ └── factory.py # Provider factory +├── base/ # Base implementations +│ ├── __init__.py +│ ├── agent.py # Base agent class +│ ├── core/ # Core components +│ │ ├── callbacks.py +│ │ ├── loop.py +│ │ └── messages.py +│ └── tools/ # Tool implementations +├── providers/ # Provider implementations +│ ├── __init__.py +│ ├── anthropic/ # Anthropic provider +│ │ ├── agent.py +│ │ ├── loop.py +│ │ └── tool_manager.py +│ └── omni/ # Omni provider +│ ├── agent.py +│ ├── loop.py +│ └── tool_manager.py +└── types/ # Type definitions + ├── __init__.py + ├── base.py # Core types + ├── messages.py # Message types + ├── tools.py # Tool types + └── providers/ # Provider-specific types + ├── anthropic.py + └── omni.py +``` + +## Key Components + +### Core +- `computer_agent.py`: Main entry point for creating and using agents +- `factory.py`: Factory for creating provider-specific implementations + +### Base +- `agent.py`: Base agent implementation with shared functionality +- `core/`: Core components used across providers +- `tools/`: Shared tool implementations + +### Providers +Each provider follows the same structure: +- `agent.py`: Provider-specific agent implementation +- `loop.py`: Provider-specific message loop +- `tool_manager.py`: Tool management for provider + +### Types +- `base.py`: Core type definitions +- `messages.py`: Message-related types +- `tools.py`: Tool-related types +- `providers/`: Provider-specific type definitions diff --git a/libs/agent/agent/__init__.py b/libs/agent/agent/__init__.py new file mode 100644 index 00000000..cbc46bf6 --- /dev/null +++ b/libs/agent/agent/__init__.py @@ -0,0 +1,56 @@ +"""CUA (Computer Use) Agent for AI-driven computer interaction.""" + +import sys +import logging + +__version__ = "0.1.0" + +# Initialize logging +logger = logging.getLogger("cua.agent") + +# Initialize telemetry when the package is imported +try: + # Import from core telemetry for basic functions + from core.telemetry import ( + is_telemetry_enabled, + flush, + record_event, + ) + + # Import set_dimension from our own telemetry module + from .core.telemetry import set_dimension + + # Check if telemetry is enabled + if is_telemetry_enabled(): + logger.info("Telemetry is enabled") + + # Record package initialization + record_event( + "module_init", + { + "module": "agent", + "version": __version__, + "python_version": sys.version, + }, + ) + + # Set the package version as a dimension + set_dimension("agent_version", __version__) + + # Flush events to ensure they're sent + flush() + else: + logger.info("Telemetry is disabled") +except ImportError as e: + # Telemetry not available + logger.warning(f"Telemetry not available: {e}") +except Exception as e: + # Other issues with telemetry + logger.warning(f"Error initializing telemetry: {e}") + +from .core.factory import AgentFactory +from .core.agent import ComputerAgent +from .providers.omni.types import LLMProvider, LLM +from .types.base import Provider, AgentLoop + +__all__ = ["AgentFactory", "Provider", "ComputerAgent", "AgentLoop", "LLMProvider", "LLM"] diff --git a/libs/agent/agent/core/README.md b/libs/agent/agent/core/README.md new file mode 100644 index 00000000..9916f7b5 --- /dev/null +++ b/libs/agent/agent/core/README.md @@ -0,0 +1,101 @@ +# Unified ComputerAgent + +The `ComputerAgent` class provides a unified implementation that consolidates the previously separate agent implementations (AnthropicComputerAgent and OmniComputerAgent) into a single, configurable class. + +## Features + +- **Multiple Loop Types**: Switch between different agentic loop implementations using the `loop_type` parameter (Anthropic or Omni). +- **Provider Support**: Use different AI providers (OpenAI, Anthropic, etc.) with the appropriate loop. +- **Trajectory Saving**: Control whether to save screenshots and logs with the `save_trajectory` parameter. +- **Consistent Interface**: Maintains a consistent interface regardless of the underlying loop implementation. + +## API Key Requirements + +To use the ComputerAgent, you'll need API keys for the providers you want to use: + +- For **OpenAI**: Set the `OPENAI_API_KEY` environment variable or pass it directly as `api_key`. +- For **Anthropic**: Set the `ANTHROPIC_API_KEY` environment variable or pass it directly as `api_key`. +- For **Groq**: Set the `GROQ_API_KEY` environment variable or pass it directly as `api_key`. + +You can set environment variables in several ways: + +```bash +# In your terminal before running the code +export OPENAI_API_KEY=your_api_key_here + +# Or in a .env file +OPENAI_API_KEY=your_api_key_here +``` + +## Usage + +Here's how to use the unified ComputerAgent: + +```python +from agent.core.agent import ComputerAgent +from agent.types.base import AgenticLoop +from agent.providers.omni.types import LLMProvider +from computer import Computer + +# Create a Computer instance +computer = Computer() + +# Create an agent with the OMNI loop and OpenAI provider +agent = ComputerAgent( + computer=computer, + loop_type=AgenticLoop.OMNI, + provider=LLMProvider.OPENAI, + model="gpt-4o", + api_key="your_api_key_here", # Can also use OPENAI_API_KEY environment variable + save_trajectory=True, + only_n_most_recent_images=5 +) + +# Create an agent with the ANTHROPIC loop +agent = ComputerAgent( + computer=computer, + loop_type=AgenticLoop.ANTHROPIC, + model="claude-3-7-sonnet-20250219", + api_key="your_api_key_here", # Can also use ANTHROPIC_API_KEY environment variable + save_trajectory=True, + only_n_most_recent_images=5 +) + +# Use the agent +async with agent: + async for result in agent.run("Your task description here"): + # Process the result + title = result["metadata"].get("title", "Screen Analysis") + content = result["content"] + print(f"\n{title}") + print(content) +``` + +## Parameters + +- `computer`: Computer instance to control +- `loop_type`: The type of loop to use (AgenticLoop.ANTHROPIC or AgenticLoop.OMNI) +- `provider`: AI provider to use (required for Omni loop) +- `api_key`: Optional API key (will use environment variable if not provided) +- `model`: Optional model name (will use provider default if not specified) +- `save_trajectory`: Whether to save screenshots and logs +- `only_n_most_recent_images`: Only keep N most recent images +- `max_retries`: Maximum number of retry attempts + +## Directory Structure + +When `save_trajectory` is enabled, the agent will create the following directory structure: + +``` +experiments/ + ├── screenshots/ # Screenshots captured during agent execution + └── logs/ # API call logs and other logging information +``` + +## Extending with New Loop Types + +To add a new loop type: + +1. Implement a new loop class +2. Add a new value to the `AgenticLoop` enum +3. Update the `_initialize_loop` method in `ComputerAgent` to handle the new loop type \ No newline at end of file diff --git a/libs/agent/agent/core/__init__.py b/libs/agent/agent/core/__init__.py new file mode 100644 index 00000000..68deb67e --- /dev/null +++ b/libs/agent/agent/core/__init__.py @@ -0,0 +1,34 @@ +"""Core agent components.""" + +from .base_agent import BaseComputerAgent +from .loop import BaseLoop +from .messages import ( + create_user_message, + create_assistant_message, + create_system_message, + create_image_message, + create_screen_message, + BaseMessageManager, + ImageRetentionConfig, +) +from .callbacks import ( + CallbackManager, + CallbackHandler, + BaseCallbackManager, + ContentCallback, + ToolCallback, + APICallback, +) + +__all__ = [ + "BaseComputerAgent", + "BaseLoop", + "CallbackManager", + "CallbackHandler", + "BaseMessageManager", + "ImageRetentionConfig", + "BaseCallbackManager", + "ContentCallback", + "ToolCallback", + "APICallback", +] diff --git a/libs/agent/agent/core/agent.py b/libs/agent/agent/core/agent.py new file mode 100644 index 00000000..f737f8ce --- /dev/null +++ b/libs/agent/agent/core/agent.py @@ -0,0 +1,252 @@ +"""Unified computer agent implementation that supports multiple loops.""" + +import os +import logging +import asyncio +import time +import uuid +from typing import Any, AsyncGenerator, Dict, List, Optional, TYPE_CHECKING, Union, cast +from datetime import datetime +from enum import Enum + +from computer import Computer + +from ..types.base import Provider, AgentLoop +from .base_agent import BaseComputerAgent +from ..core.telemetry import record_agent_initialization + +# Only import types for type checking to avoid circular imports +if TYPE_CHECKING: + from ..providers.anthropic.loop import AnthropicLoop + from ..providers.omni.loop import OmniLoop + from ..providers.omni.parser import OmniParser + +# Import the provider types +from ..providers.omni.types import LLMProvider, LLM, Model, LLMModel + +logger = logging.getLogger(__name__) + +# Default models for different providers +DEFAULT_MODELS = { + LLMProvider.OPENAI: "gpt-4o", + LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219", +} + +# Map providers to their environment variable names +ENV_VARS = { + LLMProvider.OPENAI: "OPENAI_API_KEY", + LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY", +} + + +class ComputerAgent(BaseComputerAgent): + """Unified implementation of the computer agent supporting multiple loop types. + + This class consolidates the previous AnthropicComputerAgent and OmniComputerAgent + into a single implementation with configurable loop type. + """ + + def __init__( + self, + computer: Computer, + loop: AgentLoop = AgentLoop.OMNI, + model: Optional[Union[LLM, Dict[str, str], str]] = None, + api_key: Optional[str] = None, + save_trajectory: bool = True, + trajectory_dir: Optional[str] = "trajectories", + only_n_most_recent_images: Optional[int] = None, + max_retries: int = 3, + verbosity: int = logging.INFO, + telemetry_enabled: bool = True, + **kwargs, + ): + """Initialize a ComputerAgent instance. + + Args: + computer: The Computer instance to control + loop: The agent loop to use: ANTHROPIC or OMNI + model: The model to use. Can be a string, dict or LLM object. + Defaults to LLM for the loop type. + api_key: The API key to use. If None, will use environment variables. + save_trajectory: Whether to save the trajectory. + trajectory_dir: The directory to save trajectories to. + only_n_most_recent_images: Only keep this many most recent images. + max_retries: Maximum number of retries for failed requests. + verbosity: Logging level (standard Python logging levels). + telemetry_enabled: Whether to enable telemetry tracking. Defaults to True. + **kwargs: Additional keyword arguments to pass to the loop. + """ + super().__init__(computer) + self._configure_logging(verbosity) + logger.info(f"Initializing ComputerAgent with {loop} loop") + + # Store telemetry preference + self.telemetry_enabled = telemetry_enabled + + # Process the model configuration + self.model = self._process_model_config(model, loop) + self.loop_type = loop + self.api_key = api_key + + # Store computer + self.computer = computer + + # Save trajectory settings + self.save_trajectory = save_trajectory + self.trajectory_dir = trajectory_dir + self.only_n_most_recent_images = only_n_most_recent_images + + # Store the max retries setting + self.max_retries = max_retries + + # Initialize message history + self.messages = [] + + # Extra kwargs for the loop + self.loop_kwargs = kwargs + + # Initialize the actual loop implementation + self.loop = self._init_loop() + + # Record initialization in telemetry if enabled + if telemetry_enabled: + record_agent_initialization() + + def _process_model_config( + self, model_input: Optional[Union[LLM, Dict[str, str], str]], loop: AgentLoop + ) -> LLM: + """Process and normalize model configuration. + + Args: + model_input: Input model configuration (LLM, dict, string, or None) + loop: The loop type being used + + Returns: + Normalized LLM instance + """ + # Handle case where model_input is None + if model_input is None: + # Use Anthropic for Anthropic loop, OpenAI for Omni loop + default_provider = ( + LLMProvider.ANTHROPIC if loop == AgentLoop.ANTHROPIC else LLMProvider.OPENAI + ) + return LLM(provider=default_provider) + + # Handle case where model_input is already a LLM or one of its aliases + if isinstance(model_input, (LLM, Model, LLMModel)): + return model_input + + # Handle case where model_input is a dict + if isinstance(model_input, dict): + provider = model_input.get("provider", LLMProvider.OPENAI) + if isinstance(provider, str): + provider = LLMProvider(provider) + return LLM(provider=provider, name=model_input.get("name")) + + # Handle case where model_input is a string (model name) + if isinstance(model_input, str): + default_provider = ( + LLMProvider.ANTHROPIC if loop == AgentLoop.ANTHROPIC else LLMProvider.OPENAI + ) + return LLM(provider=default_provider, name=model_input) + + raise ValueError(f"Unsupported model configuration: {model_input}") + + def _configure_logging(self, verbosity: int): + """Configure logging based on verbosity level.""" + # Use the logging level directly without mapping + logger.setLevel(verbosity) + logging.getLogger("agent").setLevel(verbosity) + + # Log the verbosity level that was set + if verbosity <= logging.DEBUG: + logger.info("Agent logging set to DEBUG level (full debug information)") + elif verbosity <= logging.INFO: + logger.info("Agent logging set to INFO level (standard output)") + elif verbosity <= logging.WARNING: + logger.warning("Agent logging set to WARNING level (warnings and errors only)") + elif verbosity <= logging.ERROR: + logger.warning("Agent logging set to ERROR level (errors only)") + elif verbosity <= logging.CRITICAL: + logger.warning("Agent logging set to CRITICAL level (critical errors only)") + + def _init_loop(self) -> Any: + """Initialize the loop based on the loop_type. + + Returns: + Initialized loop instance + """ + # Lazy import OmniLoop and OmniParser to avoid circular imports + from ..providers.omni.loop import OmniLoop + from ..providers.omni.parser import OmniParser + + if self.loop_type == AgentLoop.ANTHROPIC: + from ..providers.anthropic.loop import AnthropicLoop + + # Ensure we always have a valid model name + model_name = self.model.name or DEFAULT_MODELS[LLMProvider.ANTHROPIC] + + return AnthropicLoop( + api_key=self.api_key, + model=model_name, + computer=self.computer, + save_trajectory=self.save_trajectory, + base_dir=self.trajectory_dir, + only_n_most_recent_images=self.only_n_most_recent_images, + **self.loop_kwargs, + ) + + # Initialize parser for OmniLoop with appropriate device + if "parser" not in self.loop_kwargs: + self.loop_kwargs["parser"] = OmniParser() + + # Ensure we always have a valid model name + model_name = self.model.name or DEFAULT_MODELS[self.model.provider] + + return OmniLoop( + provider=self.model.provider, + api_key=self.api_key, + model=model_name, + computer=self.computer, + save_trajectory=self.save_trajectory, + base_dir=self.trajectory_dir, + only_n_most_recent_images=self.only_n_most_recent_images, + **self.loop_kwargs, + ) + + async def _execute_task(self, task: str) -> AsyncGenerator[Dict[str, Any], None]: + """Execute a task using the appropriate agent loop. + + Args: + task: The task to execute + + Returns: + AsyncGenerator yielding task outputs + """ + logger.info(f"Executing task: {task}") + + try: + # Create a message from the task + task_message = {"role": "user", "content": task} + messages_with_task = self.messages + [task_message] + + # Use the run method of the loop + async for output in self.loop.run(messages_with_task): + yield output + except Exception as e: + logger.error(f"Error executing task: {e}") + raise + finally: + pass + + async def _execute_action(self, action_type: str, **action_params) -> Any: + """Execute an action with telemetry tracking.""" + try: + # Execute the action + result = await super()._execute_action(action_type, **action_params) + return result + except Exception as e: + logger.exception(f"Error executing action {action_type}: {e}") + raise + finally: + pass diff --git a/libs/agent/agent/core/base_agent.py b/libs/agent/agent/core/base_agent.py new file mode 100644 index 00000000..7227bd5a --- /dev/null +++ b/libs/agent/agent/core/base_agent.py @@ -0,0 +1,164 @@ +"""Base computer agent implementation.""" + +import asyncio +import logging +import os +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, Dict, Optional + +from computer import Computer + +from ..types.base import Provider + +logger = logging.getLogger(__name__) + + +class BaseComputerAgent(ABC): + """Base class for computer agents.""" + + def __init__( + self, + max_retries: int = 3, + computer: Optional[Computer] = None, + screenshot_dir: Optional[str] = None, + log_dir: Optional[str] = None, + **kwargs, + ): + """Initialize the base computer agent. + + Args: + max_retries: Maximum number of retry attempts + computer: Optional Computer instance + screenshot_dir: Directory to save screenshots + log_dir: Directory to save logs (set to None to disable logging to files) + **kwargs: Additional provider-specific arguments + """ + self.max_retries = max_retries + self.computer = computer or Computer() + self.queue = asyncio.Queue() + self.screenshot_dir = screenshot_dir + self.log_dir = log_dir + self._retry_count = 0 + self.provider = Provider.UNKNOWN + + # Setup logging + if self.log_dir: + os.makedirs(self.log_dir, exist_ok=True) + logger.info(f"Created logs directory: {self.log_dir}") + + # Setup screenshots directory + if self.screenshot_dir: + os.makedirs(self.screenshot_dir, exist_ok=True) + logger.info(f"Created screenshots directory: {self.screenshot_dir}") + + logger.info("BaseComputerAgent initialized") + + async def run(self, task: str) -> AsyncGenerator[Dict[str, Any], None]: + """Run a task using the computer agent. + + Args: + task: Task description + + Yields: + Task execution updates + """ + try: + logger.info(f"Running task: {task}") + + # Initialize the computer if needed + await self._init_if_needed() + + # Execute the task and yield results + # The _execute_task method should be implemented to yield results + async for result in self._execute_task(task): + yield result + + except Exception as e: + logger.error(f"Error in agent run method: {str(e)}") + yield { + "role": "assistant", + "content": f"Error: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + + async def _init_if_needed(self): + """Initialize the computer interface if it hasn't been initialized yet.""" + if not self.computer._initialized: + logger.info("Computer not initialized, initializing now...") + try: + # Call run directly without setting the flag first + await self.computer.run() + logger.info("Computer interface initialized successfully") + except Exception as e: + logger.error(f"Error initializing computer interface: {str(e)}") + raise + + async def __aenter__(self): + """Initialize the agent when used as a context manager.""" + logger.info("Entering BaseComputerAgent context") + + # In case the computer wasn't initialized + try: + # Initialize the computer only if not already initialized + logger.info("Checking if computer is already initialized...") + if not self.computer._initialized: + logger.info("Initializing computer in __aenter__...") + # Use the computer's __aenter__ directly instead of calling run() + # This avoids the circular dependency + await self.computer.__aenter__() + logger.info("Computer initialized in __aenter__") + else: + logger.info("Computer already initialized, skipping initialization") + + # Take a test screenshot to verify the computer is working + logger.info("Testing computer with a screenshot...") + try: + test_screenshot = await self.computer.interface.screenshot() + # Determine the screenshot size based on its type + if isinstance(test_screenshot, bytes): + size = len(test_screenshot) + else: + # Assume it's an object with base64_image attribute + try: + size = len(test_screenshot.base64_image) + except AttributeError: + size = "unknown" + logger.info(f"Screenshot test successful, size: {size}") + except Exception as e: + logger.error(f"Screenshot test failed: {str(e)}") + # Even though screenshot failed, we continue since some tests might not need it + except Exception as e: + logger.error(f"Error initializing computer in __aenter__: {str(e)}") + raise + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Cleanup computer resources if needed.""" + logger.info("Cleaning up agent resources") + + # Do any necessary cleanup + # We're not shutting down the computer here as it might be shared + # Just log that we're exiting + if exc_type: + logger.error(f"Exiting agent context with error: {exc_type.__name__}: {exc_val}") + else: + logger.info("Exiting agent context normally") + + # If we have a queue, make sure to signal it's done + if hasattr(self, "queue") and self.queue: + await self.queue.put(None) # Signal that we're done + + @abstractmethod + async def _execute_task(self, task: str) -> AsyncGenerator[Dict[str, Any], None]: + """Execute a task. Must be implemented by subclasses. + + This is an async method that returns an AsyncGenerator. Implementations + should use 'yield' statements to produce results asynchronously. + """ + yield { + "role": "assistant", + "content": "Base class method called", + "metadata": {"title": "Error"}, + } + raise NotImplementedError("Subclasses must implement _execute_task") diff --git a/libs/agent/agent/core/callbacks.py b/libs/agent/agent/core/callbacks.py new file mode 100644 index 00000000..70eca5ad --- /dev/null +++ b/libs/agent/agent/core/callbacks.py @@ -0,0 +1,147 @@ +"""Callback handlers for agent.""" + +import json +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Optional, Protocol + +logger = logging.getLogger(__name__) + +class ContentCallback(Protocol): + """Protocol for content callbacks.""" + def __call__(self, content: Dict[str, Any]) -> None: ... + +class ToolCallback(Protocol): + """Protocol for tool callbacks.""" + def __call__(self, result: Any, tool_id: str) -> None: ... + +class APICallback(Protocol): + """Protocol for API callbacks.""" + def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ... + +class BaseCallbackManager(ABC): + """Base class for callback managers.""" + + def __init__( + self, + content_callback: ContentCallback, + tool_callback: ToolCallback, + api_callback: APICallback, + ): + """Initialize the callback manager. + + Args: + content_callback: Callback for content updates + tool_callback: Callback for tool execution results + api_callback: Callback for API interactions + """ + self.content_callback = content_callback + self.tool_callback = tool_callback + self.api_callback = api_callback + + @abstractmethod + def on_content(self, content: Any) -> None: + """Handle content updates.""" + raise NotImplementedError + + @abstractmethod + def on_tool_result(self, result: Any, tool_id: str) -> None: + """Handle tool execution results.""" + raise NotImplementedError + + @abstractmethod + def on_api_interaction( + self, + request: Any, + response: Any, + error: Optional[Exception] = None + ) -> None: + """Handle API interactions.""" + raise NotImplementedError + + +class CallbackManager: + """Manager for callback handlers.""" + + def __init__(self, handlers: Optional[List["CallbackHandler"]] = None): + """Initialize with optional handlers. + + Args: + handlers: List of callback handlers + """ + self.handlers = handlers or [] + + def add_handler(self, handler: "CallbackHandler") -> None: + """Add a callback handler. + + Args: + handler: Callback handler to add + """ + self.handlers.append(handler) + + async def on_action_start(self, action: str, **kwargs) -> None: + """Called when an action starts. + + Args: + action: Action name + **kwargs: Additional data + """ + for handler in self.handlers: + await handler.on_action_start(action, **kwargs) + + async def on_action_end(self, action: str, success: bool, **kwargs) -> None: + """Called when an action ends. + + Args: + action: Action name + success: Whether the action was successful + **kwargs: Additional data + """ + for handler in self.handlers: + await handler.on_action_end(action, success, **kwargs) + + async def on_error(self, error: Exception, **kwargs) -> None: + """Called when an error occurs. + + Args: + error: Exception that occurred + **kwargs: Additional data + """ + for handler in self.handlers: + await handler.on_error(error, **kwargs) + + +class CallbackHandler(ABC): + """Base class for callback handlers.""" + + @abstractmethod + async def on_action_start(self, action: str, **kwargs) -> None: + """Called when an action starts. + + Args: + action: Action name + **kwargs: Additional data + """ + pass + + @abstractmethod + async def on_action_end(self, action: str, success: bool, **kwargs) -> None: + """Called when an action ends. + + Args: + action: Action name + success: Whether the action was successful + **kwargs: Additional data + """ + pass + + @abstractmethod + async def on_error(self, error: Exception, **kwargs) -> None: + """Called when an error occurs. + + Args: + error: Exception that occurred + **kwargs: Additional data + """ + pass \ No newline at end of file diff --git a/libs/agent/agent/core/computer_agent.py b/libs/agent/agent/core/computer_agent.py new file mode 100644 index 00000000..875f7049 --- /dev/null +++ b/libs/agent/agent/core/computer_agent.py @@ -0,0 +1,69 @@ +"""Main entry point for computer agents.""" + +import logging +from typing import Any, AsyncGenerator, Dict, Optional + +from computer import Computer +from ..types.base import Provider +from .factory import AgentFactory + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class ComputerAgent: + """A computer agent that can perform automated tasks using natural language instructions.""" + + def __init__(self, provider: Provider, computer: Optional[Computer] = None, **kwargs): + """Initialize the ComputerAgent. + + Args: + provider: The AI provider to use (e.g., Provider.ANTHROPIC) + computer: Optional Computer instance. If not provided, one will be created with default settings. + **kwargs: Additional provider-specific arguments + """ + self.provider = provider + self._computer = computer + self._kwargs = kwargs + self._agent = None + self._initialized = False + self._in_context = False + + # Create provider-specific agent using factory + self._agent = AgentFactory.create(provider=provider, computer=computer, **kwargs) + + async def __aenter__(self): + """Enter the async context manager.""" + self._in_context = True + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the async context manager.""" + self._in_context = False + + async def initialize(self) -> None: + """Initialize the agent and its components.""" + if not self._initialized: + if not self._in_context and self._computer: + # If not in context manager but have a computer, initialize it + await self._computer.run() + self._initialized = True + + async def run(self, task: str) -> AsyncGenerator[Dict[str, Any], None]: + """Run the agent with a given task.""" + if not self._initialized: + await self.initialize() + + if self._agent is None: + logger.error("Agent not initialized properly") + yield {"error": "Agent not initialized properly"} + return + + async for result in self._agent.run(task): + yield result + + @property + def computer(self) -> Optional[Computer]: + """Get the underlying computer instance.""" + return self._agent.computer if self._agent else None diff --git a/libs/agent/agent/core/experiment.py b/libs/agent/agent/core/experiment.py new file mode 100644 index 00000000..c5162e78 --- /dev/null +++ b/libs/agent/agent/core/experiment.py @@ -0,0 +1,232 @@ +"""Core experiment management for agents.""" + +import os +import logging +import base64 +from io import BytesIO +from datetime import datetime +from typing import Any, Dict, List, Optional +from PIL import Image +import json +import re + +logger = logging.getLogger(__name__) + + +class ExperimentManager: + """Manages experiment directories and logging for the agent.""" + + def __init__( + self, + base_dir: Optional[str] = None, + only_n_most_recent_images: Optional[int] = None, + ): + """Initialize the experiment manager. + + Args: + base_dir: Base directory for saving experiment data + only_n_most_recent_images: Maximum number of recent screenshots to include in API requests + """ + self.base_dir = base_dir + self.only_n_most_recent_images = only_n_most_recent_images + self.run_dir = None + self.current_turn_dir = None + self.turn_count = 0 + self.screenshot_count = 0 + # Track all screenshots for potential API request inclusion + self.screenshot_paths = [] + + # Set up experiment directories if base_dir is provided + if self.base_dir: + self.setup_experiment_dirs() + + def setup_experiment_dirs(self) -> None: + """Setup the experiment directory structure.""" + if not self.base_dir: + return + + # Create base experiments directory if it doesn't exist + os.makedirs(self.base_dir, exist_ok=True) + + # Create timestamped run directory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.run_dir = os.path.join(self.base_dir, timestamp) + os.makedirs(self.run_dir, exist_ok=True) + logger.info(f"Created run directory: {self.run_dir}") + + # Create first turn directory + self.create_turn_dir() + + def create_turn_dir(self) -> None: + """Create a new directory for the current turn.""" + if not self.run_dir: + logger.warning("Cannot create turn directory: run_dir not set") + return + + # Increment turn counter + self.turn_count += 1 + + # Create turn directory with padded number + turn_name = f"turn_{self.turn_count:03d}" + self.current_turn_dir = os.path.join(self.run_dir, turn_name) + os.makedirs(self.current_turn_dir, exist_ok=True) + logger.info(f"Created turn directory: {self.current_turn_dir}") + + def sanitize_log_data(self, data: Any) -> Any: + """Sanitize log data by replacing large binary data with placeholders. + + Args: + data: Data to sanitize + + Returns: + Sanitized copy of the data + """ + if isinstance(data, dict): + result = {} + for k, v in data.items(): + result[k] = self.sanitize_log_data(v) + return result + elif isinstance(data, list): + return [self.sanitize_log_data(item) for item in data] + elif isinstance(data, str) and len(data) > 1000 and "base64" in data.lower(): + return f"[BASE64_DATA_LENGTH_{len(data)}]" + else: + return data + + def save_screenshot(self, img_base64: str, action_type: str = "") -> None: + """Save a screenshot to the experiment directory. + + Args: + img_base64: Base64 encoded screenshot + action_type: Type of action that triggered the screenshot + """ + if not self.current_turn_dir: + return + + try: + # Increment screenshot counter + self.screenshot_count += 1 + + # Sanitize action_type to ensure valid filename + # Replace characters that are not safe for filenames + sanitized_action = "" + if action_type: + # Replace invalid filename characters with underscores + sanitized_action = re.sub(r'[\\/*?:"<>|]', "_", action_type) + # Limit the length to avoid excessively long filenames + sanitized_action = sanitized_action[:50] + + # Create a descriptive filename + timestamp = int(datetime.now().timestamp() * 1000) + action_suffix = f"_{sanitized_action}" if sanitized_action else "" + filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png" + + # Save directly to the turn directory + filepath = os.path.join(self.current_turn_dir, filename) + + # Save the screenshot + img_data = base64.b64decode(img_base64) + with open(filepath, "wb") as f: + f.write(img_data) + + # Keep track of the file path + self.screenshot_paths.append(filepath) + + return filepath + except Exception as e: + logger.error(f"Error saving screenshot: {str(e)}") + return None + + def save_action_visualization( + self, img: Image.Image, action_name: str, details: str = "" + ) -> str: + """Save a visualization of an action. + + Args: + img: Image to save + action_name: Name of the action + details: Additional details about the action + + Returns: + Path to the saved image + """ + if not self.current_turn_dir: + return "" + + try: + # Create a descriptive filename + timestamp = int(datetime.now().timestamp() * 1000) + details_suffix = f"_{details}" if details else "" + filename = f"vis_{action_name}{details_suffix}_{timestamp}.png" + + # Save directly to the turn directory + filepath = os.path.join(self.current_turn_dir, filename) + + # Save the image + img.save(filepath) + + # Keep track of the file path + self.screenshot_paths.append(filepath) + + return filepath + except Exception as e: + logger.error(f"Error saving action visualization: {str(e)}") + return "" + + def log_api_call( + self, + call_type: str, + request: Any, + provider: str = "unknown", + model: str = "unknown", + response: Any = None, + error: Optional[Exception] = None, + ) -> None: + """Log API call details to file. + + Args: + call_type: Type of API call (request, response, error) + request: Request data + provider: API provider name + model: Model name + response: Response data (for response logs) + error: Error information (for error logs) + """ + if not self.current_turn_dir: + logger.warning("Cannot log API call: current_turn_dir not set") + return + + try: + # Create a timestamp for the log file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Create filename based on log type + filename = f"api_call_{timestamp}_{call_type}.json" + filepath = os.path.join(self.current_turn_dir, filename) + + # Sanitize data before logging + sanitized_request = self.sanitize_log_data(request) + sanitized_response = self.sanitize_log_data(response) if response is not None else None + + # Prepare log data + log_data = { + "timestamp": timestamp, + "provider": provider, + "model": model, + "type": call_type, + "request": sanitized_request, + } + + if sanitized_response is not None: + log_data["response"] = sanitized_response + if error is not None: + log_data["error"] = str(error) + + # Write to file + with open(filepath, "w") as f: + json.dump(log_data, f, indent=2, default=str) + + logger.info(f"Logged API {call_type} to {filepath}") + + except Exception as e: + logger.error(f"Error logging API call: {str(e)}") diff --git a/libs/agent/agent/core/factory.py b/libs/agent/agent/core/factory.py new file mode 100644 index 00000000..e2454134 --- /dev/null +++ b/libs/agent/agent/core/factory.py @@ -0,0 +1,102 @@ +"""Factory for creating provider-specific agents.""" + +from typing import Optional, Dict, Any, List + +from computer import Computer +from ..types.base import Provider +from .base_agent import BaseComputerAgent + +# Import provider-specific implementations +_ANTHROPIC_AVAILABLE = False +_OPENAI_AVAILABLE = False +_OLLAMA_AVAILABLE = False +_OMNI_AVAILABLE = False + +# Try importing providers +try: + import anthropic + from ..providers.anthropic.agent import AnthropicComputerAgent + + _ANTHROPIC_AVAILABLE = True +except ImportError: + pass + +try: + import openai + + _OPENAI_AVAILABLE = True +except ImportError: + pass + +try: + from ..providers.omni.agent import OmniComputerAgent + + _OMNI_AVAILABLE = True +except ImportError: + pass + + +class AgentFactory: + """Factory for creating provider-specific agent implementations.""" + + @staticmethod + def create( + provider: Provider, computer: Optional[Computer] = None, **kwargs: Any + ) -> BaseComputerAgent: + """Create an agent based on the specified provider. + + Args: + provider: The AI provider to use + computer: Optional Computer instance + **kwargs: Additional provider-specific arguments + + Returns: + A provider-specific agent implementation + + Raises: + ImportError: If provider dependencies are not installed + ValueError: If provider is not supported + """ + # Create a Computer instance if none is provided + if computer is None: + computer = Computer() + + if provider == Provider.ANTHROPIC: + if not _ANTHROPIC_AVAILABLE: + raise ImportError( + "Anthropic provider requires additional dependencies. " + "Install them with: pip install cua-agent[anthropic]" + ) + return AnthropicComputerAgent(max_retries=3, computer=computer, **kwargs) + elif provider == Provider.OPENAI: + if not _OPENAI_AVAILABLE: + raise ImportError( + "OpenAI provider requires additional dependencies. " + "Install them with: pip install cua-agent[openai]" + ) + raise NotImplementedError("OpenAI provider not yet implemented") + elif provider == Provider.OLLAMA: + if not _OLLAMA_AVAILABLE: + raise ImportError( + "Ollama provider requires additional dependencies. " + "Install them with: pip install cua-agent[ollama]" + ) + # Only import ollama when actually creating an Ollama agent + try: + import ollama + from ..providers.ollama.agent import OllamaComputerAgent + + return OllamaComputerAgent(max_retries=3, computer=computer, **kwargs) + except ImportError: + raise ImportError( + "Failed to import ollama package. " "Install it with: pip install ollama" + ) + elif provider == Provider.OMNI: + if not _OMNI_AVAILABLE: + raise ImportError( + "Omni provider requires additional dependencies. " + "Install them with: pip install cua-agent[omni]" + ) + return OmniComputerAgent(max_retries=3, computer=computer, **kwargs) + else: + raise ValueError(f"Unsupported provider: {provider}") diff --git a/libs/agent/agent/core/loop.py b/libs/agent/agent/core/loop.py new file mode 100644 index 00000000..81b41f6e --- /dev/null +++ b/libs/agent/agent/core/loop.py @@ -0,0 +1,244 @@ +"""Base agent loop implementation.""" + +import logging +import asyncio +import json +import os +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from datetime import datetime +import base64 + +from computer import Computer +from .experiment import ExperimentManager + +logger = logging.getLogger(__name__) + + +class BaseLoop(ABC): + """Base class for agent loops that handle message processing and tool execution.""" + + def __init__( + self, + computer: Computer, + model: str, + api_key: str, + max_tokens: int = 4096, + max_retries: int = 3, + retry_delay: float = 1.0, + base_dir: Optional[str] = "trajectories", + save_trajectory: bool = True, + only_n_most_recent_images: Optional[int] = 2, + **kwargs, + ): + """Initialize base agent loop. + + Args: + computer: Computer instance to control + model: Model name to use + api_key: API key for provider + max_tokens: Maximum tokens to generate + max_retries: Maximum number of retries + retry_delay: Delay between retries in seconds + base_dir: Base directory for saving experiment data + save_trajectory: Whether to save trajectory data + only_n_most_recent_images: Maximum number of recent screenshots to include in API requests + **kwargs: Additional provider-specific arguments + """ + self.computer = computer + self.model = model + self.api_key = api_key + self.max_tokens = max_tokens + self.max_retries = max_retries + self.retry_delay = retry_delay + self.base_dir = base_dir + self.save_trajectory = save_trajectory + self.only_n_most_recent_images = only_n_most_recent_images + self._kwargs = kwargs + self.message_history = [] + # self.tool_manager = BaseToolManager(computer) + + # Initialize experiment manager + if self.save_trajectory and self.base_dir: + self.experiment_manager = ExperimentManager( + base_dir=self.base_dir, + only_n_most_recent_images=only_n_most_recent_images, + ) + # Track directories for convenience + self.run_dir = self.experiment_manager.run_dir + self.current_turn_dir = self.experiment_manager.current_turn_dir + else: + self.experiment_manager = None + self.run_dir = None + self.current_turn_dir = None + + # Initialize basic tracking + self.turn_count = 0 + + def _setup_experiment_dirs(self) -> None: + """Setup the experiment directory structure.""" + if self.experiment_manager: + # Use the experiment manager to set up directories + self.experiment_manager.setup_experiment_dirs() + + # Update local tracking variables + self.run_dir = self.experiment_manager.run_dir + self.current_turn_dir = self.experiment_manager.current_turn_dir + + def _create_turn_dir(self) -> None: + """Create a new directory for the current turn.""" + if self.experiment_manager: + # Use the experiment manager to create the turn directory + self.experiment_manager.create_turn_dir() + + # Update local tracking variables + self.current_turn_dir = self.experiment_manager.current_turn_dir + self.turn_count = self.experiment_manager.turn_count + + def _log_api_call( + self, call_type: str, request: Any, response: Any = None, error: Optional[Exception] = None + ) -> None: + """Log API call details to file. + + Args: + call_type: Type of API call (e.g., 'request', 'response', 'error') + request: The API request data + response: Optional API response data + error: Optional error information + """ + if self.experiment_manager: + # Use the experiment manager to log the API call + provider = getattr(self, "provider", "unknown") + provider_str = str(provider) if provider else "unknown" + + self.experiment_manager.log_api_call( + call_type=call_type, + request=request, + provider=provider_str, + model=self.model, + response=response, + error=error, + ) + + def _save_screenshot(self, img_base64: str, action_type: str = "") -> None: + """Save a screenshot to the experiment directory. + + Args: + img_base64: Base64 encoded screenshot + action_type: Type of action that triggered the screenshot + """ + if self.experiment_manager: + self.experiment_manager.save_screenshot(img_base64, action_type) + + async def initialize(self) -> None: + """Initialize both the API client and computer interface with retries.""" + for attempt in range(self.max_retries): + try: + logger.info( + f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..." + ) + + # Initialize API client + await self.initialize_client() + + # Initialize computer + await self.computer.initialize() + + logger.info("Initialization complete.") + return + except Exception as e: + if attempt < self.max_retries - 1: + logger.warning( + f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..." + ) + await asyncio.sleep(self.retry_delay) + else: + logger.error( + f"Initialization failed after {self.max_retries} attempts: {str(e)}" + ) + raise RuntimeError(f"Failed to initialize: {str(e)}") + + async def _get_parsed_screen_som(self) -> Dict[str, Any]: + """Get parsed screen information. + + Returns: + Dict containing screen information + """ + try: + # Take screenshot + screenshot = await self.computer.interface.screenshot() + + # Initialize with default values + width, height = 1024, 768 + base64_image = "" + + # Handle different types of screenshot returns + if isinstance(screenshot, bytes): + # Raw bytes screenshot + base64_image = base64.b64encode(screenshot).decode("utf-8") + elif hasattr(screenshot, "base64_image"): + # Object-style screenshot with attributes + base64_image = screenshot.base64_image + if hasattr(screenshot, "width") and hasattr(screenshot, "height"): + width = screenshot.width + height = screenshot.height + + # Create parsed screen data + parsed_screen = { + "width": width, + "height": height, + "parsed_content_list": [], + "timestamp": datetime.now().isoformat(), + "screenshot_base64": base64_image, + } + + # Save screenshot if requested + if self.save_trajectory and self.experiment_manager: + try: + img_data = base64_image + if "," in img_data: + img_data = img_data.split(",")[1] + self._save_screenshot(img_data, action_type="state") + except Exception as e: + logger.error(f"Error saving screenshot: {str(e)}") + + return parsed_screen + except Exception as e: + logger.error(f"Error taking screenshot: {str(e)}") + return { + "width": 1024, + "height": 768, + "parsed_content_list": [], + "timestamp": datetime.now().isoformat(), + "error": f"Error taking screenshot: {str(e)}", + "screenshot_base64": "", + } + + @abstractmethod + async def initialize_client(self) -> None: + """Initialize the API client and any provider-specific components.""" + raise NotImplementedError + + @abstractmethod + async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]: + """Run the agent loop with provided messages. + + Args: + messages: List of message objects + + Yields: + Dict containing response data + """ + raise NotImplementedError + + @abstractmethod + async def _process_screen( + self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]] + ) -> None: + """Process screen information and add to messages. + + Args: + parsed_screen: Dictionary containing parsed screen info + messages: List of messages to update + """ + raise NotImplementedError diff --git a/libs/agent/agent/core/messages.py b/libs/agent/agent/core/messages.py new file mode 100644 index 00000000..d9a24e7b --- /dev/null +++ b/libs/agent/agent/core/messages.py @@ -0,0 +1,245 @@ +"""Message handling utilities for agent.""" + +import base64 +from datetime import datetime +from io import BytesIO +import logging +from typing import Any, Dict, List, Optional, Union +from PIL import Image +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class ImageRetentionConfig: + """Configuration for image retention in messages.""" + + num_images_to_keep: Optional[int] = None + min_removal_threshold: int = 1 + enable_caching: bool = True + + def should_retain_images(self) -> bool: + """Check if image retention is enabled.""" + return self.num_images_to_keep is not None and self.num_images_to_keep > 0 + + +class BaseMessageManager: + """Base class for message preparation and management.""" + + def __init__(self, image_retention_config: Optional[ImageRetentionConfig] = None): + """Initialize the message manager. + + Args: + image_retention_config: Configuration for image retention + """ + self.image_retention_config = image_retention_config or ImageRetentionConfig() + if self.image_retention_config.min_removal_threshold < 1: + raise ValueError("min_removal_threshold must be at least 1") + + # Track provider for message formatting + self.provider = "openai" # Default provider + + def set_provider(self, provider: str) -> None: + """Set the current provider to format messages for. + + Args: + provider: Provider name (e.g., 'openai', 'anthropic') + """ + self.provider = provider.lower() + + def prepare_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Prepare messages by applying image retention and caching as configured. + + Args: + messages: List of messages to prepare + + Returns: + Prepared messages + """ + if self.image_retention_config.should_retain_images(): + self._filter_images(messages) + if self.image_retention_config.enable_caching: + self._inject_caching(messages) + return messages + + def _filter_images(self, messages: List[Dict[str, Any]]) -> None: + """Filter messages to retain only the specified number of most recent images. + + Args: + messages: Messages to filter + """ + # Find all tool result blocks that contain images + tool_results = [ + item + for message in messages + for item in (message["content"] if isinstance(message["content"], list) else []) + if isinstance(item, dict) and item.get("type") == "tool_result" + ] + + # Count total images + total_images = sum( + 1 + for result in tool_results + for content in result.get("content", []) + if isinstance(content, dict) and content.get("type") == "image" + ) + + # Calculate how many images to remove + images_to_remove = total_images - (self.image_retention_config.num_images_to_keep or 0) + images_to_remove -= images_to_remove % self.image_retention_config.min_removal_threshold + + # Remove oldest images first + for result in tool_results: + if isinstance(result.get("content"), list): + new_content = [] + for content in result["content"]: + if isinstance(content, dict) and content.get("type") == "image": + if images_to_remove > 0: + images_to_remove -= 1 + continue + new_content.append(content) + result["content"] = new_content + + def _inject_caching(self, messages: List[Dict[str, Any]]) -> None: + """Inject caching control for recent message turns. + + Args: + messages: Messages to inject caching into + """ + # Only apply cache_control for Anthropic API, not OpenAI + if self.provider != "anthropic": + return + + # Default to caching last 3 turns + turns_to_cache = 3 + for message in reversed(messages): + if message["role"] == "user" and isinstance(content := message["content"], list): + if turns_to_cache: + turns_to_cache -= 1 + content[-1]["cache_control"] = {"type": "ephemeral"} + else: + content[-1].pop("cache_control", None) + break + + +def create_user_message(text: str) -> Dict[str, str]: + """Create a user message. + + Args: + text: The message text + + Returns: + Message dictionary + """ + return { + "role": "user", + "content": text, + } + + +def create_assistant_message(text: str) -> Dict[str, str]: + """Create an assistant message. + + Args: + text: The message text + + Returns: + Message dictionary + """ + return { + "role": "assistant", + "content": text, + } + + +def create_system_message(text: str) -> Dict[str, str]: + """Create a system message. + + Args: + text: The message text + + Returns: + Message dictionary + """ + return { + "role": "system", + "content": text, + } + + +def create_image_message( + image_base64: Optional[str] = None, + image_path: Optional[str] = None, + image_obj: Optional[Image.Image] = None, +) -> Dict[str, Union[str, List[Dict[str, Any]]]]: + """Create a message with an image. + + Args: + image_base64: Base64 encoded image + image_path: Path to image file + image_obj: PIL Image object + + Returns: + Message dictionary with content list + + Raises: + ValueError: If no image source is provided + """ + if not any([image_base64, image_path, image_obj]): + raise ValueError("Must provide one of image_base64, image_path, or image_obj") + + # Convert to base64 if needed + if image_path and not image_base64: + with open(image_path, "rb") as f: + image_bytes = f.read() + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + elif image_obj and not image_base64: + buffer = BytesIO() + image_obj.save(buffer, format="PNG") + image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + + return { + "role": "user", + "content": [ + {"type": "image", "image_url": {"url": f"data:image/png;base64,{image_base64}"}} + ], + } + + +def create_screen_message( + parsed_screen: Dict[str, Any], + include_raw: bool = False, +) -> Dict[str, Union[str, List[Dict[str, Any]]]]: + """Create a message with screen information. + + Args: + parsed_screen: Dictionary containing parsed screen info + include_raw: Whether to include raw screenshot base64 + + Returns: + Message dictionary with content + """ + if include_raw and "screenshot_base64" in parsed_screen: + # Create content list with both image and text + return { + "role": "user", + "content": [ + { + "type": "image", + "image_url": { + "url": f"data:image/png;base64,{parsed_screen['screenshot_base64']}" + }, + }, + { + "type": "text", + "text": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}", + }, + ], + } + else: + # Create text-only message with screen info + return { + "role": "user", + "content": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}", + } diff --git a/libs/agent/agent/core/telemetry.py b/libs/agent/agent/core/telemetry.py new file mode 100644 index 00000000..39865f55 --- /dev/null +++ b/libs/agent/agent/core/telemetry.py @@ -0,0 +1,130 @@ +"""Agent telemetry for tracking anonymous usage and feature usage.""" + +import logging +import os +import platform +import sys +from typing import Dict, Any + +# Import the core telemetry module +TELEMETRY_AVAILABLE = False + +try: + from core.telemetry import ( + record_event, + increment, + get_telemetry_client, + flush, + is_telemetry_enabled, + is_telemetry_globally_disabled, + ) + + def increment_counter(counter_name: str, value: int = 1) -> None: + """Wrapper for increment to maintain backward compatibility.""" + if is_telemetry_enabled(): + increment(counter_name, value) + + def set_dimension(name: str, value: Any) -> None: + """Set a dimension that will be attached to all events.""" + logger = logging.getLogger("cua.agent.telemetry") + logger.debug(f"Setting dimension {name}={value}") + + TELEMETRY_AVAILABLE = True + logger = logging.getLogger("cua.agent.telemetry") + logger.info("Successfully imported telemetry") +except ImportError as e: + logger = logging.getLogger("cua.agent.telemetry") + logger.warning(f"Could not import telemetry: {e}") + TELEMETRY_AVAILABLE = False + + +# Local fallbacks in case core telemetry isn't available +def _noop(*args: Any, **kwargs: Any) -> None: + """No-op function for when telemetry is not available.""" + pass + + +logger = logging.getLogger("cua.agent.telemetry") + +# If telemetry isn't available, use no-op functions +if not TELEMETRY_AVAILABLE: + logger.debug("Telemetry not available, using no-op functions") + record_event = _noop # type: ignore + increment_counter = _noop # type: ignore + set_dimension = _noop # type: ignore + get_telemetry_client = lambda: None # type: ignore + flush = _noop # type: ignore + is_telemetry_enabled = lambda: False # type: ignore + is_telemetry_globally_disabled = lambda: True # type: ignore + +# Get system info once to use in telemetry +SYSTEM_INFO = { + "os": platform.system().lower(), + "os_version": platform.release(), + "python_version": platform.python_version(), +} + + +def enable_telemetry() -> bool: + """Enable telemetry if available. + + Returns: + bool: True if telemetry was successfully enabled, False otherwise + """ + global TELEMETRY_AVAILABLE + + # Check if globally disabled using core function + if TELEMETRY_AVAILABLE and is_telemetry_globally_disabled(): + logger.info("Telemetry is globally disabled via environment variable - cannot enable") + return False + + # Already enabled + if TELEMETRY_AVAILABLE: + return True + + # Try to import and enable + try: + from core.telemetry import ( + record_event, + increment, + get_telemetry_client, + flush, + is_telemetry_globally_disabled, + ) + + # Check again after import + if is_telemetry_globally_disabled(): + logger.info("Telemetry is globally disabled via environment variable - cannot enable") + return False + + TELEMETRY_AVAILABLE = True + logger.info("Telemetry successfully enabled") + return True + except ImportError as e: + logger.warning(f"Could not enable telemetry: {e}") + return False + + +def is_telemetry_enabled() -> bool: + """Check if telemetry is enabled. + + Returns: + bool: True if telemetry is enabled, False otherwise + """ + # Use the core function if available, otherwise use our local flag + if TELEMETRY_AVAILABLE: + from core.telemetry import is_telemetry_enabled as core_is_enabled + + return core_is_enabled() + return False + + +def record_agent_initialization() -> None: + """Record when an agent instance is initialized.""" + if TELEMETRY_AVAILABLE and is_telemetry_enabled(): + record_event("agent_initialized", SYSTEM_INFO) + + # Set dimensions that will be attached to all events + set_dimension("os", SYSTEM_INFO["os"]) + set_dimension("os_version", SYSTEM_INFO["os_version"]) + set_dimension("python_version", SYSTEM_INFO["python_version"]) diff --git a/libs/agent/agent/core/tools/__init__.py b/libs/agent/agent/core/tools/__init__.py new file mode 100644 index 00000000..29c017e2 --- /dev/null +++ b/libs/agent/agent/core/tools/__init__.py @@ -0,0 +1,21 @@ +"""Core tools package.""" + +from .base import BaseTool, ToolResult, ToolError, ToolFailure, CLIResult +from .bash import BaseBashTool +from .collection import ToolCollection +from .computer import BaseComputerTool +from .edit import BaseEditTool +from .manager import BaseToolManager + +__all__ = [ + "BaseTool", + "ToolResult", + "ToolError", + "ToolFailure", + "CLIResult", + "BaseBashTool", + "BaseComputerTool", + "BaseEditTool", + "ToolCollection", + "BaseToolManager", +] diff --git a/libs/agent/agent/core/tools/base.py b/libs/agent/agent/core/tools/base.py new file mode 100644 index 00000000..84f32f0b --- /dev/null +++ b/libs/agent/agent/core/tools/base.py @@ -0,0 +1,74 @@ +"""Abstract base classes for tools that can be used with any provider.""" + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, fields, replace +from typing import Any, Dict + + +class BaseTool(metaclass=ABCMeta): + """Abstract base class for provider-agnostic tools.""" + + name: str + + @abstractmethod + async def __call__(self, **kwargs) -> Any: + """Executes the tool with the given arguments.""" + ... + + @abstractmethod + def to_params(self) -> Dict[str, Any]: + """Convert tool to provider-specific API parameters. + + Returns: + Dictionary with tool parameters specific to the LLM provider + """ + raise NotImplementedError + + +@dataclass(kw_only=True, frozen=True) +class ToolResult: + """Represents the result of a tool execution.""" + + output: str | None = None + error: str | None = None + base64_image: str | None = None + system: str | None = None + content: list[dict] | None = None + + def __bool__(self): + return any(getattr(self, field.name) for field in fields(self)) + + def __add__(self, other: "ToolResult"): + def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True): + if field and other_field: + if concatenate: + return field + other_field + raise ValueError("Cannot combine tool results") + return field or other_field + + return ToolResult( + output=combine_fields(self.output, other.output), + error=combine_fields(self.error, other.error), + base64_image=combine_fields(self.base64_image, other.base64_image, False), + system=combine_fields(self.system, other.system), + content=self.content or other.content, # Use first non-None content + ) + + def replace(self, **kwargs): + """Returns a new ToolResult with the given fields replaced.""" + return replace(self, **kwargs) + + +class CLIResult(ToolResult): + """A ToolResult that can be rendered as a CLI output.""" + + +class ToolFailure(ToolResult): + """A ToolResult that represents a failure.""" + + +class ToolError(Exception): + """Raised when a tool encounters an error.""" + + def __init__(self, message): + self.message = message diff --git a/libs/agent/agent/core/tools/bash.py b/libs/agent/agent/core/tools/bash.py new file mode 100644 index 00000000..00c171ac --- /dev/null +++ b/libs/agent/agent/core/tools/bash.py @@ -0,0 +1,52 @@ +"""Abstract base bash/shell tool implementation.""" + +import asyncio +import logging +from abc import abstractmethod +from typing import Any, Dict, Tuple + +from computer.computer import Computer + +from .base import BaseTool, ToolResult + + +class BaseBashTool(BaseTool): + """Base class for bash/shell command execution tools across different providers.""" + + name = "bash" + logger = logging.getLogger(__name__) + computer: Computer + + def __init__(self, computer: Computer): + """Initialize the BashTool. + + Args: + computer: Computer instance, may be used for related operations + """ + self.computer = computer + + async def run_command(self, command: str) -> Tuple[int, str, str]: + """Run a shell command and return exit code, stdout, and stderr. + + Args: + command: Shell command to execute + + Returns: + Tuple containing (exit_code, stdout, stderr) + """ + try: + process = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + return process.returncode or 0, stdout.decode(), stderr.decode() + except Exception as e: + self.logger.error(f"Error running command: {str(e)}") + return 1, "", str(e) + + @abstractmethod + async def __call__(self, **kwargs) -> ToolResult: + """Execute the tool with the provided arguments.""" + raise NotImplementedError diff --git a/libs/agent/agent/core/tools/collection.py b/libs/agent/agent/core/tools/collection.py new file mode 100644 index 00000000..d14b35d5 --- /dev/null +++ b/libs/agent/agent/core/tools/collection.py @@ -0,0 +1,46 @@ +"""Collection classes for managing multiple tools.""" + +from typing import Any, Dict, List, Type + +from .base import ( + BaseTool, + ToolError, + ToolFailure, + ToolResult, +) + + +class ToolCollection: + """A collection of tools that can be used with any provider.""" + + def __init__(self, *tools: BaseTool): + self.tools = tools + self.tool_map = {tool.name: tool for tool in tools} + + def to_params(self) -> List[Dict[str, Any]]: + """Convert all tools to provider-specific parameters. + + Returns: + List of dictionaries with tool parameters + """ + return [tool.to_params() for tool in self.tools] + + async def run(self, *, name: str, tool_input: Dict[str, Any]) -> ToolResult: + """Run a tool with the given input. + + Args: + name: Name of the tool to run + tool_input: Input parameters for the tool + + Returns: + Result of the tool execution + """ + tool = self.tool_map.get(name) + if not tool: + return ToolFailure(error=f"Tool {name} is invalid") + try: + return await tool(**tool_input) + except ToolError as e: + return ToolFailure(error=e.message) + except Exception as e: + return ToolFailure(error=f"Unexpected error in tool {name}: {str(e)}") diff --git a/libs/agent/agent/core/tools/computer.py b/libs/agent/agent/core/tools/computer.py new file mode 100644 index 00000000..b1f1334e --- /dev/null +++ b/libs/agent/agent/core/tools/computer.py @@ -0,0 +1,113 @@ +"""Abstract base computer tool implementation.""" + +import asyncio +import base64 +import io +import logging +from abc import abstractmethod +from typing import Any, Dict, Optional, Tuple + +from PIL import Image +from computer.computer import Computer + +from .base import BaseTool, ToolError, ToolResult + + +class BaseComputerTool(BaseTool): + """Base class for computer interaction tools across different providers.""" + + name = "computer" + logger = logging.getLogger(__name__) + + width: Optional[int] = None + height: Optional[int] = None + display_num: Optional[int] = None + computer: Computer + + _screenshot_delay = 1.0 # Default delay for most platforms + _scaling_enabled = True + + def __init__(self, computer: Computer): + """Initialize the ComputerTool. + + Args: + computer: Computer instance for screen interactions + """ + self.computer = computer + + async def initialize_dimensions(self): + """Initialize screen dimensions from the computer interface.""" + display_size = await self.computer.interface.get_screen_size() + self.width = display_size["width"] + self.height = display_size["height"] + self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}") + + @property + def options(self) -> Dict[str, Any]: + """Get the options for the tool. + + Returns: + Dictionary with tool options + """ + if self.width is None or self.height is None: + raise RuntimeError( + "Screen dimensions not initialized. Call initialize_dimensions() first." + ) + return { + "display_width_px": self.width, + "display_height_px": self.height, + "display_number": self.display_num, + } + + async def resize_screenshot_if_needed(self, screenshot: bytes) -> bytes: + """Resize a screenshot to match the expected dimensions. + + Args: + screenshot: Raw screenshot data + + Returns: + Resized screenshot data + """ + if self.width is None or self.height is None: + raise ToolError("Screen dimensions not initialized") + + try: + img = Image.open(io.BytesIO(screenshot)) + if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info): + img = img.convert("RGB") + + # Resize if dimensions don't match + if img.size != (self.width, self.height): + self.logger.info( + f"Scaling image from {img.size} to {self.width}x{self.height} to match screen dimensions" + ) + img = img.resize((self.width, self.height), Image.Resampling.LANCZOS) + + # Save back to bytes + buffer = io.BytesIO() + img.save(buffer, format="PNG") + return buffer.getvalue() + + return screenshot + except Exception as e: + self.logger.error(f"Error during screenshot resizing: {str(e)}") + raise ToolError(f"Failed to resize screenshot: {str(e)}") + + async def screenshot(self) -> ToolResult: + """Take a screenshot and return it as a ToolResult with base64-encoded image. + + Returns: + ToolResult with the screenshot + """ + try: + screenshot = await self.computer.interface.screenshot() + screenshot = await self.resize_screenshot_if_needed(screenshot) + return ToolResult(base64_image=base64.b64encode(screenshot).decode()) + except Exception as e: + self.logger.error(f"Error taking screenshot: {str(e)}") + return ToolResult(error=f"Failed to take screenshot: {str(e)}") + + @abstractmethod + async def __call__(self, **kwargs) -> ToolResult: + """Execute the tool with the provided arguments.""" + raise NotImplementedError diff --git a/libs/agent/agent/core/tools/edit.py b/libs/agent/agent/core/tools/edit.py new file mode 100644 index 00000000..f57ba026 --- /dev/null +++ b/libs/agent/agent/core/tools/edit.py @@ -0,0 +1,67 @@ +"""Abstract base edit tool implementation.""" + +import asyncio +import logging +import os +from abc import abstractmethod +from pathlib import Path +from typing import Any, Dict, Optional + +from computer.computer import Computer + +from .base import BaseTool, ToolError, ToolResult + + +class BaseEditTool(BaseTool): + """Base class for text editor tools across different providers.""" + + name = "edit" + logger = logging.getLogger(__name__) + computer: Computer + + def __init__(self, computer: Computer): + """Initialize the EditTool. + + Args: + computer: Computer instance, may be used for related operations + """ + self.computer = computer + + async def read_file(self, path: str) -> str: + """Read a file and return its contents. + + Args: + path: Path to the file to read + + Returns: + File contents as a string + """ + try: + path_obj = Path(path) + if not path_obj.exists(): + raise ToolError(f"File does not exist: {path}") + return path_obj.read_text() + except Exception as e: + self.logger.error(f"Error reading file: {str(e)}") + raise ToolError(f"Failed to read file: {str(e)}") + + async def write_file(self, path: str, content: str) -> None: + """Write content to a file. + + Args: + path: Path to the file to write + content: Content to write to the file + """ + try: + path_obj = Path(path) + # Create parent directories if they don't exist + path_obj.parent.mkdir(parents=True, exist_ok=True) + path_obj.write_text(content) + except Exception as e: + self.logger.error(f"Error writing file: {str(e)}") + raise ToolError(f"Failed to write file: {str(e)}") + + @abstractmethod + async def __call__(self, **kwargs) -> ToolResult: + """Execute the tool with the provided arguments.""" + raise NotImplementedError diff --git a/libs/agent/agent/core/tools/manager.py b/libs/agent/agent/core/tools/manager.py new file mode 100644 index 00000000..ab614af0 --- /dev/null +++ b/libs/agent/agent/core/tools/manager.py @@ -0,0 +1,56 @@ +"""Tool manager for initializing and running tools.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from computer.computer import Computer + +from .base import BaseTool, ToolResult +from .collection import ToolCollection + + +class BaseToolManager(ABC): + """Base class for tool managers across different providers.""" + + def __init__(self, computer: Computer): + """Initialize the tool manager. + + Args: + computer: Computer instance for computer-related tools + """ + self.computer = computer + self.tools: ToolCollection | None = None + + @abstractmethod + def _initialize_tools(self) -> ToolCollection: + """Initialize all available tools.""" + ... + + async def initialize(self) -> None: + """Initialize tool-specific requirements and create tool collection.""" + await self._initialize_tools_specific() + self.tools = self._initialize_tools() + + @abstractmethod + async def _initialize_tools_specific(self) -> None: + """Initialize provider-specific tool requirements.""" + ... + + @abstractmethod + def get_tool_params(self) -> List[Dict[str, Any]]: + """Get tool parameters for API calls.""" + ... + + async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult: + """Execute a tool with the given input. + + Args: + name: Name of the tool to execute + tool_input: Input parameters for the tool + + Returns: + Result of the tool execution + """ + if self.tools is None: + raise RuntimeError("Tools not initialized. Call initialize() first.") + return await self.tools.run(name=name, tool_input=tool_input) diff --git a/libs/agent/agent/providers/__init__.py b/libs/agent/agent/providers/__init__.py new file mode 100644 index 00000000..99caa419 --- /dev/null +++ b/libs/agent/agent/providers/__init__.py @@ -0,0 +1,4 @@ +"""Provider implementations for different AI services.""" + +# Import specific providers only when needed to avoid circular imports +__all__ = [] # Let each provider module handle its own exports \ No newline at end of file diff --git a/libs/agent/agent/providers/anthropic/__init__.py b/libs/agent/agent/providers/anthropic/__init__.py new file mode 100644 index 00000000..10732ea7 --- /dev/null +++ b/libs/agent/agent/providers/anthropic/__init__.py @@ -0,0 +1,6 @@ +"""Anthropic provider implementation.""" + +from .loop import AnthropicLoop +from .types import LLMProvider + +__all__ = ["AnthropicLoop", "LLMProvider"] diff --git a/libs/agent/agent/providers/anthropic/api/client.py b/libs/agent/agent/providers/anthropic/api/client.py new file mode 100644 index 00000000..f5a21e6f --- /dev/null +++ b/libs/agent/agent/providers/anthropic/api/client.py @@ -0,0 +1,219 @@ +from typing import Any +import httpx +import asyncio +from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex +from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaToolUnionParam +from ..types import LLMProvider +from .logging import log_api_interaction +import random +import logging + +logger = logging.getLogger(__name__) + + +class APIConnectionError(Exception): + """Error raised when there are connection issues with the API.""" + + pass + + +class BaseAnthropicClient: + """Base class for Anthropic API clients.""" + + MAX_RETRIES = 10 + INITIAL_RETRY_DELAY = 1.0 + MAX_RETRY_DELAY = 60.0 + JITTER_FACTOR = 0.1 + + async def create_message( + self, + *, + messages: list[BetaMessageParam], + system: list[Any], + tools: list[BetaToolUnionParam], + max_tokens: int, + betas: list[str], + ) -> BetaMessage: + """Create a message using the Anthropic API.""" + raise NotImplementedError + + async def _make_api_call_with_retries(self, api_call): + """Make an API call with exponential backoff retry logic. + + Args: + api_call: Async function that makes the actual API call + + Returns: + API response + + Raises: + APIConnectionError: If all retries fail + """ + retry_count = 0 + last_error = None + + while retry_count < self.MAX_RETRIES: + try: + return await api_call() + except Exception as e: + last_error = e + retry_count += 1 + + if retry_count == self.MAX_RETRIES: + break + + # Calculate delay with exponential backoff and jitter + delay = min( + self.INITIAL_RETRY_DELAY * (2 ** (retry_count - 1)), self.MAX_RETRY_DELAY + ) + # Add jitter to avoid thundering herd + jitter = delay * self.JITTER_FACTOR * (2 * random.random() - 1) + final_delay = delay + jitter + + logger.info( + f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) " + f"in {final_delay:.2f} seconds after error: {str(e)}" + ) + await asyncio.sleep(final_delay) + + raise APIConnectionError( + f"Failed after {self.MAX_RETRIES} retries. " f"Last error: {str(last_error)}" + ) + + +class AnthropicDirectClient(BaseAnthropicClient): + """Direct Anthropic API client implementation.""" + + def __init__(self, api_key: str, model: str): + self.model = model + self.client = Anthropic(api_key=api_key, http_client=self._create_http_client()) + + def _create_http_client(self) -> httpx.Client: + """Create an HTTP client with appropriate settings.""" + return httpx.Client( + verify=True, + timeout=httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0), + transport=httpx.HTTPTransport( + retries=3, + verify=True, + limits=httpx.Limits(max_keepalive_connections=5, max_connections=10), + ), + ) + + async def create_message( + self, + *, + messages: list[BetaMessageParam], + system: list[Any], + tools: list[BetaToolUnionParam], + max_tokens: int, + betas: list[str], + ) -> BetaMessage: + """Create a message using the direct Anthropic API with retry logic.""" + + async def api_call(): + response = self.client.beta.messages.with_raw_response.create( + max_tokens=max_tokens, + messages=messages, + model=self.model, + system=system, + tools=tools, + betas=betas, + ) + log_api_interaction(response.http_response.request, response.http_response, None) + return response.parse() + + try: + return await self._make_api_call_with_retries(api_call) + except Exception as e: + log_api_interaction(None, None, e) + raise + + +class AnthropicVertexClient(BaseAnthropicClient): + """Google Cloud Vertex AI implementation of Anthropic client.""" + + def __init__(self, model: str): + self.model = model + self.client = AnthropicVertex() + + async def create_message( + self, + *, + messages: list[BetaMessageParam], + system: list[Any], + tools: list[BetaToolUnionParam], + max_tokens: int, + betas: list[str], + ) -> BetaMessage: + """Create a message using Vertex AI with retry logic.""" + + async def api_call(): + response = self.client.beta.messages.with_raw_response.create( + max_tokens=max_tokens, + messages=messages, + model=self.model, + system=system, + tools=tools, + betas=betas, + ) + log_api_interaction(response.http_response.request, response.http_response, None) + return response.parse() + + try: + return await self._make_api_call_with_retries(api_call) + except Exception as e: + log_api_interaction(None, None, e) + raise + + +class AnthropicBedrockClient(BaseAnthropicClient): + """AWS Bedrock implementation of Anthropic client.""" + + def __init__(self, model: str): + self.model = model + self.client = AnthropicBedrock() + + async def create_message( + self, + *, + messages: list[BetaMessageParam], + system: list[Any], + tools: list[BetaToolUnionParam], + max_tokens: int, + betas: list[str], + ) -> BetaMessage: + """Create a message using AWS Bedrock with retry logic.""" + + async def api_call(): + response = self.client.beta.messages.with_raw_response.create( + max_tokens=max_tokens, + messages=messages, + model=self.model, + system=system, + tools=tools, + betas=betas, + ) + log_api_interaction(response.http_response.request, response.http_response, None) + return response.parse() + + try: + return await self._make_api_call_with_retries(api_call) + except Exception as e: + log_api_interaction(None, None, e) + raise + + +class AnthropicClientFactory: + """Factory for creating appropriate Anthropic client implementations.""" + + @staticmethod + def create_client(provider: LLMProvider, api_key: str, model: str) -> BaseAnthropicClient: + """Create an appropriate client based on the provider.""" + if provider == LLMProvider.ANTHROPIC: + return AnthropicDirectClient(api_key, model) + elif provider == LLMProvider.VERTEX: + return AnthropicVertexClient(model) + elif provider == LLMProvider.BEDROCK: + return AnthropicBedrockClient(model) + raise ValueError(f"Unsupported provider: {provider}") diff --git a/libs/agent/agent/providers/anthropic/api/logging.py b/libs/agent/agent/providers/anthropic/api/logging.py new file mode 100644 index 00000000..80584411 --- /dev/null +++ b/libs/agent/agent/providers/anthropic/api/logging.py @@ -0,0 +1,150 @@ +"""API logging functionality.""" + +import json +import logging +from datetime import datetime +from pathlib import Path +import httpx +from typing import Any + +logger = logging.getLogger(__name__) + +def _filter_base64_images(content: Any) -> Any: + """Filter out base64 image data from content. + + Args: + content: Content to filter + + Returns: + Filtered content with base64 data replaced by placeholder + """ + if isinstance(content, dict): + filtered = {} + for key, value in content.items(): + if ( + isinstance(value, dict) + and value.get("type") == "image" + and value.get("source", {}).get("type") == "base64" + ): + # Replace base64 data with placeholder + filtered[key] = { + **value, + "source": { + **value["source"], + "data": "" + } + } + else: + filtered[key] = _filter_base64_images(value) + return filtered + elif isinstance(content, list): + return [_filter_base64_images(item) for item in content] + return content + +def log_api_interaction( + request: httpx.Request | None, + response: httpx.Response | object | None, + error: Exception | None, + log_dir: Path = Path("/tmp/claude_logs") +) -> None: + """Log API request, response, and any errors in a structured way. + + Args: + request: The HTTP request if available + response: The HTTP response or response object + error: Any error that occurred + log_dir: Directory to store log files + """ + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") + + # Helper function to safely decode JSON content + def safe_json_decode(content): + if not content: + return None + try: + if isinstance(content, bytes): + return json.loads(content.decode()) + elif isinstance(content, str): + return json.loads(content) + elif isinstance(content, dict): + return content + return None + except json.JSONDecodeError: + return {"error": "Could not decode JSON", "raw": str(content)} + + # Process request content + request_content = None + if request and request.content: + request_content = safe_json_decode(request.content) + request_content = _filter_base64_images(request_content) + + # Process response content + response_content = None + if response: + if isinstance(response, httpx.Response): + try: + response_content = response.json() + except json.JSONDecodeError: + response_content = {"error": "Could not decode JSON", "raw": response.text} + else: + response_content = safe_json_decode(response) + response_content = _filter_base64_images(response_content) + + log_entry = { + "timestamp": timestamp, + "request": { + "method": request.method if request else None, + "url": str(request.url) if request else None, + "headers": dict(request.headers) if request else None, + "content": request_content, + } if request else None, + "response": { + "status_code": response.status_code if isinstance(response, httpx.Response) else None, + "headers": dict(response.headers) if isinstance(response, httpx.Response) else None, + "content": response_content, + } if response else None, + "error": { + "type": type(error).__name__ if error else None, + "message": str(error) if error else None, + } if error else None + } + + # Log to file with timestamp in filename + log_dir.mkdir(exist_ok=True) + log_file = log_dir / f"claude_api_{timestamp.replace(' ', '_').replace(':', '-')}.json" + + with open(log_file, 'w') as f: + json.dump(log_entry, f, indent=2) + + # Also log a summary to the console + if error: + logger.error(f"API Error at {timestamp}: {error}") + else: + logger.info( + f"API Call at {timestamp}: " + f"{request.method if request else 'No request'} -> " + f"{response.status_code if isinstance(response, httpx.Response) else 'No response'}" + ) + + # Log if there are any images in the content + if response_content: + image_count = count_images(response_content) + if image_count > 0: + logger.info(f"Response contains {image_count} images") + +def count_images(content: dict | list | Any) -> int: + """Count the number of images in the content. + + Args: + content: Content to search for images + + Returns: + Number of images found + """ + if isinstance(content, dict): + if content.get("type") == "image": + return 1 + return sum(count_images(v) for v in content.values()) + elif isinstance(content, list): + return sum(count_images(item) for item in content) + return 0 \ No newline at end of file diff --git a/libs/agent/agent/providers/anthropic/callbacks/manager.py b/libs/agent/agent/providers/anthropic/callbacks/manager.py new file mode 100644 index 00000000..04a2ec2c --- /dev/null +++ b/libs/agent/agent/providers/anthropic/callbacks/manager.py @@ -0,0 +1,55 @@ +from typing import Callable, Protocol +import httpx +from anthropic.types.beta import BetaContentBlockParam +from ..tools import ToolResult + +class APICallback(Protocol): + """Protocol for API callbacks.""" + def __call__(self, request: httpx.Request | None, + response: httpx.Response | object | None, + error: Exception | None) -> None: ... + +class ContentCallback(Protocol): + """Protocol for content callbacks.""" + def __call__(self, content: BetaContentBlockParam) -> None: ... + +class ToolCallback(Protocol): + """Protocol for tool callbacks.""" + def __call__(self, result: ToolResult, tool_id: str) -> None: ... + +class CallbackManager: + """Manages various callbacks for the agent system.""" + + def __init__( + self, + content_callback: ContentCallback, + tool_callback: ToolCallback, + api_callback: APICallback, + ): + """Initialize the callback manager. + + Args: + content_callback: Callback for content updates + tool_callback: Callback for tool execution results + api_callback: Callback for API interactions + """ + self.content_callback = content_callback + self.tool_callback = tool_callback + self.api_callback = api_callback + + def on_content(self, content: BetaContentBlockParam) -> None: + """Handle content updates.""" + self.content_callback(content) + + def on_tool_result(self, result: ToolResult, tool_id: str) -> None: + """Handle tool execution results.""" + self.tool_callback(result, tool_id) + + def on_api_interaction( + self, + request: httpx.Request | None, + response: httpx.Response | object | None, + error: Exception | None + ) -> None: + """Handle API interactions.""" + self.api_callback(request, response, error) \ No newline at end of file diff --git a/libs/agent/agent/providers/anthropic/loop.py b/libs/agent/agent/providers/anthropic/loop.py new file mode 100644 index 00000000..de6d5133 --- /dev/null +++ b/libs/agent/agent/providers/anthropic/loop.py @@ -0,0 +1,521 @@ +"""Anthropic-specific agent loop implementation.""" + +import logging +import asyncio +import json +import os +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, cast +import base64 +from datetime import datetime +from httpx import ConnectError, ReadTimeout + +# Anthropic-specific imports +from anthropic import AsyncAnthropic +from anthropic.types.beta import ( + BetaMessage, + BetaMessageParam, + BetaTextBlock, + BetaTextBlockParam, + BetaToolUseBlockParam, +) + +# Computer +from computer import Computer + +# Base imports +from ...core.loop import BaseLoop +from ...core.messages import ImageRetentionConfig + +# Anthropic provider-specific imports +from .api.client import AnthropicClientFactory, BaseAnthropicClient +from .tools.manager import ToolManager +from .messages.manager import MessageManager +from .callbacks.manager import CallbackManager +from .prompts import SYSTEM_PROMPT +from .types import LLMProvider +from .tools import ToolResult + +# Constants +COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24" +PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31" + +logger = logging.getLogger(__name__) + + +class AnthropicLoop(BaseLoop): + """Anthropic-specific implementation of the agent loop.""" + + def __init__( + self, + api_key: str, + model: str = "claude-3-7-sonnet-20250219", # Fixed model + computer: Optional[Computer] = None, + only_n_most_recent_images: Optional[int] = 2, + base_dir: Optional[str] = "trajectories", + max_retries: int = 3, + retry_delay: float = 1.0, + save_trajectory: bool = True, + **kwargs, + ): + """Initialize the Anthropic loop. + + Args: + api_key: Anthropic API key + model: Model name (fixed to claude-3-7-sonnet-20250219) + computer: Computer instance + only_n_most_recent_images: Maximum number of recent screenshots to include in API requests + base_dir: Base directory for saving experiment data + max_retries: Maximum number of retries for API calls + retry_delay: Delay between retries in seconds + save_trajectory: Whether to save trajectory data + """ + # Initialize base class + super().__init__( + computer=computer, + model=model, + api_key=api_key, + max_retries=max_retries, + retry_delay=retry_delay, + base_dir=base_dir, + save_trajectory=save_trajectory, + only_n_most_recent_images=only_n_most_recent_images, + **kwargs, + ) + + # Ensure model is always the fixed one + self.model = "claude-3-7-sonnet-20250219" + + # Anthropic-specific attributes + self.provider = LLMProvider.ANTHROPIC + self.client = None + self.retry_count = 0 + self.tool_manager = None + self.message_manager = None + self.callback_manager = None + + # Configure image retention + self.image_retention_config = ImageRetentionConfig( + num_images_to_keep=only_n_most_recent_images + ) + + # Message history + self.message_history = [] + + async def initialize_client(self) -> None: + """Initialize the Anthropic API client and tools.""" + try: + logger.info(f"Initializing Anthropic client with model {self.model}...") + + # Initialize client + self.client = AnthropicClientFactory.create_client( + provider=self.provider, api_key=self.api_key, model=self.model + ) + + # Initialize message manager + self.message_manager = MessageManager( + ImageRetentionConfig( + num_images_to_keep=self.only_n_most_recent_images, enable_caching=True + ) + ) + + # Initialize callback manager + self.callback_manager = CallbackManager( + content_callback=self._handle_content, + tool_callback=self._handle_tool_result, + api_callback=self._handle_api_interaction, + ) + + # Initialize tool manager + self.tool_manager = ToolManager(self.computer) + await self.tool_manager.initialize() + + logger.info(f"Initialized Anthropic client with model {self.model}") + except Exception as e: + logger.error(f"Error initializing Anthropic client: {str(e)}") + self.client = None + raise RuntimeError(f"Failed to initialize Anthropic client: {str(e)}") + + async def _process_screen( + self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]] + ) -> None: + """Process screen information and add to messages. + + Args: + parsed_screen: Dictionary containing parsed screen info + messages: List of messages to update + """ + try: + # Extract screenshot from parsed screen + screenshot_base64 = parsed_screen.get("screenshot_base64") + + if screenshot_base64: + # Remove data URL prefix if present + if "," in screenshot_base64: + screenshot_base64 = screenshot_base64.split(",")[1] + + # Create Anthropic-compatible message with image + screen_info_msg = { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": screenshot_base64, + }, + } + ], + } + + # Add screen info message to messages + messages.append(screen_info_msg) + + except Exception as e: + logger.error(f"Error processing screen info: {str(e)}") + raise + + async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]: + """Run the agent loop with provided messages. + + Args: + messages: List of message objects + + Yields: + Dict containing response data + """ + try: + logger.info("Starting Anthropic loop run") + + # Reset message history and add new messages + self.message_history = [] + self.message_history.extend(messages) + + # Create queue for response streaming + queue = asyncio.Queue() + + # Ensure client is initialized + if self.client is None or self.tool_manager is None: + logger.info("Initializing client...") + await self.initialize_client() + if self.client is None: + raise RuntimeError("Failed to initialize client") + logger.info("Client initialized successfully") + + # Start loop in background task + loop_task = asyncio.create_task(self._run_loop(queue)) + + # Process and yield messages as they arrive + while True: + try: + item = await queue.get() + if item is None: # Stop signal + break + yield item + queue.task_done() + except Exception as e: + logger.error(f"Error processing queue item: {str(e)}") + continue + + # Wait for loop to complete + await loop_task + + # Send completion message + yield { + "role": "assistant", + "content": "Task completed successfully.", + "metadata": {"title": "✅ Complete"}, + } + + except Exception as e: + logger.error(f"Error executing task: {str(e)}") + yield { + "role": "assistant", + "content": f"Error: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + + async def _run_loop(self, queue: asyncio.Queue) -> None: + """Run the agent loop with current message history. + + Args: + queue: Queue for response streaming + """ + try: + while True: + # Get up-to-date screen information + parsed_screen = await self._get_parsed_screen_som() + + # Process screen info and update messages + await self._process_screen(parsed_screen, self.message_history) + + # Prepare messages and make API call + prepared_messages = self.message_manager.prepare_messages( + cast(List[BetaMessageParam], self.message_history.copy()) + ) + + # Create new turn directory for this API call + self._create_turn_dir() + + # Make API call + response = await self._make_api_call(prepared_messages) + + # Handle the response + if not await self._handle_response(response, self.message_history): + break + + # Signal completion + await queue.put(None) + + except Exception as e: + logger.error(f"Error in _run_loop: {str(e)}") + await queue.put( + { + "role": "assistant", + "content": f"Error in agent loop: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + ) + await queue.put(None) + + async def _make_api_call(self, messages: List[BetaMessageParam]) -> BetaMessage: + """Make API call to Anthropic with retry logic. + + Args: + messages: List of messages to send to the API + + Returns: + API response + """ + last_error = None + + for attempt in range(self.max_retries): + try: + # Log request + request_data = { + "messages": messages, + "max_tokens": self.max_tokens, + "system": SYSTEM_PROMPT, + } + self._log_api_call("request", request_data) + + # Setup betas and system + system = BetaTextBlockParam( + type="text", + text=SYSTEM_PROMPT, + ) + + betas = [COMPUTER_USE_BETA_FLAG] + # Temporarily disable prompt caching due to "A maximum of 4 blocks with cache_control may be provided" error + # if self.message_manager.image_retention_config.enable_caching: + # betas.append(PROMPT_CACHING_BETA_FLAG) + # system["cache_control"] = {"type": "ephemeral"} + + # Make API call + response = await self.client.create_message( + messages=messages, + system=[system], + tools=self.tool_manager.get_tool_params(), + max_tokens=self.max_tokens, + betas=betas, + ) + + # Log success response + self._log_api_call("response", request_data, response) + + return response + except Exception as e: + last_error = e + logger.error( + f"Error in API call (attempt {attempt + 1}/{self.max_retries}): {str(e)}" + ) + self._log_api_call("error", {"messages": messages}, error=e) + + if attempt < self.max_retries - 1: + await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff + continue + + # If we get here, all retries failed + error_message = f"API call failed after {self.max_retries} attempts" + if last_error: + error_message += f": {str(last_error)}" + + logger.error(error_message) + raise RuntimeError(error_message) + + async def _handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool: + """Handle the Anthropic API response. + + Args: + response: API response + messages: List of messages to update + + Returns: + True if the loop should continue, False otherwise + """ + try: + # Convert response to parameter format + response_params = self._response_to_params(response) + + # Add response to messages + messages.append( + { + "role": "assistant", + "content": response_params, + } + ) + + # Handle tool use blocks and collect results + tool_result_content = [] + for content_block in response_params: + # Notify callback of content + self.callback_manager.on_content(content_block) + + # Handle tool use + if content_block.get("type") == "tool_use": + result = await self.tool_manager.execute_tool( + name=content_block["name"], + tool_input=cast(Dict[str, Any], content_block["input"]), + ) + + # Create tool result and add to content + tool_result = self._make_tool_result(result, content_block["id"]) + tool_result_content.append(tool_result) + + # Notify callback of tool result + self.callback_manager.on_tool_result(result, content_block["id"]) + + # If no tool results, we're done + if not tool_result_content: + # Signal completion + self.callback_manager.on_content({"type": "text", "text": ""}) + return False + + # Add tool results to message history + messages.append({"content": tool_result_content, "role": "user"}) + return True + + except Exception as e: + logger.error(f"Error handling response: {str(e)}") + messages.append( + { + "role": "assistant", + "content": f"Error: {str(e)}", + } + ) + return False + + def _response_to_params( + self, + response: BetaMessage, + ) -> List[Dict[str, Any]]: + """Convert API response to message parameters. + + Args: + response: API response message + + Returns: + List of content blocks + """ + result = [] + for block in response.content: + if isinstance(block, BetaTextBlock): + result.append({"type": "text", "text": block.text}) + else: + result.append(cast(Dict[str, Any], block.model_dump())) + return result + + def _make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]: + """Convert a tool result to API format. + + Args: + result: Tool execution result + tool_use_id: ID of the tool use + + Returns: + Formatted tool result + """ + if result.content: + return { + "type": "tool_result", + "content": result.content, + "tool_use_id": tool_use_id, + "is_error": bool(result.error), + } + + tool_result_content = [] + is_error = False + + if result.error: + is_error = True + tool_result_content = [ + { + "type": "text", + "text": self._maybe_prepend_system_tool_result(result, result.error), + } + ] + else: + if result.output: + tool_result_content.append( + { + "type": "text", + "text": self._maybe_prepend_system_tool_result(result, result.output), + } + ) + if result.base64_image: + tool_result_content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": result.base64_image, + }, + } + ) + + return { + "type": "tool_result", + "content": tool_result_content, + "tool_use_id": tool_use_id, + "is_error": is_error, + } + + def _maybe_prepend_system_tool_result(self, result: ToolResult, result_text: str) -> str: + """Prepend system information to tool result if available. + + Args: + result: Tool execution result + result_text: Text to prepend to + + Returns: + Text with system information prepended if available + """ + if result.system: + result_text = f"{result.system}\n{result_text}" + return result_text + + def _handle_content(self, content: Dict[str, Any]) -> None: + """Handle content updates from the assistant.""" + if content.get("type") == "text": + text = content.get("text", "") + if text == "": + return + + logger.info(f"Assistant: {text}") + + def _handle_tool_result(self, result: ToolResult, tool_id: str) -> None: + """Handle tool execution results.""" + if result.error: + logger.error(f"Tool {tool_id} error: {result.error}") + else: + logger.info(f"Tool {tool_id} output: {result.output}") + + def _handle_api_interaction( + self, request: Any, response: Any, error: Optional[Exception] + ) -> None: + """Handle API interactions.""" + if error: + logger.error(f"API error: {error}") + else: + logger.debug(f"API request: {request}") diff --git a/libs/agent/agent/providers/anthropic/messages/manager.py b/libs/agent/agent/providers/anthropic/messages/manager.py new file mode 100644 index 00000000..c5136135 --- /dev/null +++ b/libs/agent/agent/providers/anthropic/messages/manager.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass +from typing import cast +from anthropic.types.beta import ( + BetaMessageParam, + BetaCacheControlEphemeralParam, + BetaToolResultBlockParam, +) + + +@dataclass +class ImageRetentionConfig: + """Configuration for image retention in messages.""" + + num_images_to_keep: int | None = None + min_removal_threshold: int = 1 + enable_caching: bool = True + + def should_retain_images(self) -> bool: + """Check if image retention is enabled.""" + return self.num_images_to_keep is not None and self.num_images_to_keep > 0 + + +class MessageManager: + """Manages message preparation, including image retention and caching.""" + + def __init__(self, image_retention_config: ImageRetentionConfig): + """Initialize the message manager. + + Args: + image_retention_config: Configuration for image retention + """ + if image_retention_config.min_removal_threshold < 1: + raise ValueError("min_removal_threshold must be at least 1") + self.image_retention_config = image_retention_config + + def prepare_messages(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]: + """Prepare messages by applying image retention and caching as configured.""" + if self.image_retention_config.should_retain_images(): + self._filter_images(messages) + if self.image_retention_config.enable_caching: + self._inject_caching(messages) + return messages + + def _filter_images(self, messages: list[BetaMessageParam]) -> None: + """Filter messages to retain only the specified number of most recent images.""" + tool_result_blocks = cast( + list[BetaToolResultBlockParam], + [ + item + for message in messages + for item in (message["content"] if isinstance(message["content"], list) else []) + if isinstance(item, dict) and item.get("type") == "tool_result" + ], + ) + + total_images = sum( + 1 + for tool_result in tool_result_blocks + for content in tool_result.get("content", []) + if isinstance(content, dict) and content.get("type") == "image" + ) + + images_to_remove = total_images - (self.image_retention_config.num_images_to_keep or 0) + # Round down to nearest min_removal_threshold for better cache behavior + images_to_remove -= images_to_remove % self.image_retention_config.min_removal_threshold + + # Remove oldest images first + for tool_result in tool_result_blocks: + if isinstance(tool_result.get("content"), list): + new_content = [] + for content in tool_result.get("content", []): + if isinstance(content, dict) and content.get("type") == "image": + if images_to_remove > 0: + images_to_remove -= 1 + continue + new_content.append(content) + tool_result["content"] = new_content + + def _inject_caching(self, messages: list[BetaMessageParam]) -> None: + """Inject caching control for the most recent turns, limited to 3 blocks max to avoid API errors.""" + # Anthropic API allows a maximum of 4 blocks with cache_control + # We use 3 here to be safe, as the system block may also have cache_control + blocks_with_cache_control = 0 + max_cache_control_blocks = 3 + + for message in reversed(messages): + if message["role"] == "user" and isinstance(content := message["content"], list): + # Only add cache control to the latest message in each turn + if blocks_with_cache_control < max_cache_control_blocks: + blocks_with_cache_control += 1 + # Add cache control to the last content block only + if content and len(content) > 0: + content[-1]["cache_control"] = {"type": "ephemeral"} + else: + # Remove any existing cache control + if content and len(content) > 0: + content[-1].pop("cache_control", None) + + # Ensure we're not exceeding the limit by checking the total + if blocks_with_cache_control > max_cache_control_blocks: + # If we somehow exceeded the limit, remove excess cache controls + excess = blocks_with_cache_control - max_cache_control_blocks + for message in messages: + if excess <= 0: + break + + if message["role"] == "user" and isinstance(content := message["content"], list): + if content and len(content) > 0 and "cache_control" in content[-1]: + content[-1].pop("cache_control", None) + excess -= 1 diff --git a/libs/agent/agent/providers/anthropic/prompts.py b/libs/agent/agent/providers/anthropic/prompts.py new file mode 100644 index 00000000..90e35e7a --- /dev/null +++ b/libs/agent/agent/providers/anthropic/prompts.py @@ -0,0 +1,20 @@ +"""System prompts for Anthropic provider.""" + +from datetime import datetime +import platform + +SYSTEM_PROMPT = f""" +* You are utilising a macOS virtual machine using ARM architecture with internet access and Safari as default browser. +* You can feel free to install macOS applications with your bash tool. Use curl instead of wget. +* Using bash tool you can start GUI applications. GUI apps run with bash tool will appear within your desktop environment, but they may take some time to appear. Take a screenshot to confirm it did. +* When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B -A ` to confirm output. +* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. +* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. +* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. + + + +* Plan at maximum 1 step each time, and evaluate the result of each step before proceeding. Hold back if you're not sure about the result of the step. +* If you're not sure about the location of an application, use start the app using the bash tool. +* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. +""" diff --git a/libs/agent/agent/providers/anthropic/tools/__init__.py b/libs/agent/agent/providers/anthropic/tools/__init__.py new file mode 100644 index 00000000..93ef41bc --- /dev/null +++ b/libs/agent/agent/providers/anthropic/tools/__init__.py @@ -0,0 +1,33 @@ +"""Anthropic-specific tools for agent.""" + +from .base import ( + BaseAnthropicTool, + ToolResult, + ToolError, + ToolFailure, + CLIResult, + AnthropicToolResult, + AnthropicToolError, + AnthropicToolFailure, + AnthropicCLIResult, +) +from .bash import BashTool +from .computer import ComputerTool +from .edit import EditTool +from .manager import ToolManager + +__all__ = [ + "BaseAnthropicTool", + "ToolResult", + "ToolError", + "ToolFailure", + "CLIResult", + "AnthropicToolResult", + "AnthropicToolError", + "AnthropicToolFailure", + "AnthropicCLIResult", + "BashTool", + "ComputerTool", + "EditTool", + "ToolManager", +] diff --git a/libs/agent/agent/providers/anthropic/tools/base.py b/libs/agent/agent/providers/anthropic/tools/base.py new file mode 100644 index 00000000..2edbfeff --- /dev/null +++ b/libs/agent/agent/providers/anthropic/tools/base.py @@ -0,0 +1,88 @@ +"""Anthropic-specific tool base classes.""" + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, fields, replace +from typing import Any, Dict + +from anthropic.types.beta import BetaToolUnionParam + +from ....core.tools.base import BaseTool, ToolError, ToolResult, ToolFailure, CLIResult + + +class BaseAnthropicTool(BaseTool, metaclass=ABCMeta): + """Abstract base class for Anthropic-defined tools.""" + + def __init__(self): + """Initialize the base Anthropic tool.""" + # No specific initialization needed yet, but included for future extensibility + pass + + @abstractmethod + async def __call__(self, **kwargs) -> Any: + """Executes the tool with the given arguments.""" + ... + + @abstractmethod + def to_params(self) -> Dict[str, Any]: + """Convert tool to Anthropic-specific API parameters. + + Returns: + Dictionary with tool parameters for Anthropic API + """ + raise NotImplementedError + + +@dataclass(kw_only=True, frozen=True) +class ToolResult: + """Represents the result of a tool execution.""" + + output: str | None = None + error: str | None = None + base64_image: str | None = None + system: str | None = None + content: list[dict] | None = None + + def __bool__(self): + return any(getattr(self, field.name) for field in fields(self)) + + def __add__(self, other: "ToolResult"): + def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True): + if field and other_field: + if concatenate: + return field + other_field + raise ValueError("Cannot combine tool results") + return field or other_field + + return ToolResult( + output=combine_fields(self.output, other.output), + error=combine_fields(self.error, other.error), + base64_image=combine_fields(self.base64_image, other.base64_image, False), + system=combine_fields(self.system, other.system), + content=self.content or other.content, # Use first non-None content + ) + + def replace(self, **kwargs): + """Returns a new ToolResult with the given fields replaced.""" + return replace(self, **kwargs) + + +class CLIResult(ToolResult): + """A ToolResult that can be rendered as a CLI output.""" + + +class ToolFailure(ToolResult): + """A ToolResult that represents a failure.""" + + +class ToolError(Exception): + """Raised when a tool encounters an error.""" + + def __init__(self, message): + self.message = message + + +# Re-export the core tool classes with Anthropic-specific names for backward compatibility +AnthropicToolResult = ToolResult +AnthropicToolError = ToolError +AnthropicToolFailure = ToolFailure +AnthropicCLIResult = CLIResult diff --git a/libs/agent/agent/providers/anthropic/tools/bash.py b/libs/agent/agent/providers/anthropic/tools/bash.py new file mode 100644 index 00000000..00bdd572 --- /dev/null +++ b/libs/agent/agent/providers/anthropic/tools/bash.py @@ -0,0 +1,163 @@ +import asyncio +import os +from typing import ClassVar, Literal, Dict, Any +from computer.computer import Computer + +from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult +from ....core.tools.bash import BaseBashTool + + +class _BashSession: + """A session of a bash shell.""" + + _started: bool + _process: asyncio.subprocess.Process + + command: str = "/bin/bash" + _output_delay: float = 0.2 # seconds + _timeout: float = 120.0 # seconds + _sentinel: str = "<>" + + def __init__(self): + self._started = False + self._timed_out = False + + async def start(self): + if self._started: + return + + self._process = await asyncio.create_subprocess_shell( + self.command, + preexec_fn=os.setsid, + shell=True, + bufsize=0, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + self._started = True + + def stop(self): + """Terminate the bash shell.""" + if not self._started: + raise ToolError("Session has not started.") + if self._process.returncode is not None: + return + self._process.terminate() + + async def run(self, command: str): + """Execute a command in the bash shell.""" + if not self._started: + raise ToolError("Session has not started.") + if self._process.returncode is not None: + return ToolResult( + system="tool must be restarted", + error=f"bash has exited with returncode {self._process.returncode}", + ) + if self._timed_out: + raise ToolError( + f"timed out: bash has not returned in {self._timeout} seconds and must be restarted", + ) + + # we know these are not None because we created the process with PIPEs + assert self._process.stdin + assert self._process.stdout + assert self._process.stderr + + # send command to the process + self._process.stdin.write(command.encode() + f"; echo '{self._sentinel}'\n".encode()) + await self._process.stdin.drain() + + # read output from the process, until the sentinel is found + try: + async with asyncio.timeout(self._timeout): + while True: + await asyncio.sleep(self._output_delay) + # if we read directly from stdout/stderr, it will wait forever for + # EOF. use the StreamReader buffer directly instead. + output = ( + self._process.stdout._buffer.decode() + ) # pyright: ignore[reportAttributeAccessIssue] + if self._sentinel in output: + # strip the sentinel and break + output = output[: output.index(self._sentinel)] + break + except asyncio.TimeoutError: + self._timed_out = True + raise ToolError( + f"timed out: bash has not returned in {self._timeout} seconds and must be restarted", + ) from None + + if output.endswith("\n"): + output = output[:-1] + + error = self._process.stderr._buffer.decode() # pyright: ignore[reportAttributeAccessIssue] + if error.endswith("\n"): + error = error[:-1] + + # clear the buffers so that the next output can be read correctly + self._process.stdout._buffer.clear() # pyright: ignore[reportAttributeAccessIssue] + self._process.stderr._buffer.clear() # pyright: ignore[reportAttributeAccessIssue] + + return CLIResult(output=output, error=error) + + +class BashTool(BaseBashTool, BaseAnthropicTool): + """ + A tool that allows the agent to run bash commands. + The tool parameters are defined by Anthropic and are not editable. + """ + + name: ClassVar[Literal["bash"]] = "bash" + api_type: ClassVar[Literal["bash_20250124"]] = "bash_20250124" + _timeout: float = 120.0 # seconds + + def __init__(self, computer: Computer): + """Initialize the bash tool. + + Args: + computer: Computer instance for executing commands + """ + # Initialize the base bash tool first + BaseBashTool.__init__(self, computer) + # Then initialize the Anthropic tool + BaseAnthropicTool.__init__(self) + # Initialize bash session + self._session = _BashSession() + + async def __call__(self, command: str | None = None, restart: bool = False, **kwargs): + """Execute a bash command. + + Args: + command: The command to execute + restart: Whether to restart the shell (not used with computer interface) + + Returns: + Tool execution result + + Raises: + ToolError: If command execution fails + """ + if restart: + return ToolResult(system="Restart not needed with computer interface.") + + if command is None: + raise ToolError("no command provided.") + + try: + async with asyncio.timeout(self._timeout): + stdout, stderr = await self.computer.interface.run_command(command) + return CLIResult(output=stdout or "", error=stderr or "") + except asyncio.TimeoutError as e: + raise ToolError(f"Command timed out after {self._timeout} seconds") from e + except Exception as e: + raise ToolError(f"Failed to execute command: {str(e)}") + + def to_params(self) -> Dict[str, Any]: + """Convert tool to API parameters. + + Returns: + Dictionary with tool parameters + """ + return {"name": self.name, "type": self.api_type} diff --git a/libs/agent/agent/providers/anthropic/tools/collection.py b/libs/agent/agent/providers/anthropic/tools/collection.py new file mode 100644 index 00000000..c4e8c95c --- /dev/null +++ b/libs/agent/agent/providers/anthropic/tools/collection.py @@ -0,0 +1,34 @@ +"""Collection classes for managing multiple tools.""" + +from typing import Any + +from anthropic.types.beta import BetaToolUnionParam + +from .base import ( + BaseAnthropicTool, + ToolError, + ToolFailure, + ToolResult, +) + + +class ToolCollection: + """A collection of anthropic-defined tools.""" + + def __init__(self, *tools: BaseAnthropicTool): + self.tools = tools + self.tool_map = {tool.to_params()["name"]: tool for tool in tools} + + def to_params( + self, + ) -> list[BetaToolUnionParam]: + return [tool.to_params() for tool in self.tools] + + async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult: + tool = self.tool_map.get(name) + if not tool: + return ToolFailure(error=f"Tool {name} is invalid") + try: + return await tool(**tool_input) + except ToolError as e: + return ToolFailure(error=e.message) diff --git a/libs/agent/agent/providers/anthropic/tools/computer.py b/libs/agent/agent/providers/anthropic/tools/computer.py new file mode 100644 index 00000000..2d00b3c6 --- /dev/null +++ b/libs/agent/agent/providers/anthropic/tools/computer.py @@ -0,0 +1,550 @@ +import asyncio +import base64 +import io +import logging +from enum import StrEnum +from pathlib import Path +from typing import Literal, TypedDict, Any, Dict +import subprocess +from PIL import Image +from datetime import datetime + +from computer.computer import Computer + +from .base import BaseAnthropicTool, ToolError, ToolResult +from .run import run +from ....core.tools.computer import BaseComputerTool + +TYPING_DELAY_MS = 12 +TYPING_GROUP_SIZE = 50 + +Action = Literal[ + "key", + "type", + "mouse_move", + "left_click", + "left_click_drag", + "right_click", + "middle_click", + "double_click", + "screenshot", + "cursor_position", + "scroll", +] + + +class Resolution(TypedDict): + width: int + height: int + + +class ScalingSource(StrEnum): + COMPUTER = "computer" + API = "api" + + +class ComputerToolOptions(TypedDict): + display_height_px: int + display_width_px: int + display_number: int | None + + +def chunks(s: str, chunk_size: int) -> list[str]: + return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)] + + +class ComputerTool(BaseComputerTool, BaseAnthropicTool): + """ + A tool that allows the agent to interact with the screen, keyboard, and mouse of the current macOS computer. + The tool parameters are defined by Anthropic and are not editable. + """ + + name: Literal["computer"] = "computer" + api_type: Literal["computer_20250124"] = "computer_20250124" + width: int | None + height: int | None + display_num: int | None + computer: Computer # The CUA Computer instance + logger = logging.getLogger(__name__) + + _screenshot_delay = 1.0 # macOS is generally faster than X11 + _scaling_enabled = True + + @property + def options(self) -> ComputerToolOptions: + if self.width is None or self.height is None: + raise RuntimeError( + "Screen dimensions not initialized. Call initialize_dimensions() first." + ) + return { + "display_width_px": self.width, + "display_height_px": self.height, + "display_number": self.display_num, + } + + def to_params(self) -> Dict[str, Any]: + """Convert tool to API parameters. + + Returns: + Dictionary with tool parameters + """ + return {"name": self.name, "type": self.api_type, **self.options} + + def __init__(self, computer): + # Initialize the base computer tool first + BaseComputerTool.__init__(self, computer) + # Then initialize the Anthropic tool + BaseAnthropicTool.__init__(self) + + # Additional initialization + self.width = None # Will be initialized from computer interface + self.height = None # Will be initialized from computer interface + self.display_num = None + + async def initialize_dimensions(self): + """Initialize screen dimensions from the computer interface.""" + display_size = await self.computer.interface.get_screen_size() + self.width = display_size["width"] + self.height = display_size["height"] + self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}") + + async def __call__( + self, + *, + action: Action, + text: str | None = None, + coordinate: tuple[int, int] | None = None, + **kwargs, + ): + try: + # Ensure dimensions are initialized + if self.width is None or self.height is None: + await self.initialize_dimensions() + except Exception as e: + raise ToolError(f"Failed to initialize dimensions: {e}") + + if action in ("mouse_move", "left_click_drag"): + if coordinate is None: + raise ToolError(f"coordinate is required for {action}") + if text is not None: + raise ToolError(f"text is not accepted for {action}") + if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2: + raise ToolError(f"{coordinate} must be a tuple of length 2") + if not all(isinstance(i, int) and i >= 0 for i in coordinate): + raise ToolError(f"{coordinate} must be a tuple of non-negative ints") + + try: + x, y = coordinate + self.logger.info(f"Handling {action} action:") + self.logger.info(f" Coordinates: ({x}, {y})") + + # Take pre-action screenshot to get current dimensions + pre_screenshot = await self.computer.interface.screenshot() + pre_img = Image.open(io.BytesIO(pre_screenshot)) + + # Scale image to match screen dimensions if needed + if pre_img.size != (self.width, self.height): + self.logger.info( + f"Scaling image from {pre_img.size} to {self.width}x{self.height} to match screen dimensions" + ) + pre_img = pre_img.resize((self.width, self.height), Image.Resampling.LANCZOS) + + self.logger.info(f" Current dimensions: {pre_img.width}x{pre_img.height}") + + if action == "mouse_move": + self.logger.info(f"Moving cursor to ({x}, {y})") + await self.computer.interface.move_cursor(x, y) + elif action == "left_click_drag": + self.logger.info(f"Dragging from ({x}, {y})") + # First move to the position + await self.computer.interface.move_cursor(x, y) + # Then perform drag operation - check if drag_to exists or we need to use other methods + try: + if hasattr(self.computer.interface, "drag_to"): + await self.computer.interface.drag_to(x, y) + else: + # Alternative approach: press mouse down, move, release + await self.computer.interface.mouse_down() + await asyncio.sleep(0.2) + await self.computer.interface.move_cursor(x, y) + await asyncio.sleep(0.2) + await self.computer.interface.mouse_up() + except Exception as e: + self.logger.error(f"Error during drag operation: {str(e)}") + raise ToolError(f"Failed to perform drag: {str(e)}") + + # Wait briefly for any UI changes + await asyncio.sleep(0.5) + + # Take post-action screenshot + post_screenshot = await self.computer.interface.screenshot() + post_img = Image.open(io.BytesIO(post_screenshot)) + + # Scale post-action image if needed + if post_img.size != (self.width, self.height): + self.logger.info( + f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}" + ) + post_img = post_img.resize((self.width, self.height), Image.Resampling.LANCZOS) + buffer = io.BytesIO() + post_img.save(buffer, format="PNG") + post_screenshot = buffer.getvalue() + + return ToolResult( + output=f"{'Moved cursor to' if action == 'mouse_move' else 'Dragged to'} {x},{y}", + base64_image=base64.b64encode(post_screenshot).decode(), + ) + except Exception as e: + self.logger.error(f"Error during {action} action: {str(e)}") + raise ToolError(f"Failed to perform {action}: {str(e)}") + + elif action in ("left_click", "right_click", "double_click"): + if coordinate: + x, y = coordinate + self.logger.info(f"Handling {action} action:") + self.logger.info(f" Coordinates: ({x}, {y})") + + try: + # Take pre-action screenshot to get current dimensions + pre_screenshot = await self.computer.interface.screenshot() + pre_img = Image.open(io.BytesIO(pre_screenshot)) + + # Scale image to match screen dimensions if needed + if pre_img.size != (self.width, self.height): + self.logger.info( + f"Scaling image from {pre_img.size} to {self.width}x{self.height} to match screen dimensions" + ) + pre_img = pre_img.resize( + (self.width, self.height), Image.Resampling.LANCZOS + ) + # Save the scaled image back to bytes + buffer = io.BytesIO() + pre_img.save(buffer, format="PNG") + pre_screenshot = buffer.getvalue() + + self.logger.info(f" Current dimensions: {pre_img.width}x{pre_img.height}") + + # Perform the click action + if action == "left_click": + self.logger.info(f"Clicking at ({x}, {y})") + await self.computer.interface.move_cursor(x, y) + await self.computer.interface.left_click() + elif action == "right_click": + self.logger.info(f"Right clicking at ({x}, {y})") + await self.computer.interface.move_cursor(x, y) + await self.computer.interface.right_click() + elif action == "double_click": + self.logger.info(f"Double clicking at ({x}, {y})") + await self.computer.interface.move_cursor(x, y) + await self.computer.interface.double_click() + + # Wait briefly for any UI changes + await asyncio.sleep(0.5) + + # Take and save post-action screenshot + post_screenshot = await self.computer.interface.screenshot() + post_img = Image.open(io.BytesIO(post_screenshot)) + + # Scale post-action image if needed + if post_img.size != (self.width, self.height): + self.logger.info( + f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}" + ) + post_img = post_img.resize( + (self.width, self.height), Image.Resampling.LANCZOS + ) + buffer = io.BytesIO() + post_img.save(buffer, format="PNG") + post_screenshot = buffer.getvalue() + + return ToolResult( + output=f"Performed {action} at ({x}, {y})", + base64_image=base64.b64encode(post_screenshot).decode(), + ) + except Exception as e: + self.logger.error(f"Error during {action} action: {str(e)}") + raise ToolError(f"Failed to perform {action}: {str(e)}") + else: + try: + # Take pre-action screenshot + pre_screenshot = await self.computer.interface.screenshot() + pre_img = Image.open(io.BytesIO(pre_screenshot)) + + # Scale image if needed + if pre_img.size != (self.width, self.height): + self.logger.info( + f"Scaling image from {pre_img.size} to {self.width}x{self.height}" + ) + pre_img = pre_img.resize( + (self.width, self.height), Image.Resampling.LANCZOS + ) + + # Perform the click action + if action == "left_click": + self.logger.info("Performing left click at current position") + await self.computer.interface.left_click() + elif action == "right_click": + self.logger.info("Performing right click at current position") + await self.computer.interface.right_click() + elif action == "double_click": + self.logger.info("Performing double click at current position") + await self.computer.interface.double_click() + + # Wait briefly for any UI changes + await asyncio.sleep(0.5) + + # Take post-action screenshot + post_screenshot = await self.computer.interface.screenshot() + post_img = Image.open(io.BytesIO(post_screenshot)) + + # Scale post-action image if needed + if post_img.size != (self.width, self.height): + self.logger.info( + f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}" + ) + post_img = post_img.resize( + (self.width, self.height), Image.Resampling.LANCZOS + ) + buffer = io.BytesIO() + post_img.save(buffer, format="PNG") + post_screenshot = buffer.getvalue() + + return ToolResult( + output=f"Performed {action} at current position", + base64_image=base64.b64encode(post_screenshot).decode(), + ) + except Exception as e: + self.logger.error(f"Error during {action} action: {str(e)}") + raise ToolError(f"Failed to perform {action}: {str(e)}") + + elif action in ("key", "type"): + if text is None: + raise ToolError(f"text is required for {action}") + if coordinate is not None: + raise ToolError(f"coordinate is not accepted for {action}") + if not isinstance(text, str): + raise ToolError(f"{text} must be a string") + + try: + # Take pre-action screenshot + pre_screenshot = await self.computer.interface.screenshot() + pre_img = Image.open(io.BytesIO(pre_screenshot)) + + # Scale image if needed + if pre_img.size != (self.width, self.height): + self.logger.info( + f"Scaling image from {pre_img.size} to {self.width}x{self.height}" + ) + pre_img = pre_img.resize((self.width, self.height), Image.Resampling.LANCZOS) + + if action == "key": + # Special handling for page up/down on macOS + if text.lower() in ["pagedown", "page_down", "page down"]: + self.logger.info("Converting page down to fn+down for macOS") + await self.computer.interface.hotkey("fn", "down") + output_text = "fn+down" + elif text.lower() in ["pageup", "page_up", "page up"]: + self.logger.info("Converting page up to fn+up for macOS") + await self.computer.interface.hotkey("fn", "up") + output_text = "fn+up" + elif text == "fn+down": + self.logger.info("Using fn+down combination") + await self.computer.interface.hotkey("fn", "down") + output_text = text + elif text == "fn+up": + self.logger.info("Using fn+up combination") + await self.computer.interface.hotkey("fn", "up") + output_text = text + elif "+" in text: + # Handle hotkey combinations + keys = text.split("+") + self.logger.info(f"Pressing hotkey combination: {text}") + await self.computer.interface.hotkey(*keys) + output_text = text + else: + # Handle single key press + self.logger.info(f"Pressing key: {text}") + try: + await self.computer.interface.press(text) + output_text = text + except ValueError as e: + raise ToolError(f"Invalid key: {text}. {str(e)}") + + # Wait briefly for UI changes + await asyncio.sleep(0.5) + + # Take post-action screenshot + post_screenshot = await self.computer.interface.screenshot() + post_img = Image.open(io.BytesIO(post_screenshot)) + + # Scale post-action image if needed + if post_img.size != (self.width, self.height): + self.logger.info( + f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}" + ) + post_img = post_img.resize( + (self.width, self.height), Image.Resampling.LANCZOS + ) + buffer = io.BytesIO() + post_img.save(buffer, format="PNG") + post_screenshot = buffer.getvalue() + + return ToolResult( + output=f"Pressed key: {output_text}", + base64_image=base64.b64encode(post_screenshot).decode(), + ) + + elif action == "type": + self.logger.info(f"Typing text: {text}") + await self.computer.interface.type_text(text) + + # Wait briefly for UI changes + await asyncio.sleep(0.5) + + # Take post-action screenshot + post_screenshot = await self.computer.interface.screenshot() + post_img = Image.open(io.BytesIO(post_screenshot)) + + # Scale post-action image if needed + if post_img.size != (self.width, self.height): + self.logger.info( + f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}" + ) + post_img = post_img.resize( + (self.width, self.height), Image.Resampling.LANCZOS + ) + buffer = io.BytesIO() + post_img.save(buffer, format="PNG") + post_screenshot = buffer.getvalue() + + return ToolResult( + output=f"Typed text: {text}", + base64_image=base64.b64encode(post_screenshot).decode(), + ) + except Exception as e: + self.logger.error(f"Error during {action} action: {str(e)}") + raise ToolError(f"Failed to perform {action}: {str(e)}") + + elif action in ("screenshot", "cursor_position"): + if text is not None: + raise ToolError(f"text is not accepted for {action}") + if coordinate is not None: + raise ToolError(f"coordinate is not accepted for {action}") + + try: + if action == "screenshot": + # Take screenshot + screenshot = await self.computer.interface.screenshot() + img = Image.open(io.BytesIO(screenshot)) + + # Scale image if needed + if img.size != (self.width, self.height): + self.logger.info( + f"Scaling image from {img.size} to {self.width}x{self.height}" + ) + img = img.resize((self.width, self.height), Image.Resampling.LANCZOS) + buffer = io.BytesIO() + img.save(buffer, format="PNG") + screenshot = buffer.getvalue() + + return ToolResult(base64_image=base64.b64encode(screenshot).decode()) + + elif action == "cursor_position": + pos = await self.computer.interface.get_cursor_position() + return ToolResult(output=f"X={int(pos[0])},Y={int(pos[1])}") + + except Exception as e: + self.logger.error(f"Error during {action} action: {str(e)}") + raise ToolError(f"Failed to perform {action}: {str(e)}") + + elif action == "scroll": + # Implement scroll action + direction = kwargs.get("direction", "down") + amount = kwargs.get("amount", 10) + + if direction not in ["up", "down"]: + raise ToolError(f"Invalid scroll direction: {direction}. Must be 'up' or 'down'.") + + try: + if direction == "down": + # Scroll down (Page Down on macOS) + self.logger.info(f"Scrolling down, amount: {amount}") + # Use fn+down for page down on macOS + for _ in range(amount): + await self.computer.interface.hotkey("fn", "down") + await asyncio.sleep(0.1) + else: + # Scroll up (Page Up on macOS) + self.logger.info(f"Scrolling up, amount: {amount}") + # Use fn+up for page up on macOS + for _ in range(amount): + await self.computer.interface.hotkey("fn", "up") + await asyncio.sleep(0.1) + + # Wait briefly for UI changes + await asyncio.sleep(0.5) + + # Take post-action screenshot + post_screenshot = await self.computer.interface.screenshot() + post_img = Image.open(io.BytesIO(post_screenshot)) + + # Scale post-action image if needed + if post_img.size != (self.width, self.height): + self.logger.info( + f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}" + ) + post_img = post_img.resize((self.width, self.height), Image.Resampling.LANCZOS) + buffer = io.BytesIO() + post_img.save(buffer, format="PNG") + post_screenshot = buffer.getvalue() + + return ToolResult( + output=f"Scrolled {direction} by {amount} steps", + base64_image=base64.b64encode(post_screenshot).decode(), + ) + except Exception as e: + self.logger.error(f"Error during scroll action: {str(e)}") + raise ToolError(f"Failed to perform scroll: {str(e)}") + + raise ToolError(f"Invalid action: {action}") + + async def screenshot(self): + """Take a screenshot and return it as a base64-encoded string.""" + try: + screenshot = await self.computer.interface.screenshot() + img = Image.open(io.BytesIO(screenshot)) + + # Scale image if needed + if img.size != (self.width, self.height): + self.logger.info(f"Scaling image from {img.size} to {self.width}x{self.height}") + img = img.resize((self.width, self.height), Image.Resampling.LANCZOS) + buffer = io.BytesIO() + img.save(buffer, format="PNG") + screenshot = buffer.getvalue() + + return ToolResult(base64_image=base64.b64encode(screenshot).decode()) + except Exception as e: + self.logger.error(f"Error taking screenshot: {str(e)}") + return ToolResult(error=f"Failed to take screenshot: {str(e)}") + + async def shell(self, command: str, take_screenshot=False) -> ToolResult: + """Run a shell command and return the output, error, and optionally a screenshot.""" + try: + _, stdout, stderr = await run(command) + base64_image = None + + if take_screenshot: + # delay to let things settle before taking a screenshot + await asyncio.sleep(self._screenshot_delay) + screenshot_result = await self.screenshot() + if screenshot_result.error: + return ToolResult( + output=stdout, + error=f"{stderr}\nScreenshot error: {screenshot_result.error}", + ) + base64_image = screenshot_result.base64_image + + return ToolResult(output=stdout, error=stderr, base64_image=base64_image) + + except Exception as e: + return ToolResult(error=f"Shell command failed: {str(e)}") diff --git a/libs/agent/agent/providers/anthropic/tools/edit.py b/libs/agent/agent/providers/anthropic/tools/edit.py new file mode 100644 index 00000000..e4da1f85 --- /dev/null +++ b/libs/agent/agent/providers/anthropic/tools/edit.py @@ -0,0 +1,326 @@ +from collections import defaultdict +from pathlib import Path +from typing import Literal, get_args, Dict, Any +from computer.computer import Computer + +from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult +from ....core.tools.edit import BaseEditTool +from .run import maybe_truncate + +Command = Literal[ + "view", + "create", + "str_replace", + "insert", + "undo_edit", +] +SNIPPET_LINES: int = 4 + + +class EditTool(BaseEditTool, BaseAnthropicTool): + """ + An filesystem editor tool that allows the agent to view, create, and edit files. + The tool parameters are defined by Anthropic and are not editable. + """ + + api_type: Literal["text_editor_20250124"] = "text_editor_20250124" + name: Literal["str_replace_editor"] = "str_replace_editor" + _timeout: float = 30.0 # seconds + + def __init__(self, computer: Computer): + """Initialize the edit tool. + + Args: + computer: Computer instance for file operations + """ + # Initialize the base edit tool first + BaseEditTool.__init__(self, computer) + # Then initialize the Anthropic tool + BaseAnthropicTool.__init__(self) + + # Edit history for the current session + self.edit_history = defaultdict(list) + + async def __call__( + self, + *, + command: Command, + path: str, + file_text: str | None = None, + view_range: list[int] | None = None, + old_str: str | None = None, + new_str: str | None = None, + insert_line: int | None = None, + **kwargs, + ): + _path = Path(path) + await self.validate_path(command, _path) + + if command == "view": + return await self.view(_path, view_range) + elif command == "create": + if file_text is None: + raise ToolError("Parameter `file_text` is required for command: create") + await self.write_file(_path, file_text) + self.edit_history[_path].append(file_text) + return ToolResult(output=f"File created successfully at: {_path}") + elif command == "str_replace": + if old_str is None: + raise ToolError("Parameter `old_str` is required for command: str_replace") + return await self.str_replace(_path, old_str, new_str) + elif command == "insert": + if insert_line is None: + raise ToolError("Parameter `insert_line` is required for command: insert") + if new_str is None: + raise ToolError("Parameter `new_str` is required for command: insert") + return await self.insert(_path, insert_line, new_str) + elif command == "undo_edit": + return await self.undo_edit(_path) + + raise ToolError( + f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}' + ) + + async def validate_path(self, command: str, path: Path): + """Check that the path/command combination is valid.""" + # Check if its an absolute path + if not path.is_absolute(): + suggested_path = Path("") / path + raise ToolError( + f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?" + ) + + # Check if path exists using bash commands + try: + result = await self.computer.interface.run_command( + f'[ -e "{str(path)}" ] && echo "exists" || echo "not exists"' + ) + exists = result[0].strip() == "exists" + + if exists: + result = await self.computer.interface.run_command( + f'[ -d "{str(path)}" ] && echo "dir" || echo "file"' + ) + is_dir = result[0].strip() == "dir" + else: + is_dir = False + + # Check path validity + if not exists and command != "create": + raise ToolError(f"The path {path} does not exist. Please provide a valid path.") + if exists and command == "create": + raise ToolError( + f"File already exists at: {path}. Cannot overwrite files using command `create`." + ) + if is_dir and command != "view": + raise ToolError( + f"The path {path} is a directory and only the `view` command can be used on directories" + ) + except Exception as e: + raise ToolError(f"Failed to validate path: {str(e)}") + + async def view(self, path: Path, view_range: list[int] | None = None): + """Implement the view command""" + try: + # Check if path is a directory + result = await self.computer.interface.run_command( + f'[ -d "{str(path)}" ] && echo "dir" || echo "file"' + ) + is_dir = result[0].strip() == "dir" + + if is_dir: + if view_range: + raise ToolError( + "The `view_range` parameter is not allowed when `path` points to a directory." + ) + + # List directory contents using ls + result = await self.computer.interface.run_command(f'ls -la "{str(path)}"') + contents = result[0] + if contents: + stdout = f"Here's the files and directories in {path}:\n{contents}\n" + else: + stdout = f"Directory {path} is empty\n" + return CLIResult(output=stdout) + + # Read file content using cat + file_content = await self.read_file(path) + init_line = 1 + + if view_range: + if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range): + raise ToolError("Invalid `view_range`. It should be a list of two integers.") + + file_lines = file_content.split("\n") + n_lines_file = len(file_lines) + init_line, final_line = view_range + + if init_line < 1 or init_line > n_lines_file: + raise ToolError( + f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}" + ) + if final_line > n_lines_file: + raise ToolError( + f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`" + ) + if final_line != -1 and final_line < init_line: + raise ToolError( + f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`" + ) + + if final_line == -1: + file_content = "\n".join(file_lines[init_line - 1 :]) + else: + file_content = "\n".join(file_lines[init_line - 1 : final_line]) + + return CLIResult(output=self._make_output(file_content, str(path), init_line=init_line)) + except Exception as e: + raise ToolError(f"Failed to view path: {str(e)}") + + async def str_replace(self, path: Path, old_str: str, new_str: str | None): + """Implement the str_replace command""" + # Read the file content + file_content = await self.read_file(path) + file_content = file_content.expandtabs() + old_str = old_str.expandtabs() + new_str = new_str.expandtabs() if new_str is not None else "" + + # Check if old_str is unique in the file + occurrences = file_content.count(old_str) + if occurrences == 0: + raise ToolError( + f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}." + ) + elif occurrences > 1: + file_content_lines = file_content.split("\n") + lines = [idx + 1 for idx, line in enumerate(file_content_lines) if old_str in line] + raise ToolError( + f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique" + ) + + # Replace old_str with new_str + new_file_content = file_content.replace(old_str, new_str) + + # Write the new content to the file + await self.write_file(path, new_file_content) + + # Save the content to history + self.edit_history[path].append(file_content) + + # Create a snippet of the edited section + replacement_line = file_content.split(old_str)[0].count("\n") + start_line = max(0, replacement_line - SNIPPET_LINES) + end_line = replacement_line + SNIPPET_LINES + new_str.count("\n") + snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1]) + + # Prepare the success message + success_msg = f"The file {path} has been edited. " + success_msg += self._make_output(snippet, f"a snippet of {path}", start_line + 1) + success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary." + + return CLIResult(output=success_msg) + + async def insert(self, path: Path, insert_line: int, new_str: str): + """Implement the insert command""" + file_text = await self.read_file(path) + file_text = file_text.expandtabs() + new_str = new_str.expandtabs() + file_text_lines = file_text.split("\n") + n_lines_file = len(file_text_lines) + + if insert_line < 0 or insert_line > n_lines_file: + raise ToolError( + f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}" + ) + + new_str_lines = new_str.split("\n") + new_file_text_lines = ( + file_text_lines[:insert_line] + new_str_lines + file_text_lines[insert_line:] + ) + snippet_lines = ( + file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line] + + new_str_lines + + file_text_lines[insert_line : insert_line + SNIPPET_LINES] + ) + + new_file_text = "\n".join(new_file_text_lines) + snippet = "\n".join(snippet_lines) + + await self.write_file(path, new_file_text) + self.edit_history[path].append(file_text) + + success_msg = f"The file {path} has been edited. " + success_msg += self._make_output( + snippet, "a snippet of the edited file", max(1, insert_line - SNIPPET_LINES + 1) + ) + success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary." + return CLIResult(output=success_msg) + + async def undo_edit(self, path: Path): + """Implement the undo_edit command""" + if not self.edit_history[path]: + raise ToolError(f"No edit history found for {path}.") + + old_text = self.edit_history[path].pop() + await self.write_file(path, old_text) + + return CLIResult( + output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}" + ) + + async def read_file(self, path: Path) -> str: + """Read the content of a file using cat command.""" + try: + result = await self.computer.interface.run_command(f'cat "{str(path)}"') + if result[1]: # If there's stderr output + raise ToolError(f"Error reading file: {result[1]}") + return result[0] + except Exception as e: + raise ToolError(f"Failed to read {path}: {str(e)}") + + async def write_file(self, path: Path, content: str): + """Write content to a file using echo and redirection.""" + try: + # Create parent directories if they don't exist + parent = path.parent + if parent != Path("/"): + await self.computer.interface.run_command(f'mkdir -p "{str(parent)}"') + + # Write content to file using echo and heredoc to preserve formatting + cmd = f"""cat > "{str(path)}" << 'EOFCUA' +{content} +EOFCUA""" + result = await self.computer.interface.run_command(cmd) + if result[1]: # If there's stderr output + raise ToolError(f"Error writing file: {result[1]}") + except Exception as e: + raise ToolError(f"Failed to write to {path}: {str(e)}") + + def _make_output( + self, + file_content: str, + file_descriptor: str, + init_line: int = 1, + expand_tabs: bool = True, + ) -> str: + """Generate output for the CLI based on the content of a file.""" + file_content = maybe_truncate(file_content) + if expand_tabs: + file_content = file_content.expandtabs() + file_content = "\n".join( + [f"{i + init_line:6}\t{line}" for i, line in enumerate(file_content.split("\n"))] + ) + return ( + f"Here's the result of running `cat -n` on {file_descriptor}:\n" + file_content + "\n" + ) + + def to_params(self) -> Dict[str, Any]: + """Convert tool to API parameters. + + Returns: + Dictionary with tool parameters + """ + return { + "name": self.name, + "type": self.api_type, + } diff --git a/libs/agent/agent/providers/anthropic/tools/manager.py b/libs/agent/agent/providers/anthropic/tools/manager.py new file mode 100644 index 00000000..6e8857d1 --- /dev/null +++ b/libs/agent/agent/providers/anthropic/tools/manager.py @@ -0,0 +1,54 @@ +from typing import Any, Dict, List +from anthropic.types.beta import BetaToolUnionParam +from computer.computer import Computer + +from ....core.tools import BaseToolManager, ToolResult +from ....core.tools.collection import ToolCollection + +from .bash import BashTool +from .computer import ComputerTool +from .edit import EditTool + + +class ToolManager(BaseToolManager): + """Manages Anthropic-specific tool initialization and execution.""" + + def __init__(self, computer: Computer): + """Initialize the tool manager. + + Args: + computer: Computer instance for computer-related tools + """ + super().__init__(computer) + # Initialize Anthropic-specific tools + self.computer_tool = ComputerTool(self.computer) + self.bash_tool = BashTool(self.computer) + self.edit_tool = EditTool(self.computer) + + def _initialize_tools(self) -> ToolCollection: + """Initialize all available tools.""" + return ToolCollection(self.computer_tool, self.bash_tool, self.edit_tool) + + async def _initialize_tools_specific(self) -> None: + """Initialize Anthropic-specific tool requirements.""" + await self.computer_tool.initialize_dimensions() + + def get_tool_params(self) -> List[BetaToolUnionParam]: + """Get tool parameters for Anthropic API calls.""" + if self.tools is None: + raise RuntimeError("Tools not initialized. Call initialize() first.") + return self.tools.to_params() + + async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> ToolResult: + """Execute a tool with the given input. + + Args: + name: Name of the tool to execute + tool_input: Input parameters for the tool + + Returns: + Result of the tool execution + """ + if self.tools is None: + raise RuntimeError("Tools not initialized. Call initialize() first.") + return await self.tools.run(name=name, tool_input=tool_input) diff --git a/libs/agent/agent/providers/anthropic/tools/run.py b/libs/agent/agent/providers/anthropic/tools/run.py new file mode 100644 index 00000000..89db980a --- /dev/null +++ b/libs/agent/agent/providers/anthropic/tools/run.py @@ -0,0 +1,42 @@ +"""Utility to run shell commands asynchronously with a timeout.""" + +import asyncio + +TRUNCATED_MESSAGE: str = "To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for." +MAX_RESPONSE_LEN: int = 16000 + + +def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN): + """Truncate content and append a notice if content exceeds the specified length.""" + return ( + content + if not truncate_after or len(content) <= truncate_after + else content[:truncate_after] + TRUNCATED_MESSAGE + ) + + +async def run( + cmd: str, + timeout: float | None = 120.0, # seconds + truncate_after: int | None = MAX_RESPONSE_LEN, +): + """Run a shell command asynchronously with a timeout.""" + process = await asyncio.create_subprocess_shell( + cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + + try: + stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout) + return ( + process.returncode or 0, + maybe_truncate(stdout.decode(), truncate_after=truncate_after), + maybe_truncate(stderr.decode(), truncate_after=truncate_after), + ) + except asyncio.TimeoutError as exc: + try: + process.kill() + except ProcessLookupError: + pass + raise TimeoutError( + f"Command '{cmd}' timed out after {timeout} seconds" + ) from exc diff --git a/libs/agent/agent/providers/anthropic/types.py b/libs/agent/agent/providers/anthropic/types.py new file mode 100644 index 00000000..c2d80fdb --- /dev/null +++ b/libs/agent/agent/providers/anthropic/types.py @@ -0,0 +1,16 @@ +from enum import StrEnum + + +class LLMProvider(StrEnum): + """Enum for supported API providers.""" + + ANTHROPIC = "anthropic" + BEDROCK = "bedrock" + VERTEX = "vertex" + + +PROVIDER_TO_DEFAULT_MODEL_NAME: dict[LLMProvider, str] = { + LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219", + LLMProvider.BEDROCK: "anthropic.claude-3-7-sonnet-20250219-v2:0", + LLMProvider.VERTEX: "claude-3-5-sonnet-v2@20241022", +} diff --git a/libs/agent/agent/providers/omni/__init__.py b/libs/agent/agent/providers/omni/__init__.py new file mode 100644 index 00000000..8706c658 --- /dev/null +++ b/libs/agent/agent/providers/omni/__init__.py @@ -0,0 +1,27 @@ +"""Omni provider implementation.""" + +# The OmniComputerAgent has been replaced by the unified ComputerAgent +# which can be found in agent.core.agent +from .types import LLMProvider +from .experiment import ExperimentManager +from .visualization import visualize_click, visualize_scroll, calculate_element_center +from .image_utils import ( + decode_base64_image, + encode_image_base64, + clean_base64_data, + extract_base64_from_text, + get_image_dimensions, +) + +__all__ = [ + "LLMProvider", + "ExperimentManager", + "visualize_click", + "visualize_scroll", + "calculate_element_center", + "decode_base64_image", + "encode_image_base64", + "clean_base64_data", + "extract_base64_from_text", + "get_image_dimensions", +] diff --git a/libs/agent/agent/providers/omni/callbacks.py b/libs/agent/agent/providers/omni/callbacks.py new file mode 100644 index 00000000..b25aa05d --- /dev/null +++ b/libs/agent/agent/providers/omni/callbacks.py @@ -0,0 +1,78 @@ +"""Omni callback manager implementation.""" + +import logging +from typing import Any, Dict, Optional, Set + +from ...core.callbacks import BaseCallbackManager, ContentCallback, ToolCallback, APICallback +from ...types.tools import ToolResult + +logger = logging.getLogger(__name__) + +class OmniCallbackManager(BaseCallbackManager): + """Callback manager for multi-provider support.""" + + def __init__( + self, + content_callback: ContentCallback, + tool_callback: ToolCallback, + api_callback: APICallback, + ): + """Initialize Omni callback manager. + + Args: + content_callback: Callback for content updates + tool_callback: Callback for tool execution results + api_callback: Callback for API interactions + """ + super().__init__( + content_callback=content_callback, + tool_callback=tool_callback, + api_callback=api_callback + ) + self._active_tools: Set[str] = set() + + def on_content(self, content: Any) -> None: + """Handle content updates. + + Args: + content: Content update data + """ + logger.debug(f"Content update: {content}") + self.content_callback(content) + + def on_tool_result(self, result: ToolResult, tool_id: str) -> None: + """Handle tool execution results. + + Args: + result: Tool execution result + tool_id: ID of the tool + """ + logger.debug(f"Tool result for {tool_id}: {result}") + self.tool_callback(result, tool_id) + + def on_api_interaction( + self, + request: Any, + response: Any, + error: Optional[Exception] = None + ) -> None: + """Handle API interactions. + + Args: + request: API request data + response: API response data + error: Optional error that occurred + """ + if error: + logger.error(f"API error: {str(error)}") + else: + logger.debug(f"API interaction - Request: {request}, Response: {response}") + self.api_callback(request, response, error) + + def get_active_tools(self) -> Set[str]: + """Get currently active tools. + + Returns: + Set of active tool names + """ + return self._active_tools.copy() diff --git a/libs/agent/agent/providers/omni/clients/anthropic.py b/libs/agent/agent/providers/omni/clients/anthropic.py new file mode 100644 index 00000000..6d835277 --- /dev/null +++ b/libs/agent/agent/providers/omni/clients/anthropic.py @@ -0,0 +1,99 @@ +"""Anthropic API client implementation.""" + +import logging +from typing import Any, Dict, List, Optional, Tuple, cast +import asyncio +from httpx import ConnectError, ReadTimeout + +from anthropic import AsyncAnthropic, Anthropic +from anthropic.types import MessageParam +from .base import BaseOmniClient + +logger = logging.getLogger(__name__) + + +class AnthropicClient(BaseOmniClient): + """Client for making calls to Anthropic API.""" + + def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0): + """Initialize the Anthropic client. + + Args: + api_key: Anthropic API key + model: Anthropic model name (e.g. "claude-3-opus-20240229") + max_retries: Maximum number of retries for API calls + retry_delay: Base delay between retries in seconds + """ + if not model: + raise ValueError("Model name must be provided") + + self.client = AsyncAnthropic(api_key=api_key) + self.model: str = model # Add explicit type annotation + self.max_retries = max_retries + self.retry_delay = retry_delay + + def _convert_message_format(self, messages: List[Dict[str, Any]]) -> List[MessageParam]: + """Convert messages from standard format to Anthropic format. + + Args: + messages: Messages in standard format + + Returns: + Messages in Anthropic format + """ + anthropic_messages = [] + + for message in messages: + if message["role"] == "user": + anthropic_messages.append({"role": "user", "content": message["content"]}) + elif message["role"] == "assistant": + anthropic_messages.append({"role": "assistant", "content": message["content"]}) + + # Cast the list to the correct type expected by Anthropic + return cast(List[MessageParam], anthropic_messages) + + async def run_interleaved( + self, messages: List[Dict[str, Any]], system: str, max_tokens: int + ) -> Any: + """Run model with interleaved conversation format. + + Args: + messages: List of messages to process + system: System prompt + max_tokens: Maximum tokens to generate + + Returns: + Model response + """ + last_error = None + + for attempt in range(self.max_retries): + try: + # Convert messages to Anthropic format + anthropic_messages = self._convert_message_format(messages) + + response = await self.client.messages.create( + model=self.model, + max_tokens=max_tokens, + temperature=0, + system=system, + messages=anthropic_messages, + ) + + return response + + except (ConnectError, ReadTimeout) as e: + last_error = e + logger.warning( + f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}" + ) + if attempt < self.max_retries - 1: + await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff + continue + + except Exception as e: + logger.error(f"Unexpected error in Anthropic API call: {str(e)}") + raise RuntimeError(f"Anthropic API call failed: {str(e)}") + + # If we get here, all retries failed + raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}") diff --git a/libs/agent/agent/providers/omni/clients/base.py b/libs/agent/agent/providers/omni/clients/base.py new file mode 100644 index 00000000..77f2c69d --- /dev/null +++ b/libs/agent/agent/providers/omni/clients/base.py @@ -0,0 +1,44 @@ +"""Base client implementation for Omni providers.""" + +import os +import logging +from typing import Dict, List, Optional, Any, Tuple +import aiohttp +import json + +logger = logging.getLogger(__name__) + +class BaseOmniClient: + """Base class for provider-specific clients.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: Optional[str] = None + ): + """Initialize base client. + + Args: + api_key: Optional API key + model: Optional model name + """ + self.api_key = api_key + self.model = model + + async def run_interleaved( + self, + messages: List[Dict[str, Any]], + system: str, + max_tokens: Optional[int] = None + ) -> Dict[str, Any]: + """Run interleaved chat completion. + + Args: + messages: List of message dicts + system: System prompt + max_tokens: Optional max tokens override + + Returns: + Response dict + """ + raise NotImplementedError diff --git a/libs/agent/agent/providers/omni/clients/groq.py b/libs/agent/agent/providers/omni/clients/groq.py new file mode 100644 index 00000000..a7d6776b --- /dev/null +++ b/libs/agent/agent/providers/omni/clients/groq.py @@ -0,0 +1,101 @@ +"""Groq client implementation.""" + +import os +import logging +from typing import Dict, List, Optional, Any, Tuple + +from groq import Groq +import re +from .utils import is_image_path +from .base import BaseOmniClient + +logger = logging.getLogger(__name__) + + +class GroqClient(BaseOmniClient): + """Client for making Groq API calls.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "deepseek-r1-distill-llama-70b", + max_tokens: int = 4096, + temperature: float = 0.6, + ): + """Initialize Groq client. + + Args: + api_key: Groq API key (if not provided, will try to get from env) + model: Model name to use + max_tokens: Maximum tokens to generate + temperature: Temperature for sampling + """ + super().__init__(api_key=api_key, model=model) + self.api_key = api_key or os.getenv("GROQ_API_KEY") + if not self.api_key: + raise ValueError("No Groq API key provided") + + self.max_tokens = max_tokens + self.temperature = temperature + self.client = Groq(api_key=self.api_key) + self.model: str = model # Add explicit type annotation + + def run_interleaved( + self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None + ) -> tuple[str, int]: + """Run interleaved chat completion. + + Args: + messages: List of message dicts + system: System prompt + max_tokens: Optional max tokens override + + Returns: + Tuple of (response text, token usage) + """ + # Avoid using system messages for R1 + final_messages = [{"role": "user", "content": system}] + + # Process messages + if isinstance(messages, list): + for item in messages: + if isinstance(item, dict): + # For dict items, concatenate all text content, ignoring images + text_contents = [] + for cnt in item["content"]: + if isinstance(cnt, str): + if not is_image_path(cnt): # Skip image paths + text_contents.append(cnt) + else: + text_contents.append(str(cnt)) + + if text_contents: # Only add if there's text content + message = {"role": "user", "content": " ".join(text_contents)} + final_messages.append(message) + else: # str + message = {"role": "user", "content": item} + final_messages.append(message) + + elif isinstance(messages, str): + final_messages.append({"role": "user", "content": messages}) + + try: + completion = self.client.chat.completions.create( # type: ignore + model=self.model, + messages=final_messages, # type: ignore + temperature=self.temperature, + max_tokens=max_tokens or self.max_tokens, + top_p=0.95, + stream=False, + ) + + response = completion.choices[0].message.content + final_answer = response.split("\n")[-1] if "" in response else response + final_answer = final_answer.replace("", "").replace("", "") + token_usage = completion.usage.total_tokens + + return final_answer, token_usage + + except Exception as e: + logger.error(f"Error in Groq API call: {e}") + raise diff --git a/libs/agent/agent/providers/omni/clients/openai.py b/libs/agent/agent/providers/omni/clients/openai.py new file mode 100644 index 00000000..83ebc18c --- /dev/null +++ b/libs/agent/agent/providers/omni/clients/openai.py @@ -0,0 +1,159 @@ +"""OpenAI client implementation.""" + +import os +import logging +from typing import Dict, List, Optional, Any +import aiohttp +import base64 +import re +import json +import ssl +import certifi +from datetime import datetime +from .base import BaseOmniClient + +logger = logging.getLogger(__name__) + +# OpenAI specific client for the OmniLoop + + +class OpenAIClient(BaseOmniClient): + """OpenAI vision API client implementation.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "gpt-4o", + provider_base_url: str = "https://api.openai.com/v1", + max_tokens: int = 4096, + temperature: float = 0.0, + ): + """Initialize the OpenAI client. + + Args: + api_key: OpenAI API key + model: Model to use + provider_base_url: API endpoint + max_tokens: Maximum tokens to generate + temperature: Generation temperature + """ + super().__init__(api_key=api_key, model=model) + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("No OpenAI API key provided") + + self.model = model + self.provider_base_url = provider_base_url + self.max_tokens = max_tokens + self.temperature = temperature + + def _extract_base64_image(self, text: str) -> Optional[str]: + """Extract base64 image data from an HTML img tag.""" + pattern = r'data:image/[^;]+;base64,([^"]+)' + match = re.search(pattern, text) + return match.group(1) if match else None + + def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Create a loggable version of messages with image data truncated.""" + loggable_messages = [] + for msg in messages: + if isinstance(msg.get("content"), list): + new_content = [] + for content in msg["content"]: + if content.get("type") == "image": + new_content.append( + {"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}} + ) + else: + new_content.append(content) + loggable_messages.append({"role": msg["role"], "content": new_content}) + else: + loggable_messages.append(msg) + return loggable_messages + + async def run_interleaved( + self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None + ) -> Dict[str, Any]: + """Run interleaved chat completion. + + Args: + messages: List of message dicts + system: System prompt + max_tokens: Optional max tokens override + + Returns: + Response dict + """ + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + + final_messages = [{"role": "system", "content": system}] + + # Process messages + for item in messages: + if isinstance(item, dict): + if isinstance(item["content"], list): + # Content is already in the correct format + final_messages.append(item) + else: + # Single string content, check for image + base64_img = self._extract_base64_image(item["content"]) + if base64_img: + message = { + "role": item["role"], + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"}, + } + ], + } + else: + message = { + "role": item["role"], + "content": [{"type": "text", "text": item["content"]}], + } + final_messages.append(message) + else: + # String content, check for image + base64_img = self._extract_base64_image(item) + if base64_img: + message = { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"}, + } + ], + } + else: + message = {"role": "user", "content": [{"type": "text", "text": item}]} + final_messages.append(message) + + payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature} + + if "o1" in self.model or "o3-mini" in self.model: + payload["reasoning_effort"] = "low" + payload["max_completion_tokens"] = max_tokens or self.max_tokens + else: + payload["max_tokens"] = max_tokens or self.max_tokens + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.provider_base_url}/chat/completions", headers=headers, json=payload + ) as response: + response_json = await response.json() + + if response.status != 200: + error_msg = response_json.get("error", {}).get( + "message", str(response_json) + ) + logger.error(f"Error in OpenAI API call: {error_msg}") + raise Exception(f"OpenAI API error: {error_msg}") + + return response_json + + except Exception as e: + logger.error(f"Error in OpenAI API call: {str(e)}") + raise diff --git a/libs/agent/agent/providers/omni/clients/utils.py b/libs/agent/agent/providers/omni/clients/utils.py new file mode 100644 index 00000000..fb4bcfc4 --- /dev/null +++ b/libs/agent/agent/providers/omni/clients/utils.py @@ -0,0 +1,25 @@ +import base64 + +def is_image_path(text: str) -> bool: + """Check if a text string is an image file path. + + Args: + text: Text string to check + + Returns: + True if text ends with image extension, False otherwise + """ + image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif") + return text.endswith(image_extensions) + +def encode_image(image_path: str) -> str: + """Encode image file to base64. + + Args: + image_path: Path to image file + + Returns: + Base64 encoded image string + """ + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") \ No newline at end of file diff --git a/libs/agent/agent/providers/omni/experiment.py b/libs/agent/agent/providers/omni/experiment.py new file mode 100644 index 00000000..347eb74a --- /dev/null +++ b/libs/agent/agent/providers/omni/experiment.py @@ -0,0 +1,273 @@ +"""Experiment management for the Cua provider.""" + +import os +import logging +import copy +import base64 +from io import BytesIO +from datetime import datetime +from typing import Any, Dict, List, Optional +from PIL import Image +import json +import time + +logger = logging.getLogger(__name__) + + +class ExperimentManager: + """Manages experiment directories and logging for the agent.""" + + def __init__( + self, + base_dir: Optional[str] = None, + only_n_most_recent_images: Optional[int] = None, + ): + """Initialize the experiment manager. + + Args: + base_dir: Base directory for saving experiment data + only_n_most_recent_images: Maximum number of recent screenshots to include in API requests + """ + self.base_dir = base_dir + self.only_n_most_recent_images = only_n_most_recent_images + self.run_dir = None + self.current_turn_dir = None + self.turn_count = 0 + self.screenshot_count = 0 + # Track all screenshots for potential API request inclusion + self.screenshot_paths = [] + + # Set up experiment directories if base_dir is provided + if self.base_dir: + self.setup_experiment_dirs() + + def setup_experiment_dirs(self) -> None: + """Setup the experiment directory structure.""" + if not self.base_dir: + return + + # Create base experiments directory if it doesn't exist + os.makedirs(self.base_dir, exist_ok=True) + + # Use the base_dir directly as the run_dir + self.run_dir = self.base_dir + logger.info(f"Using directory for experiment: {self.run_dir}") + + # Create first turn directory + self.create_turn_dir() + + def create_turn_dir(self) -> None: + """Create a new directory for the current turn.""" + if not self.run_dir: + return + + self.turn_count += 1 + self.current_turn_dir = os.path.join(self.run_dir, f"turn_{self.turn_count:03d}") + os.makedirs(self.current_turn_dir, exist_ok=True) + logger.info(f"Created turn directory: {self.current_turn_dir}") + + def sanitize_log_data(self, data: Any) -> Any: + """Sanitize data for logging by removing large base64 strings. + + Args: + data: Data to sanitize (dict, list, or primitive) + + Returns: + Sanitized copy of the data + """ + if isinstance(data, dict): + result = copy.deepcopy(data) + + # Handle nested dictionaries and lists + for key, value in result.items(): + # Process content arrays that contain image data + if key == "content" and isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, dict): + # Handle Anthropic format + if item.get("type") == "image" and isinstance(item.get("source"), dict): + source = item["source"] + if "data" in source and isinstance(source["data"], str): + # Replace base64 data with a placeholder and length info + data_len = len(source["data"]) + source["data"] = f"[BASE64_IMAGE_DATA_LENGTH_{data_len}]" + + # Handle OpenAI format + elif item.get("type") == "image_url" and isinstance( + item.get("image_url"), dict + ): + url_dict = item["image_url"] + if "url" in url_dict and isinstance(url_dict["url"], str): + url = url_dict["url"] + if url.startswith("data:"): + # Replace base64 data with placeholder + data_len = len(url) + url_dict["url"] = f"[BASE64_IMAGE_URL_LENGTH_{data_len}]" + + # Handle other nested structures recursively + if isinstance(value, dict): + result[key] = self.sanitize_log_data(value) + elif isinstance(value, list): + result[key] = [self.sanitize_log_data(item) for item in value] + + return result + elif isinstance(data, list): + return [self.sanitize_log_data(item) for item in data] + else: + return data + + def save_debug_image(self, image_data: str, filename: str) -> None: + """Save a debug image to the experiment directory. + + Args: + image_data: Base64 encoded image data + filename: Filename to save the image as + """ + # Since we no longer want to use the images/ folder, we'll skip this functionality + return + + def save_screenshot(self, img_base64: str, action_type: str = "") -> None: + """Save a screenshot to the experiment directory. + + Args: + img_base64: Base64 encoded screenshot + action_type: Type of action that triggered the screenshot + """ + if not self.current_turn_dir: + return + + try: + # Increment screenshot counter + self.screenshot_count += 1 + + # Create a descriptive filename + timestamp = int(time.time() * 1000) + action_suffix = f"_{action_type}" if action_type else "" + filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png" + + # Save directly to the turn directory (no screenshots subdirectory) + filepath = os.path.join(self.current_turn_dir, filename) + + # Save the screenshot + img_data = base64.b64decode(img_base64) + with open(filepath, "wb") as f: + f.write(img_data) + + # Keep track of the file path for reference + self.screenshot_paths.append(filepath) + + return filepath + except Exception as e: + logger.error(f"Error saving screenshot: {str(e)}") + return None + + def should_save_debug_image(self) -> bool: + """Determine if debug images should be saved. + + Returns: + Boolean indicating if debug images should be saved + """ + # We no longer need to save debug images, so always return False + return False + + def save_action_visualization( + self, img: Image.Image, action_name: str, details: str = "" + ) -> str: + """Save a visualization of an action. + + Args: + img: Image to save + action_name: Name of the action + details: Additional details about the action + + Returns: + Path to the saved image + """ + if not self.current_turn_dir: + return "" + + try: + # Create a descriptive filename + timestamp = int(time.time() * 1000) + details_suffix = f"_{details}" if details else "" + filename = f"vis_{action_name}{details_suffix}_{timestamp}.png" + + # Save directly to the turn directory (no visualizations subdirectory) + filepath = os.path.join(self.current_turn_dir, filename) + + # Save the image + img.save(filepath) + + # Keep track of the file path for cleanup + self.screenshot_paths.append(filepath) + + return filepath + except Exception as e: + logger.error(f"Error saving action visualization: {str(e)}") + return "" + + def extract_and_save_images(self, data: Any, prefix: str) -> None: + """Extract and save images from response data. + + Args: + data: Response data to extract images from + prefix: Prefix for saved image filenames + """ + # Since we no longer want to save extracted images separately, + # we'll skip this functionality entirely + return + + def log_api_call( + self, + call_type: str, + request: Any, + provider: str, + model: str, + response: Any = None, + error: Optional[Exception] = None, + ) -> None: + """Log API call details to file. + + Args: + call_type: Type of API call (e.g., 'request', 'response', 'error') + request: The API request data + provider: The AI provider used + model: The AI model used + response: Optional API response data + error: Optional error information + """ + if not self.current_turn_dir: + return + + try: + # Create a unique filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"api_call_{timestamp}_{call_type}.json" + filepath = os.path.join(self.current_turn_dir, filename) + + # Sanitize data to remove large base64 strings + sanitized_request = self.sanitize_log_data(request) + sanitized_response = self.sanitize_log_data(response) if response is not None else None + + # Prepare log data + log_data = { + "timestamp": timestamp, + "provider": provider, + "model": model, + "type": call_type, + "request": sanitized_request, + } + + if sanitized_response is not None: + log_data["response"] = sanitized_response + if error is not None: + log_data["error"] = str(error) + + # Write to file + with open(filepath, "w") as f: + json.dump(log_data, f, indent=2, default=str) + + logger.info(f"Logged API {call_type} to {filepath}") + + except Exception as e: + logger.error(f"Error logging API call: {str(e)}") diff --git a/libs/agent/agent/providers/omni/image_utils.py b/libs/agent/agent/providers/omni/image_utils.py new file mode 100644 index 00000000..37c68705 --- /dev/null +++ b/libs/agent/agent/providers/omni/image_utils.py @@ -0,0 +1,106 @@ +"""Image processing utilities for the Cua provider.""" + +import base64 +import logging +import re +from io import BytesIO +from typing import Optional, Tuple +from PIL import Image + +logger = logging.getLogger(__name__) + + +def decode_base64_image(img_base64: str) -> Optional[Image.Image]: + """Decode a base64 encoded image to a PIL Image. + + Args: + img_base64: Base64 encoded image, may include data URL prefix + + Returns: + PIL Image or None if decoding fails + """ + try: + # Remove data URL prefix if present + if img_base64.startswith("data:image"): + img_base64 = img_base64.split(",")[1] + + # Decode base64 to bytes + img_data = base64.b64decode(img_base64) + + # Convert bytes to PIL Image + return Image.open(BytesIO(img_data)) + except Exception as e: + logger.error(f"Error decoding base64 image: {str(e)}") + return None + + +def encode_image_base64(img: Image.Image, format: str = "PNG") -> str: + """Encode a PIL Image to base64. + + Args: + img: PIL Image to encode + format: Image format (PNG, JPEG, etc.) + + Returns: + Base64 encoded image string + """ + try: + buffered = BytesIO() + img.save(buffered, format=format) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + except Exception as e: + logger.error(f"Error encoding image to base64: {str(e)}") + return "" + + +def clean_base64_data(img_base64: str) -> str: + """Clean base64 image data by removing data URL prefix. + + Args: + img_base64: Base64 encoded image, may include data URL prefix + + Returns: + Clean base64 string without prefix + """ + if img_base64.startswith("data:image"): + return img_base64.split(",")[1] + return img_base64 + + +def extract_base64_from_text(text: str) -> Optional[str]: + """Extract base64 image data from a text string. + + Args: + text: Text potentially containing base64 image data + + Returns: + Base64 string or None if not found + """ + # Look for data URL pattern + data_url_pattern = r"data:image/[^;]+;base64,([a-zA-Z0-9+/=]+)" + match = re.search(data_url_pattern, text) + if match: + return match.group(1) + + # Look for plain base64 pattern (basic heuristic) + base64_pattern = r"([a-zA-Z0-9+/=]{100,})" + match = re.search(base64_pattern, text) + if match: + return match.group(1) + + return None + + +def get_image_dimensions(img_base64: str) -> Tuple[int, int]: + """Get the dimensions of a base64 encoded image. + + Args: + img_base64: Base64 encoded image + + Returns: + Tuple of (width, height) or (0, 0) if decoding fails + """ + img = decode_base64_image(img_base64) + if img: + return img.size + return (0, 0) diff --git a/libs/agent/agent/providers/omni/loop.py b/libs/agent/agent/providers/omni/loop.py new file mode 100644 index 00000000..3901b530 --- /dev/null +++ b/libs/agent/agent/providers/omni/loop.py @@ -0,0 +1,965 @@ +"""Omni-specific agent loop implementation.""" + +import logging +from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator, Union +import base64 +from PIL import Image +from io import BytesIO +import json +import re +import os +from datetime import datetime +import asyncio +from httpx import ConnectError, ReadTimeout +import shutil +import copy + +from .parser import OmniParser, ParseResult, ParserMetadata, UIElement +from ...core.loop import BaseLoop +from computer import Computer +from .types import LLMProvider +from .clients.base import BaseOmniClient +from .clients.openai import OpenAIClient +from .clients.groq import GroqClient +from .clients.anthropic import AnthropicClient +from .prompts import SYSTEM_PROMPT +from .utils import compress_image_base64 +from .visualization import visualize_click, visualize_scroll, calculate_element_center +from .image_utils import decode_base64_image, clean_base64_data +from ...core.messages import ImageRetentionConfig +from .messages import OmniMessageManager + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def extract_data(input_string: str, data_type: str) -> str: + """Extract content from code blocks.""" + pattern = f"```{data_type}" + r"(.*?)(```|$)" + matches = re.findall(pattern, input_string, re.DOTALL) + return matches[0][0].strip() if matches else input_string + + +class OmniLoop(BaseLoop): + """Omni-specific implementation of the agent loop.""" + + def __init__( + self, + parser: OmniParser, + provider: LLMProvider, + api_key: str, + model: str, + computer: Computer, + only_n_most_recent_images: Optional[int] = 2, + base_dir: Optional[str] = "trajectories", + max_retries: int = 3, + retry_delay: float = 1.0, + save_trajectory: bool = True, + **kwargs, + ): + """Initialize the loop. + + Args: + parser: Parser instance + provider: API provider + api_key: API key + model: Model name + computer: Computer instance + only_n_most_recent_images: Maximum number of recent screenshots to include in API requests + base_dir: Base directory for saving experiment data + max_retries: Maximum number of retries for API calls + retry_delay: Delay between retries in seconds + save_trajectory: Whether to save trajectory data + """ + # Set parser and provider before initializing base class + self.parser = parser + self.provider = provider + + # Initialize message manager with image retention config + image_retention_config = ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images) + self.message_manager = OmniMessageManager(config=image_retention_config) + + # Initialize base class (which will set up experiment manager) + super().__init__( + computer=computer, + model=model, + api_key=api_key, + max_retries=max_retries, + retry_delay=retry_delay, + base_dir=base_dir, + save_trajectory=save_trajectory, + only_n_most_recent_images=only_n_most_recent_images, + **kwargs, + ) + + # Set API client attributes + self.client = None + self.retry_count = 0 + + def _should_save_debug_image(self) -> bool: + """Check if debug images should be saved. + + Returns: + bool: Always returns False as debug image saving has been disabled. + """ + # Debug image saving functionality has been removed + return False + + def _extract_and_save_images(self, data: Any, prefix: str) -> None: + """Extract and save images from API data. + + This method is now a no-op as image extraction functionality has been removed. + + Args: + data: Data to extract images from + prefix: Prefix for the extracted image filenames + """ + # Image extraction functionality has been removed + return + + def _save_debug_image(self, image_data: str, filename: str) -> None: + """Save a debug image to the current turn directory. + + This method is now a no-op as debug image saving functionality has been removed. + + Args: + image_data: Base64 encoded image data + filename: Name to use for the saved image + """ + # Debug image saving functionality has been removed + return + + def _visualize_action(self, x: int, y: int, img_base64: str) -> None: + """Visualize an action by drawing on the screenshot.""" + if ( + not self.save_trajectory + or not hasattr(self, "experiment_manager") + or not self.experiment_manager + ): + return + + try: + # Use the visualization utility + img = visualize_click(x, y, img_base64) + + # Save the visualization + self.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}") + except Exception as e: + logger.error(f"Error visualizing action: {str(e)}") + + def _visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None: + """Visualize a scroll action by drawing arrows on the screenshot.""" + if ( + not self.save_trajectory + or not hasattr(self, "experiment_manager") + or not self.experiment_manager + ): + return + + try: + # Use the visualization utility + img = visualize_scroll(direction, clicks, img_base64) + + # Save the visualization + self.experiment_manager.save_action_visualization( + img, "scroll", f"{direction}_{clicks}" + ) + except Exception as e: + logger.error(f"Error visualizing scroll: {str(e)}") + + def _save_action_visualization( + self, img: Image.Image, action_name: str, details: str = "" + ) -> str: + """Save a visualization of an action.""" + if hasattr(self, "experiment_manager") and self.experiment_manager: + return self.experiment_manager.save_action_visualization(img, action_name, details) + return "" + + async def initialize_client(self) -> None: + """Initialize the appropriate client based on provider.""" + try: + logger.info(f"Initializing {self.provider} client with model {self.model}...") + + if self.provider == LLMProvider.OPENAI: + self.client = OpenAIClient(api_key=self.api_key, model=self.model) + elif self.provider == LLMProvider.GROQ: + self.client = GroqClient(api_key=self.api_key, model=self.model) + elif self.provider == LLMProvider.ANTHROPIC: + self.client = AnthropicClient( + api_key=self.api_key, + model=self.model, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + ) + else: + raise ValueError(f"Unsupported provider: {self.provider}") + + logger.info(f"Initialized {self.provider} client with model {self.model}") + except Exception as e: + logger.error(f"Error initializing client: {str(e)}") + self.client = None + raise RuntimeError(f"Failed to initialize client: {str(e)}") + + async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any: + """Make API call to provider with retry logic.""" + # Create new turn directory for this API call + self._create_turn_dir() + + request_data = None + last_error = None + + for attempt in range(self.max_retries): + try: + # Ensure client is initialized + if self.client is None: + logger.info( + f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..." + ) + await self.initialize_client() + if self.client is None: + raise RuntimeError("Failed to initialize client") + + # Set the provider in message manager based on current provider + provider_name = str(self.provider).split(".")[-1].lower() # Extract name from enum + self.message_manager.set_provider(provider_name) + + # Apply image retention and prepare messages + # This will limit the number of images based on only_n_most_recent_images + prepared_messages = self.message_manager.get_formatted_messages(provider_name) + + # Filter out system messages for Anthropic + if self.provider == LLMProvider.ANTHROPIC: + filtered_messages = [ + msg for msg in prepared_messages if msg["role"] != "system" + ] + else: + filtered_messages = prepared_messages + + # Log request + request_data = {"messages": filtered_messages, "max_tokens": self.max_tokens} + + if self.provider == LLMProvider.ANTHROPIC: + request_data["system"] = self._get_system_prompt() + else: + request_data["system"] = system_prompt + + self._log_api_call("request", request_data) + + # Make API call with appropriate parameters + if self.client is None: + raise RuntimeError("Client not initialized. Call initialize_client() first.") + + # Check if the method is async by inspecting the client implementation + run_method = self.client.run_interleaved + is_async = asyncio.iscoroutinefunction(run_method) + + if is_async: + # For async implementations (AnthropicClient) + if self.provider == LLMProvider.ANTHROPIC: + response = await run_method( + messages=filtered_messages, + system=self._get_system_prompt(), + max_tokens=self.max_tokens, + ) + else: + response = await run_method( + messages=messages, + system=system_prompt, + max_tokens=self.max_tokens, + ) + else: + # For non-async implementations (GroqClient, etc.) + if self.provider == LLMProvider.ANTHROPIC: + response = run_method( + messages=filtered_messages, + system=self._get_system_prompt(), + max_tokens=self.max_tokens, + ) + else: + response = run_method( + messages=messages, + system=system_prompt, + max_tokens=self.max_tokens, + ) + + # Log success response + self._log_api_call("response", request_data, response) + + return response + + except (ConnectError, ReadTimeout) as e: + last_error = e + logger.warning( + f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}" + ) + if attempt < self.max_retries - 1: + await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff + # Reset client on connection errors to force re-initialization + self.client = None + continue + + except RuntimeError as e: + # Handle client initialization errors specifically + last_error = e + self._log_api_call("error", request_data, error=e) + logger.error( + f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}" + ) + if attempt < self.max_retries - 1: + # Reset client to force re-initialization + self.client = None + await asyncio.sleep(self.retry_delay) + continue + + except Exception as e: + # Log unexpected error + last_error = e + self._log_api_call("error", request_data, error=e) + logger.error(f"Unexpected error in API call: {str(e)}") + if attempt < self.max_retries - 1: + await asyncio.sleep(self.retry_delay) + continue + + # If we get here, all retries failed + error_message = f"API call failed after {self.max_retries} attempts" + if last_error: + error_message += f": {str(last_error)}" + + logger.error(error_message) + raise RuntimeError(error_message) + + async def _handle_response( + self, response: Any, messages: List[Dict[str, Any]], parsed_screen: Dict[str, Any] + ) -> Tuple[bool, bool]: + """Handle API response. + + Returns: + Tuple of (should_continue, action_screenshot_saved) + """ + action_screenshot_saved = False + try: + # Handle Anthropic response format + if self.provider == LLMProvider.ANTHROPIC: + if hasattr(response, "content") and isinstance(response.content, list): + # Extract text from content blocks + for block in response.content: + if hasattr(block, "type") and block.type == "text": + content = block.text + + # Try to find JSON in the content + try: + # First look for JSON block + json_content = extract_data(content, "json") + parsed_content = json.loads(json_content) + logger.info("Successfully parsed JSON from code block") + except (json.JSONDecodeError, IndexError): + # If no JSON block, try to find JSON object in the text + try: + # Look for JSON object pattern + json_pattern = r"\{[^}]+\}" + json_match = re.search(json_pattern, content) + if json_match: + json_str = json_match.group(0) + parsed_content = json.loads(json_str) + logger.info("Successfully parsed JSON from text") + else: + logger.error(f"No JSON found in content: {content}") + continue + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON from text: {str(e)}") + continue + + # Clean up Box ID format + if "Box ID" in parsed_content and isinstance( + parsed_content["Box ID"], str + ): + parsed_content["Box ID"] = parsed_content["Box ID"].replace( + "Box #", "" + ) + + # Add any explanatory text as reasoning if not present + if "Explanation" not in parsed_content: + # Extract any text before the JSON as reasoning + text_before_json = content.split("{")[0].strip() + if text_before_json: + parsed_content["Explanation"] = text_before_json + + # Log the parsed content for debugging + logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}") + + # Add response to messages + messages.append( + {"role": "assistant", "content": json.dumps(parsed_content)} + ) + + try: + # Execute action with current parsed screen info + await self._execute_action(parsed_content, parsed_screen) + action_screenshot_saved = True + except Exception as e: + logger.error(f"Error executing action: {str(e)}") + # Add error message to conversation + messages.append( + { + "role": "assistant", + "content": f"Error executing action: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + ) + return False, action_screenshot_saved + + # Check if task is complete + if parsed_content.get("Action") == "None": + return False, action_screenshot_saved + return True, action_screenshot_saved + + logger.warning("No text block found in Anthropic response") + return True, action_screenshot_saved + + # Handle other providers' response formats + if isinstance(response, dict) and "choices" in response: + content = response["choices"][0]["message"]["content"] + else: + content = response + + # Parse JSON content + if isinstance(content, str): + try: + # First try to parse the whole content as JSON + parsed_content = json.loads(content) + except json.JSONDecodeError: + try: + # Try to find JSON block + json_content = extract_data(content, "json") + parsed_content = json.loads(json_content) + except (json.JSONDecodeError, IndexError): + try: + # Look for JSON object pattern + json_pattern = r"\{[^}]+\}" + json_match = re.search(json_pattern, content) + if json_match: + json_str = json_match.group(0) + parsed_content = json.loads(json_str) + else: + logger.error(f"No JSON found in content: {content}") + return True, action_screenshot_saved + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON from text: {str(e)}") + return True, action_screenshot_saved + + # Clean up Box ID format + if "Box ID" in parsed_content and isinstance(parsed_content["Box ID"], str): + parsed_content["Box ID"] = parsed_content["Box ID"].replace("Box #", "") + + # Add any explanatory text as reasoning if not present + if "Explanation" not in parsed_content: + # Extract any text before the JSON as reasoning + text_before_json = content.split("{")[0].strip() + if text_before_json: + parsed_content["Explanation"] = text_before_json + + # Add response to messages with stringified content + messages.append({"role": "assistant", "content": json.dumps(parsed_content)}) + + try: + # Execute action with current parsed screen info + await self._execute_action(parsed_content, parsed_screen) + action_screenshot_saved = True + except Exception as e: + logger.error(f"Error executing action: {str(e)}") + # Add error message to conversation + messages.append( + { + "role": "assistant", + "content": f"Error executing action: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + ) + return False, action_screenshot_saved + + # Check if task is complete + if parsed_content.get("Action") == "None": + return False, action_screenshot_saved + + return True, action_screenshot_saved + elif isinstance(content, dict): + # Handle case where content is already a dictionary + messages.append({"role": "assistant", "content": json.dumps(content)}) + + try: + # Execute action with current parsed screen info + await self._execute_action(content, parsed_screen) + action_screenshot_saved = True + except Exception as e: + logger.error(f"Error executing action: {str(e)}") + # Add error message to conversation + messages.append( + { + "role": "assistant", + "content": f"Error executing action: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + ) + return False, action_screenshot_saved + + # Check if task is complete + if content.get("Action") == "None": + return False, action_screenshot_saved + + return True, action_screenshot_saved + + return True, action_screenshot_saved + + except Exception as e: + logger.error(f"Error handling response: {str(e)}") + messages.append( + { + "role": "assistant", + "content": f"Error: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + ) + raise + + async def _get_parsed_screen_som(self, save_screenshot: bool = True) -> ParseResult: + """Get parsed screen information with SOM. + + Args: + save_screenshot: Whether to save the screenshot (set to False when screenshots will be saved elsewhere) + + Returns: + ParseResult containing screen information and elements + """ + try: + # Use the parser's parse_screen method which handles the screenshot internally + parsed_screen = await self.parser.parse_screen(computer=self.computer) + + # Log information about the parsed results + logger.info( + f"Parsed screen with {len(parsed_screen.elements) if parsed_screen.elements else 0} elements" + ) + + # Save screenshot if requested and if we have image data + if save_screenshot and self.save_trajectory and parsed_screen.annotated_image_base64: + try: + # Extract just the image data (remove data:image/png;base64, prefix) + img_data = parsed_screen.annotated_image_base64 + if "," in img_data: + img_data = img_data.split(",")[1] + # Save with a generic "state" action type to indicate this is the current screen state + self._save_screenshot(img_data, action_type="state") + except Exception as e: + logger.error(f"Error saving screenshot: {str(e)}") + + return parsed_screen + + except Exception as e: + logger.error(f"Error getting parsed screen: {str(e)}") + raise + + async def _process_screen( + self, parsed_screen: ParseResult, messages: List[Dict[str, Any]] + ) -> None: + """Process and add screen info to messages.""" + try: + # Only add message if we have an image and provider supports it + if self.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]: + image = parsed_screen.annotated_image_base64 or None + if image: + # Save screen info to current turn directory + if self.current_turn_dir: + # Save elements as JSON + elements_path = os.path.join(self.current_turn_dir, "elements.json") + with open(elements_path, "w") as f: + # Convert elements to dicts for JSON serialization + elements_json = [elem.model_dump() for elem in parsed_screen.elements] + json.dump(elements_json, f, indent=2) + logger.info(f"Saved elements to {elements_path}") + + # Format the image content based on the provider + if self.provider == LLMProvider.ANTHROPIC: + # Compress the image before sending to Anthropic (5MB limit) + image_size = len(image) + logger.info(f"Image base64 is present, length: {image_size}") + + # Anthropic has a 5MB limit - check against base64 string length + # which is what matters for the API call payload + # Use slightly smaller limit (4.9MB) to account for request overhead + max_size = int(4.9 * 1024 * 1024) # 4.9MB + + # Default media type (will be overridden if compression is needed) + media_type = "image/png" + + # Check if the image already has a media type prefix + if image.startswith("data:"): + parts = image.split(",", 1) + if len(parts) == 2 and "image/jpeg" in parts[0].lower(): + media_type = "image/jpeg" + elif len(parts) == 2 and "image/png" in parts[0].lower(): + media_type = "image/png" + + if image_size > max_size: + logger.info( + f"Image size ({image_size} bytes) exceeds Anthropic limit ({max_size} bytes), compressing..." + ) + image, media_type = compress_image_base64(image, max_size) + logger.info( + f"Image compressed to {len(image)} bytes with media_type {media_type}" + ) + + # Anthropic uses "type": "image" + screen_info_msg = { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": image, + }, + } + ], + } + else: + # OpenAI and others use "type": "image_url" + screen_info_msg = { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image}"}, + } + ], + } + messages.append(screen_info_msg) + + except Exception as e: + logger.error(f"Error processing screen info: {str(e)}") + raise + + def _get_system_prompt(self) -> str: + """Get the system prompt for the model.""" + return SYSTEM_PROMPT + + async def _execute_action(self, content: Dict[str, Any], parsed_screen: ParseResult) -> None: + """Execute the action specified in the content using the tool manager. + + Args: + content: Dictionary containing the action details + parsed_screen: Current parsed screen information + """ + try: + action = content.get("Action", "").lower() + if not action: + return + + # Track if we saved an action-specific screenshot + action_screenshot_saved = False + + try: + # Prepare kwargs based on action type + kwargs = {} + + if action in ["left_click", "right_click", "double_click", "move_cursor"]: + try: + box_id = int(content["Box ID"]) + logger.info(f"Processing Box ID: {box_id}") + + # Calculate click coordinates + x, y = await self._calculate_click_coordinates(box_id, parsed_screen) + logger.info(f"Calculated coordinates: x={x}, y={y}") + + kwargs["x"] = x + kwargs["y"] = y + + # Visualize action if screenshot is available + if parsed_screen.annotated_image_base64: + img_data = parsed_screen.annotated_image_base64 + # Remove data URL prefix if present + if img_data.startswith("data:image"): + img_data = img_data.split(",")[1] + # Only save visualization for coordinate-based actions + self._visualize_action(x, y, img_data) + action_screenshot_saved = True + + except ValueError as e: + logger.error(f"Error processing Box ID: {str(e)}") + return + + elif action == "drag_to": + try: + box_id = int(content["Box ID"]) + x, y = await self._calculate_click_coordinates(box_id, parsed_screen) + kwargs.update( + { + "x": x, + "y": y, + "button": content.get("button", "left"), + "duration": float(content.get("duration", 0.5)), + } + ) + + # Visualize drag destination if screenshot is available + if parsed_screen.annotated_image_base64: + img_data = parsed_screen.annotated_image_base64 + # Remove data URL prefix if present + if img_data.startswith("data:image"): + img_data = img_data.split(",")[1] + # Only save visualization for coordinate-based actions + self._visualize_action(x, y, img_data) + action_screenshot_saved = True + + except ValueError as e: + logger.error(f"Error processing drag coordinates: {str(e)}") + return + + elif action == "type_text": + kwargs["text"] = content["Value"] + # For type_text, store the value in the action type + action_type = f"type_{content['Value'][:20]}" # Truncate if too long + elif action == "press_key": + kwargs["key"] = content["Value"] + action_type = f"press_{content['Value']}" + elif action == "hotkey": + if isinstance(content.get("Value"), list): + keys = content["Value"] + action_type = f"hotkey_{'_'.join(keys)}" + else: + # Simply split string format like "command+space" into a list + keys = [k.strip() for k in content["Value"].lower().split("+")] + action_type = f"hotkey_{content['Value'].replace('+', '_')}" + logger.info(f"Preparing hotkey with keys: {keys}") + # Get the method but call it with *args instead of **kwargs + method = getattr(self.computer.interface, action) + await method(*keys) # Unpack the keys list as positional arguments + logger.info(f"Tool execution completed successfully: {action}") + + # For hotkeys, take a screenshot after the action + try: + # Get a new screenshot after the action and save it with the action type + new_parsed_screen = await self._get_parsed_screen_som(save_screenshot=False) + if new_parsed_screen and new_parsed_screen.annotated_image_base64: + img_data = new_parsed_screen.annotated_image_base64 + # Remove data URL prefix if present + if img_data.startswith("data:image"): + img_data = img_data.split(",")[1] + # Save with action type to indicate this is a post-action screenshot + self._save_screenshot(img_data, action_type=action_type) + action_screenshot_saved = True + except Exception as screenshot_error: + logger.error( + f"Error taking post-hotkey screenshot: {str(screenshot_error)}" + ) + + return + + elif action in ["scroll_down", "scroll_up"]: + clicks = int(content.get("amount", 1)) + kwargs["clicks"] = clicks + action_type = f"scroll_{action.split('_')[1]}_{clicks}" + + # Visualize scrolling if screenshot is available + if parsed_screen.annotated_image_base64: + img_data = parsed_screen.annotated_image_base64 + # Remove data URL prefix if present + if img_data.startswith("data:image"): + img_data = img_data.split(",")[1] + direction = "down" if action == "scroll_down" else "up" + # For scrolling, we only save the visualization to avoid duplicate images + self._visualize_scroll(direction, clicks, img_data) + action_screenshot_saved = True + + else: + logger.warning(f"Unknown action: {action}") + return + + # Execute tool and handle result + try: + method = getattr(self.computer.interface, action) + logger.info(f"Found method for action '{action}': {method}") + await method(**kwargs) + logger.info(f"Tool execution completed successfully: {action}") + + # For non-coordinate based actions that don't already have visualizations, + # take a new screenshot after the action + if not action_screenshot_saved: + # Take a new screenshot + try: + # Get a new screenshot after the action and save it with the action type + new_parsed_screen = await self._get_parsed_screen_som( + save_screenshot=False + ) + if new_parsed_screen and new_parsed_screen.annotated_image_base64: + img_data = new_parsed_screen.annotated_image_base64 + # Remove data URL prefix if present + if img_data.startswith("data:image"): + img_data = img_data.split(",")[1] + # Save with action type to indicate this is a post-action screenshot + if "action_type" in locals(): + self._save_screenshot(img_data, action_type=action_type) + else: + self._save_screenshot(img_data, action_type=action) + # Update the action screenshot flag for this turn + action_screenshot_saved = True + except Exception as screenshot_error: + logger.error( + f"Error taking post-action screenshot: {str(screenshot_error)}" + ) + + except AttributeError as e: + logger.error(f"Method not found for action '{action}': {str(e)}") + return + except Exception as tool_error: + logger.error(f"Tool execution failed: {str(tool_error)}") + return + + except Exception as e: + logger.error(f"Error executing action {action}: {str(e)}") + return + + except Exception as e: + logger.error(f"Error in _execute_action: {str(e)}") + return + + async def _calculate_click_coordinates( + self, box_id: int, parsed_screen: ParseResult + ) -> Tuple[int, int]: + """Calculate click coordinates based on box ID. + + Args: + box_id: The ID of the box to click + parsed_screen: The parsed screen information + + Returns: + Tuple of (x, y) coordinates + + Raises: + ValueError: If box_id is invalid or missing from parsed screen + """ + # First try to use structured elements data + logger.info(f"Elements count: {len(parsed_screen.elements)}") + + # Try to find element with matching ID + for element in parsed_screen.elements: + if element.id == box_id: + logger.info(f"Found element with ID {box_id}: {element}") + bbox = element.bbox + + # Get screen dimensions from the metadata if available, or fallback + width = parsed_screen.metadata.width if parsed_screen.metadata else 1920 + height = parsed_screen.metadata.height if parsed_screen.metadata else 1080 + logger.info(f"Screen dimensions: width={width}, height={height}") + + # Calculate center of the box in pixels + center_x = int((bbox.x1 + bbox.x2) / 2 * width) + center_y = int((bbox.y1 + bbox.y2) / 2 * height) + logger.info(f"Calculated center: ({center_x}, {center_y})") + + # Validate coordinates - if they're (0,0) or unreasonably small, + # use a default position in the center of the screen + if center_x == 0 and center_y == 0: + logger.warning("Got (0,0) coordinates, using fallback position") + center_x = width // 2 + center_y = height // 2 + logger.info(f"Using fallback center: ({center_x}, {center_y})") + + return center_x, center_y + + # If we couldn't find the box, use center of screen + logger.error( + f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})" + ) + + # Use center of screen as fallback + width = parsed_screen.metadata.width if parsed_screen.metadata else 1920 + height = parsed_screen.metadata.height if parsed_screen.metadata else 1080 + logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})") + return width // 2, height // 2 + + async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]: + """Run the agent loop with provided messages. + + Args: + messages: List of message objects + + Yields: + Dict containing response data + """ + # Keep track of conversation history + conversation_history = messages.copy() + + # Continue running until explicitly told to stop + running = True + turn_created = False + # Track if an action-specific screenshot has been saved this turn + action_screenshot_saved = False + + attempt = 0 + max_attempts = 3 + + while running and attempt < max_attempts: + try: + # Create a new turn directory if it's not already created + if not turn_created: + self._create_turn_dir() + turn_created = True + + # Ensure client is initialized + if self.client is None: + logger.info("Initializing client...") + await self.initialize_client() + if self.client is None: + raise RuntimeError("Failed to initialize client") + logger.info("Client initialized successfully") + + # Get up-to-date screen information + parsed_screen = await self._get_parsed_screen_som() + + # Process screen info and update messages + await self._process_screen(parsed_screen, conversation_history) + + # Get system prompt + system_prompt = self._get_system_prompt() + + # Make API call with retries + response = await self._make_api_call(conversation_history, system_prompt) + + # Handle the response (may execute actions) + # Returns: (should_continue, action_screenshot_saved) + should_continue, new_screenshot_saved = await self._handle_response( + response, conversation_history, parsed_screen + ) + + # Update whether an action screenshot was saved this turn + action_screenshot_saved = action_screenshot_saved or new_screenshot_saved + + # Yield the response to the caller + yield {"response": response} + + # Check if we should continue this conversation + running = should_continue + + # Create a new turn directory if we're continuing + if running: + turn_created = False + + # Reset attempt counter on success + attempt = 0 + + except Exception as e: + attempt += 1 + error_msg = f"Error in run method (attempt {attempt}/{max_attempts}): {str(e)}" + logger.error(error_msg) + + # If this is our last attempt, provide more info about the error + if attempt >= max_attempts: + logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}") + + yield { + "error": str(e), + "metadata": {"title": "❌ Error"}, + } + + # Create a brief delay before retrying + await asyncio.sleep(1) diff --git a/libs/agent/agent/providers/omni/messages.py b/libs/agent/agent/providers/omni/messages.py new file mode 100644 index 00000000..8c1824d7 --- /dev/null +++ b/libs/agent/agent/providers/omni/messages.py @@ -0,0 +1,171 @@ +"""Omni message manager implementation.""" + +import base64 +from typing import Any, Dict, List, Optional +from io import BytesIO +from PIL import Image + +from ...core.messages import BaseMessageManager, ImageRetentionConfig + + +class OmniMessageManager(BaseMessageManager): + """Message manager for multi-provider support.""" + + def __init__(self, config: Optional[ImageRetentionConfig] = None): + """Initialize the message manager. + + Args: + config: Optional configuration for image retention + """ + super().__init__(config) + self.messages: List[Dict[str, Any]] = [] + self.config = config + + def add_user_message(self, content: str, images: Optional[List[bytes]] = None) -> None: + """Add a user message to the history. + + Args: + content: Message content + images: Optional list of image data + """ + # Add images if present + if images: + # Initialize with proper typing for mixed content + message_content: List[Dict[str, Any]] = [{"type": "text", "text": content}] + + # Add each image + for img in images: + message_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{base64.b64encode(img).decode()}" + }, + } + ) + + message = {"role": "user", "content": message_content} + else: + # Simple text message + message = {"role": "user", "content": content} + + self.messages.append(message) + + # Apply retention policy + if self.config and self.config.num_images_to_keep: + self._apply_image_retention_policy() + + def add_assistant_message(self, content: str) -> None: + """Add an assistant message to the history. + + Args: + content: Message content + """ + self.messages.append({"role": "assistant", "content": content}) + + def add_system_message(self, content: str) -> None: + """Add a system message to the history. + + Args: + content: Message content + """ + self.messages.append({"role": "system", "content": content}) + + def _apply_image_retention_policy(self) -> None: + """Apply image retention policy to message history.""" + if not self.config or not self.config.num_images_to_keep: + return + + # Count images from newest to oldest + image_count = 0 + for message in reversed(self.messages): + if message["role"] != "user": + continue + + # Handle multimodal messages + if isinstance(message["content"], list): + new_content = [] + for item in message["content"]: + if item["type"] == "text": + new_content.append(item) + elif item["type"] == "image_url": + if image_count < self.config.num_images_to_keep: + new_content.append(item) + image_count += 1 + message["content"] = new_content + + def get_formatted_messages(self, provider: str) -> List[Dict[str, Any]]: + """Get messages formatted for specific provider. + + Args: + provider: Provider name to format messages for + + Returns: + List of formatted messages + """ + # Set the provider for message formatting + self.set_provider(provider) + + if provider == "anthropic": + return self._format_for_anthropic() + elif provider == "openai": + return self._format_for_openai() + elif provider == "groq": + return self._format_for_groq() + elif provider == "qwen": + return self._format_for_qwen() + else: + raise ValueError(f"Unsupported provider: {provider}") + + def _format_for_anthropic(self) -> List[Dict[str, Any]]: + """Format messages for Anthropic API.""" + formatted = [] + for msg in self.messages: + formatted_msg = {"role": msg["role"]} + + # Handle multimodal content + if isinstance(msg["content"], list): + formatted_msg["content"] = [] + for item in msg["content"]: + if item["type"] == "text": + formatted_msg["content"].append({"type": "text", "text": item["text"]}) + elif item["type"] == "image_url": + formatted_msg["content"].append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": item["image_url"]["url"].split(",")[1], + }, + } + ) + else: + formatted_msg["content"] = msg["content"] + + formatted.append(formatted_msg) + return formatted + + def _format_for_openai(self) -> List[Dict[str, Any]]: + """Format messages for OpenAI API.""" + # OpenAI already uses the same format + return self.messages + + def _format_for_groq(self) -> List[Dict[str, Any]]: + """Format messages for Groq API.""" + # Groq uses OpenAI-compatible format + return self.messages + + def _format_for_qwen(self) -> List[Dict[str, Any]]: + """Format messages for Qwen API.""" + formatted = [] + for msg in self.messages: + if isinstance(msg["content"], list): + # Convert multimodal content to text-only + text_content = next( + (item["text"] for item in msg["content"] if item["type"] == "text"), "" + ) + formatted.append({"role": msg["role"], "content": text_content}) + else: + formatted.append(msg) + return formatted diff --git a/libs/agent/agent/providers/omni/parser.py b/libs/agent/agent/providers/omni/parser.py new file mode 100644 index 00000000..1ecc381b --- /dev/null +++ b/libs/agent/agent/providers/omni/parser.py @@ -0,0 +1,252 @@ +"""Parser implementation for the Omni provider.""" + +import logging +from typing import Any, Dict, List, Optional, Tuple +import base64 +from PIL import Image +from io import BytesIO +import json +import torch + +# Import from the SOM package +from som import OmniParser as OmniDetectParser +from som.models import ParseResult, BoundingBox, UIElement, ImageData, ParserMetadata + +logger = logging.getLogger(__name__) + + +class OmniParser: + """Parser for handling responses from multiple providers.""" + + # Class-level shared OmniDetectParser instance + _shared_parser = None + + def __init__(self, force_device: Optional[str] = None): + """Initialize the OmniParser. + + Args: + force_device: Optional device to force for detection (cpu/cuda/mps) + """ + self.response_buffer = [] + + # Use shared parser if available, otherwise create a new one + if OmniParser._shared_parser is None: + logger.info("Initializing shared OmniDetectParser...") + + # Determine the best device to use + device = force_device + if not device: + if torch.cuda.is_available(): + device = "cuda" + elif ( + hasattr(torch, "backends") + and hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() + ): + device = "mps" + else: + device = "cpu" + + logger.info(f"Using device: {device} for OmniDetectParser") + self.detect_parser = OmniDetectParser(force_device=device) + + # Preload the detection model to avoid repeated loading + try: + # Access the detector to trigger model loading + detector = self.detect_parser.detector + if detector.model is None: + logger.info("Preloading detection model...") + detector.load_model() + logger.info("Detection model preloaded successfully") + except Exception as e: + logger.error(f"Error preloading detection model: {str(e)}") + + # Store as shared instance + OmniParser._shared_parser = self.detect_parser + else: + logger.info("Using existing shared OmniDetectParser") + self.detect_parser = OmniParser._shared_parser + + async def parse_screen(self, computer: Any) -> ParseResult: + """Parse a screenshot and extract screen information. + + Args: + computer: Computer instance + + Returns: + ParseResult with screen elements and image data + """ + try: + # Get screenshot from computer + logger.info("Taking screenshot...") + screenshot = await computer.interface.screenshot() + + # Log screenshot info + logger.info(f"Screenshot type: {type(screenshot)}") + logger.info(f"Screenshot is bytes: {isinstance(screenshot, bytes)}") + logger.info(f"Screenshot is str: {isinstance(screenshot, str)}") + logger.info(f"Screenshot length: {len(screenshot) if screenshot else 0}") + + # If screenshot is a string (likely base64), convert it to bytes + if isinstance(screenshot, str): + try: + screenshot = base64.b64decode(screenshot) + logger.info("Successfully converted base64 string to bytes") + logger.info(f"Decoded bytes length: {len(screenshot)}") + except Exception as e: + logger.error(f"Error decoding base64: {str(e)}") + logger.error(f"First 100 chars of screenshot string: {screenshot[:100]}") + + # Pass screenshot to OmniDetectParser + logger.info("Passing screenshot to OmniDetectParser...") + parse_result = self.detect_parser.parse( + screenshot_data=screenshot, box_threshold=0.3, iou_threshold=0.1, use_ocr=True + ) + logger.info("Screenshot parsed successfully") + logger.info(f"Parse result has {len(parse_result.elements)} elements") + + # Log element IDs for debugging + for i, elem in enumerate(parse_result.elements): + logger.info( + f"Element {i+1} (ID: {elem.id}): {elem.type} with confidence {elem.confidence:.3f}" + ) + + return parse_result + + except Exception as e: + logger.error(f"Error parsing screen: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + + # Create a minimal valid result for error cases + return ParseResult( + elements=[], + annotated_image_base64="", + parsed_content_list=[f"Error: {str(e)}"], + metadata=ParserMetadata( + image_size=(0, 0), + num_icons=0, + num_text=0, + device="cpu", + ocr_enabled=False, + latency=0.0, + ), + ) + + def parse_tool_call(self, response: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Parse a tool call from the response. + + Args: + response: Response from the provider + + Returns: + Parsed tool call or None if no tool call found + """ + try: + # Handle Anthropic format + if "tool_calls" in response: + tool_call = response["tool_calls"][0] + return { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + + # Handle OpenAI format + if "function_call" in response: + return { + "name": response["function_call"]["name"], + "arguments": response["function_call"]["arguments"], + } + + # Handle Groq format (OpenAI-compatible) + if "choices" in response and response["choices"]: + choice = response["choices"][0] + if "function_call" in choice: + return { + "name": choice["function_call"]["name"], + "arguments": choice["function_call"]["arguments"], + } + + return None + + except Exception as e: + logger.error(f"Error parsing tool call: {str(e)}") + return None + + def parse_response(self, response: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + """Parse a response from any provider. + + Args: + response: Response from the provider + + Returns: + Tuple of (content, metadata) + """ + try: + content = "" + metadata = {} + + # Handle Anthropic format + if "content" in response and isinstance(response["content"], list): + for item in response["content"]: + if item["type"] == "text": + content += item["text"] + + # Handle OpenAI format + elif "choices" in response and response["choices"]: + content = response["choices"][0]["message"]["content"] + + # Handle direct content + elif isinstance(response.get("content"), str): + content = response["content"] + + # Extract metadata if present + if "metadata" in response: + metadata = response["metadata"] + + return content, metadata + + except Exception as e: + logger.error(f"Error parsing response: {str(e)}") + return str(e), {"error": True} + + def format_for_provider( + self, messages: List[Dict[str, Any]], provider: str + ) -> List[Dict[str, Any]]: + """Format messages for a specific provider. + + Args: + messages: List of messages to format + provider: Provider to format for + + Returns: + Formatted messages + """ + try: + formatted = [] + + for msg in messages: + formatted_msg = {"role": msg["role"]} + + # Handle content formatting + if isinstance(msg["content"], list): + # For providers that support multimodal + if provider in ["anthropic", "openai"]: + formatted_msg["content"] = msg["content"] + else: + # Extract text only for other providers + text_content = next( + (item["text"] for item in msg["content"] if item["type"] == "text"), "" + ) + formatted_msg["content"] = text_content + else: + formatted_msg["content"] = msg["content"] + + formatted.append(formatted_msg) + + return formatted + + except Exception as e: + logger.error(f"Error formatting messages: {str(e)}") + return messages # Return original messages on error diff --git a/libs/agent/agent/providers/omni/prompts.py b/libs/agent/agent/providers/omni/prompts.py new file mode 100644 index 00000000..d21e8fc5 --- /dev/null +++ b/libs/agent/agent/providers/omni/prompts.py @@ -0,0 +1,64 @@ +"""Prompts for the Omni agent.""" + +SYSTEM_PROMPT = """ +You are using a macOS device. +You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot. + +You may be given some history plan and actions, this is the response from the previous loop. +You should carefully consider your plan base on the task, screenshot, and history actions. + +Your available "Next Action" only include: +- type_text: types a string of text. +- left_click: move mouse to box id and left clicks. +- right_click: move mouse to box id and right clicks. +- double_click: move mouse to box id and double clicks. +- move_cursor: move mouse to box id. +- scroll_up: scrolls the screen up to view previous content. +- scroll_down: scrolls the screen down, when the desired button is not visible, or you need to see more content. +- hotkey: press a sequence of keys. +- wait: waits for 1 second for the device to load or respond. + +Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task. + +Output format: +{ + "Explanation": str, # describe what is in the current screen, taking into account the history, then describe your step-by-step thoughts on how to achieve the task, choose one action from available actions at a time. + "Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely. + "Box ID": n, + "Value": "xxx" # only provide value field if the action is type, else don't include value key +} + +One Example: +{ + "Explanation": "The current screen shows google result of amazon, in previous action I have searched amazon on google. Then I need to click on the first search results to go to amazon.com.", + "Action": "left_click", + "Box ID": 4 +} + +Another Example: +{ + "Explanation": "The current screen shows the front page of amazon. There is no previous action. Therefore I need to type "Apple watch" in the search bar.", + "Action": "type_text", + "Box ID": 2, + "Value": "Apple watch" +} + +Another Example: +{ + "Explanation": "I am starting a Spotlight search to find the Safari browser.", + "Action": "hotkey", + "Value": "command+space" +} + +IMPORTANT NOTES: +1. You should only give a single action at a time. +2. The Box ID is the id of the element you should operate on, it is a number. Its background color corresponds to the color of the bounding box of the element. +3. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. +4. Attach the next action prediction in the "Action" field. +5. For starting applications, always use the "hotkey" action with command+space for starting a Spotlight search. +6. When the task is completed, don't complete additional actions. You should say "Action": "None" in the json field. +7. The tasks involve buying multiple products or navigating through multiple pages. You should break it into subgoals and complete each subgoal one by one in the order of the instructions. +8. Avoid choosing the same action/elements multiple times in a row, if it happens, reflect to yourself, what may have gone wrong, and predict a different action. +9. Reflect whether the element is clickable or not, for example reflect if it is an hyperlink or a button or a normal text. +10. If you are prompted with login information page or captcha page, or you think it need user's permission to do the next action, you should say "Action": "None" in the json field. +""" diff --git a/libs/agent/agent/providers/omni/tool_manager.py b/libs/agent/agent/providers/omni/tool_manager.py new file mode 100644 index 00000000..5a9260d2 --- /dev/null +++ b/libs/agent/agent/providers/omni/tool_manager.py @@ -0,0 +1,91 @@ +# """Omni tool manager implementation.""" + +# from typing import Dict, List, Type, Any + +# from computer import Computer +# from ...core.tools import BaseToolManager, BashTool, EditTool + +# class OmniToolManager(BaseToolManager): +# """Tool manager for multi-provider support.""" + +# def __init__(self, computer: Computer): +# """Initialize Omni tool manager. + +# Args: +# computer: Computer instance for tools +# """ +# super().__init__(computer) + +# def get_anthropic_tools(self) -> List[Dict[str, Any]]: +# """Get tools formatted for Anthropic API. + +# Returns: +# List of tool parameters in Anthropic format +# """ +# tools: List[Dict[str, Any]] = [] + +# # Map base tools to Anthropic format +# for tool in self.tools.values(): +# if isinstance(tool, BashTool): +# tools.append({ +# "type": "bash_20241022", +# "name": tool.name +# }) +# elif isinstance(tool, EditTool): +# tools.append({ +# "type": "text_editor_20241022", +# "name": "str_replace_editor" +# }) + +# return tools + +# def get_openai_tools(self) -> List[Dict]: +# """Get tools formatted for OpenAI API. + +# Returns: +# List of tool parameters in OpenAI format +# """ +# tools = [] + +# # Map base tools to OpenAI format +# for tool in self.tools.values(): +# tools.append({ +# "type": "function", +# "function": tool.get_schema() +# }) + +# return tools + +# def get_groq_tools(self) -> List[Dict]: +# """Get tools formatted for Groq API. + +# Returns: +# List of tool parameters in Groq format +# """ +# tools = [] + +# # Map base tools to Groq format +# for tool in self.tools.values(): +# tools.append({ +# "type": "function", +# "function": tool.get_schema() +# }) + +# return tools + +# def get_qwen_tools(self) -> List[Dict]: +# """Get tools formatted for Qwen API. + +# Returns: +# List of tool parameters in Qwen format +# """ +# tools = [] + +# # Map base tools to Qwen format +# for tool in self.tools.values(): +# tools.append({ +# "type": "function", +# "function": tool.get_schema() +# }) + +# return tools diff --git a/libs/agent/agent/providers/omni/tools/__init__.py b/libs/agent/agent/providers/omni/tools/__init__.py new file mode 100644 index 00000000..31b65f8d --- /dev/null +++ b/libs/agent/agent/providers/omni/tools/__init__.py @@ -0,0 +1,13 @@ +"""Omni provider tools - compatible with multiple LLM providers.""" + +from .bash import OmniBashTool +from .computer import OmniComputerTool +from .edit import OmniEditTool +from .manager import OmniToolManager + +__all__ = [ + "OmniBashTool", + "OmniComputerTool", + "OmniEditTool", + "OmniToolManager", +] diff --git a/libs/agent/agent/providers/omni/tools/bash.py b/libs/agent/agent/providers/omni/tools/bash.py new file mode 100644 index 00000000..f21352d1 --- /dev/null +++ b/libs/agent/agent/providers/omni/tools/bash.py @@ -0,0 +1,69 @@ +"""Provider-agnostic implementation of the BashTool.""" + +import logging +from typing import Any, Dict + +from computer.computer import Computer + +from ....core.tools.bash import BaseBashTool +from ....core.tools import ToolResult + + +class OmniBashTool(BaseBashTool): + """A provider-agnostic implementation of the bash tool.""" + + name = "bash" + logger = logging.getLogger(__name__) + + def __init__(self, computer: Computer): + """Initialize the BashTool. + + Args: + computer: Computer instance, may be used for related operations + """ + super().__init__(computer) + + def to_params(self) -> Dict[str, Any]: + """Convert tool to provider-agnostic parameters. + + Returns: + Dictionary with tool parameters + """ + return { + "name": self.name, + "description": "A tool that allows the agent to run bash commands", + "parameters": { + "command": {"type": "string", "description": "The bash command to execute"}, + "restart": { + "type": "boolean", + "description": "Whether to restart the bash session", + }, + }, + } + + async def __call__(self, **kwargs) -> ToolResult: + """Execute the bash tool with the provided arguments. + + Args: + command: The bash command to execute + restart: Whether to restart the bash session + + Returns: + ToolResult with the command output + """ + command = kwargs.get("command") + restart = kwargs.get("restart", False) + + if not command: + return ToolResult(error="Command is required") + + self.logger.info(f"Executing bash command: {command}") + exit_code, stdout, stderr = await self.run_command(command) + + output = stdout + error = None + + if exit_code != 0: + error = f"Command exited with code {exit_code}: {stderr}" + + return ToolResult(output=output, error=error) diff --git a/libs/agent/agent/providers/omni/tools/computer.py b/libs/agent/agent/providers/omni/tools/computer.py new file mode 100644 index 00000000..ccd933ba --- /dev/null +++ b/libs/agent/agent/providers/omni/tools/computer.py @@ -0,0 +1,216 @@ +"""Provider-agnostic implementation of the ComputerTool.""" + +import logging +import base64 +import io +from typing import Any, Dict + +from PIL import Image +from computer.computer import Computer + +from ....core.tools.computer import BaseComputerTool +from ....core.tools import ToolResult, ToolError + + +class OmniComputerTool(BaseComputerTool): + """A provider-agnostic implementation of the computer tool.""" + + name = "computer" + logger = logging.getLogger(__name__) + + def __init__(self, computer: Computer): + """Initialize the ComputerTool. + + Args: + computer: Computer instance for screen interactions + """ + super().__init__(computer) + # Initialize dimensions to None, will be set in initialize_dimensions + self.width = None + self.height = None + self.display_num = None + + def to_params(self) -> Dict[str, Any]: + """Convert tool to provider-agnostic parameters. + + Returns: + Dictionary with tool parameters + """ + return { + "name": self.name, + "description": "A tool that allows the agent to interact with the screen, keyboard, and mouse", + "parameters": { + "action": { + "type": "string", + "enum": [ + "key", + "type", + "mouse_move", + "left_click", + "left_click_drag", + "right_click", + "middle_click", + "double_click", + "screenshot", + "cursor_position", + "scroll", + ], + "description": "The action to perform on the computer", + }, + "text": { + "type": "string", + "description": "Text to type or key to press, required for 'key' and 'type' actions", + }, + "coordinate": { + "type": "array", + "items": {"type": "integer"}, + "description": "X,Y coordinates for mouse actions like click and move", + }, + "direction": { + "type": "string", + "enum": ["up", "down"], + "description": "Direction to scroll, used with the 'scroll' action", + }, + "amount": { + "type": "integer", + "description": "Amount to scroll, used with the 'scroll' action", + }, + }, + **self.options, + } + + async def __call__(self, **kwargs) -> ToolResult: + """Execute the computer tool with the provided arguments. + + Args: + action: The action to perform + text: Text to type or key to press (for key/type actions) + coordinate: X,Y coordinates (for mouse actions) + direction: Direction to scroll (for scroll action) + amount: Amount to scroll (for scroll action) + + Returns: + ToolResult with the action output and optional screenshot + """ + # Ensure dimensions are initialized + if self.width is None or self.height is None: + await self.initialize_dimensions() + + action = kwargs.get("action") + text = kwargs.get("text") + coordinate = kwargs.get("coordinate") + direction = kwargs.get("direction", "down") + amount = kwargs.get("amount", 10) + + self.logger.info(f"Executing computer action: {action}") + + try: + if action == "screenshot": + return await self.screenshot() + elif action == "left_click" and coordinate: + x, y = coordinate + self.logger.info(f"Clicking at ({x}, {y})") + await self.computer.interface.move_cursor(x, y) + await self.computer.interface.left_click() + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + screenshot = await self.resize_screenshot_if_needed(screenshot) + return ToolResult( + output=f"Performed left click at ({x}, {y})", + base64_image=base64.b64encode(screenshot).decode(), + ) + elif action == "right_click" and coordinate: + x, y = coordinate + self.logger.info(f"Right clicking at ({x}, {y})") + await self.computer.interface.move_cursor(x, y) + await self.computer.interface.right_click() + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + screenshot = await self.resize_screenshot_if_needed(screenshot) + return ToolResult( + output=f"Performed right click at ({x}, {y})", + base64_image=base64.b64encode(screenshot).decode(), + ) + elif action == "double_click" and coordinate: + x, y = coordinate + self.logger.info(f"Double clicking at ({x}, {y})") + await self.computer.interface.move_cursor(x, y) + await self.computer.interface.double_click() + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + screenshot = await self.resize_screenshot_if_needed(screenshot) + return ToolResult( + output=f"Performed double click at ({x}, {y})", + base64_image=base64.b64encode(screenshot).decode(), + ) + elif action == "mouse_move" and coordinate: + x, y = coordinate + self.logger.info(f"Moving cursor to ({x}, {y})") + await self.computer.interface.move_cursor(x, y) + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + screenshot = await self.resize_screenshot_if_needed(screenshot) + return ToolResult( + output=f"Moved cursor to ({x}, {y})", + base64_image=base64.b64encode(screenshot).decode(), + ) + elif action == "type" and text: + self.logger.info(f"Typing text: {text}") + await self.computer.interface.type_text(text) + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + screenshot = await self.resize_screenshot_if_needed(screenshot) + return ToolResult( + output=f"Typed text: {text}", + base64_image=base64.b64encode(screenshot).decode(), + ) + elif action == "key" and text: + self.logger.info(f"Pressing key: {text}") + + # Handle special key combinations + if "+" in text: + keys = text.split("+") + await self.computer.interface.hotkey(*keys) + else: + await self.computer.interface.press(text) + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + screenshot = await self.resize_screenshot_if_needed(screenshot) + return ToolResult( + output=f"Pressed key: {text}", + base64_image=base64.b64encode(screenshot).decode(), + ) + elif action == "cursor_position": + pos = await self.computer.interface.get_cursor_position() + return ToolResult(output=f"X={int(pos[0])},Y={int(pos[1])}") + elif action == "scroll": + if direction == "down": + self.logger.info(f"Scrolling down, amount: {amount}") + for _ in range(amount): + await self.computer.interface.hotkey("fn", "down") + else: + self.logger.info(f"Scrolling up, amount: {amount}") + for _ in range(amount): + await self.computer.interface.hotkey("fn", "up") + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + screenshot = await self.resize_screenshot_if_needed(screenshot) + return ToolResult( + output=f"Scrolled {direction} by {amount} steps", + base64_image=base64.b64encode(screenshot).decode(), + ) + + # Default to screenshot for unimplemented actions + self.logger.warning(f"Action {action} not fully implemented, taking screenshot") + return await self.screenshot() + + except Exception as e: + self.logger.error(f"Error during computer action: {str(e)}") + return ToolResult(error=f"Failed to perform {action}: {str(e)}") diff --git a/libs/agent/agent/providers/omni/tools/manager.py b/libs/agent/agent/providers/omni/tools/manager.py new file mode 100644 index 00000000..2e1152fb --- /dev/null +++ b/libs/agent/agent/providers/omni/tools/manager.py @@ -0,0 +1,83 @@ +"""Omni tool manager implementation.""" + +from typing import Dict, List, Any +from enum import Enum + +from computer.computer import Computer + +from ....core.tools import BaseToolManager +from ....core.tools.collection import ToolCollection + +from .bash import OmniBashTool +from .computer import OmniComputerTool +from .edit import OmniEditTool + + +class ProviderType(Enum): + """Supported provider types.""" + + ANTHROPIC = "anthropic" + OPENAI = "openai" + CLAUDE = "claude" # Alias for Anthropic + GPT = "gpt" # Alias for OpenAI + + +class OmniToolManager(BaseToolManager): + """Tool manager for multi-provider support.""" + + def __init__(self, computer: Computer): + """Initialize Omni tool manager. + + Args: + computer: Computer instance for tools + """ + super().__init__(computer) + # Initialize tools + self.computer_tool = OmniComputerTool(self.computer) + self.bash_tool = OmniBashTool(self.computer) + self.edit_tool = OmniEditTool(self.computer) + + def _initialize_tools(self) -> ToolCollection: + """Initialize all available tools.""" + return ToolCollection(self.computer_tool, self.bash_tool, self.edit_tool) + + async def _initialize_tools_specific(self) -> None: + """Initialize provider-specific tool requirements.""" + await self.computer_tool.initialize_dimensions() + + def get_tool_params(self) -> List[Dict[str, Any]]: + """Get tool parameters for API calls. + + Returns: + List of tool parameters in default format + """ + if self.tools is None: + raise RuntimeError("Tools not initialized. Call initialize() first.") + return self.tools.to_params() + + def get_provider_tools(self, provider: ProviderType) -> List[Dict[str, Any]]: + """Get tools formatted for a specific provider. + + Args: + provider: Provider type to format tools for + + Returns: + List of tool parameters in provider-specific format + """ + if self.tools is None: + raise RuntimeError("Tools not initialized. Call initialize() first.") + + # Default is the base implementation + tools = self.tools.to_params() + + # Customize for each provider if needed + if provider in [ProviderType.ANTHROPIC, ProviderType.CLAUDE]: + # Format for Anthropic API + # Additional adjustments can be made here + pass + elif provider in [ProviderType.OPENAI, ProviderType.GPT]: + # Format for OpenAI API + # Future implementation + pass + + return tools diff --git a/libs/agent/agent/providers/omni/types.py b/libs/agent/agent/providers/omni/types.py new file mode 100644 index 00000000..3c442ea5 --- /dev/null +++ b/libs/agent/agent/providers/omni/types.py @@ -0,0 +1,46 @@ +"""Type definitions for the Omni provider.""" + +from enum import StrEnum +from typing import Dict, Optional +from dataclasses import dataclass + + +class LLMProvider(StrEnum): + """Supported LLM providers.""" + + ANTHROPIC = "anthropic" + OPENAI = "openai" + + +LLMProvider + + +@dataclass +class LLM: + """Configuration for LLM model and provider.""" + + provider: LLMProvider + name: Optional[str] = None + + def __post_init__(self): + """Set default model name if not provided.""" + if self.name is None: + self.name = PROVIDER_TO_DEFAULT_MODEL.get(self.provider) + + +# For backward compatibility +LLMModel = LLM +Model = LLM + + +# Default models for each provider +PROVIDER_TO_DEFAULT_MODEL: Dict[LLMProvider, str] = { + LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219", + LLMProvider.OPENAI: "gpt-4o", +} + +# Environment variable names for each provider +PROVIDER_TO_ENV_VAR: Dict[LLMProvider, str] = { + LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY", + LLMProvider.OPENAI: "OPENAI_API_KEY", +} diff --git a/libs/agent/agent/providers/omni/utils.py b/libs/agent/agent/providers/omni/utils.py new file mode 100644 index 00000000..7513caf6 --- /dev/null +++ b/libs/agent/agent/providers/omni/utils.py @@ -0,0 +1,155 @@ +"""Utility functions for Omni provider.""" + +import base64 +import io +import logging +from typing import Tuple +from PIL import Image + +logger = logging.getLogger(__name__) + + +def compress_image_base64( + base64_str: str, max_size_bytes: int = 5 * 1024 * 1024, quality: int = 90 +) -> tuple[str, str]: + """Compress a base64 encoded image to ensure it's below a certain size. + + Args: + base64_str: Base64 encoded image string (with or without data URL prefix) + max_size_bytes: Maximum size in bytes (default: 5MB) + quality: Initial JPEG quality (0-100) + + Returns: + tuple[str, str]: (Compressed base64 encoded image, media_type) + """ + # Handle data URL prefix if present (e.g., "data:image/png;base64,...") + original_prefix = "" + media_type = "image/png" # Default media type + + if base64_str.startswith("data:"): + parts = base64_str.split(",", 1) + if len(parts) == 2: + original_prefix = parts[0] + "," + base64_str = parts[1] + # Try to extract media type from the prefix + if "image/jpeg" in original_prefix.lower(): + media_type = "image/jpeg" + elif "image/png" in original_prefix.lower(): + media_type = "image/png" + + # Check if the base64 string is small enough already + if len(base64_str) <= max_size_bytes: + logger.info(f"Image already within size limit: {len(base64_str)} bytes") + return original_prefix + base64_str, media_type + + try: + # Decode base64 + img_data = base64.b64decode(base64_str) + img_size = len(img_data) + logger.info(f"Original image size: {img_size} bytes") + + # Open image + img = Image.open(io.BytesIO(img_data)) + + # First, try to compress as PNG (maintains transparency if present) + buffer = io.BytesIO() + img.save(buffer, format="PNG", optimize=True) + buffer.seek(0) + compressed_data = buffer.getvalue() + compressed_b64 = base64.b64encode(compressed_data).decode("utf-8") + + if len(compressed_b64) <= max_size_bytes: + logger.info(f"Compressed to {len(compressed_data)} bytes as PNG") + return compressed_b64, "image/png" + + # Strategy 1: Try reducing quality with JPEG format + current_quality = quality + while current_quality > 20: + buffer = io.BytesIO() + # Convert to RGB if image has alpha channel (JPEG doesn't support transparency) + if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info): + logger.info("Converting transparent image to RGB for JPEG compression") + rgb_img = Image.new("RGB", img.size, (255, 255, 255)) + rgb_img.paste(img, mask=img.split()[3] if img.mode == "RGBA" else None) + rgb_img.save(buffer, format="JPEG", quality=current_quality, optimize=True) + else: + img.save(buffer, format="JPEG", quality=current_quality, optimize=True) + + buffer.seek(0) + compressed_data = buffer.getvalue() + compressed_b64 = base64.b64encode(compressed_data).decode("utf-8") + + if len(compressed_b64) <= max_size_bytes: + logger.info( + f"Compressed to {len(compressed_data)} bytes with JPEG quality {current_quality}" + ) + return compressed_b64, "image/jpeg" + + # Reduce quality and try again + current_quality -= 10 + + # Strategy 2: If quality reduction isn't enough, reduce dimensions + scale_factor = 0.8 + current_img = img + + while scale_factor > 0.3: + # Resize image + new_width = int(img.width * scale_factor) + new_height = int(img.height * scale_factor) + current_img = img.resize((new_width, new_height), Image.LANCZOS) + + # Try with reduced size and quality + buffer = io.BytesIO() + # Convert to RGB if necessary for JPEG + if current_img.mode in ("RGBA", "LA") or ( + current_img.mode == "P" and "transparency" in current_img.info + ): + rgb_img = Image.new("RGB", current_img.size, (255, 255, 255)) + rgb_img.paste( + current_img, mask=current_img.split()[3] if current_img.mode == "RGBA" else None + ) + rgb_img.save(buffer, format="JPEG", quality=70, optimize=True) + else: + current_img.save(buffer, format="JPEG", quality=70, optimize=True) + + buffer.seek(0) + compressed_data = buffer.getvalue() + compressed_b64 = base64.b64encode(compressed_data).decode("utf-8") + + if len(compressed_b64) <= max_size_bytes: + logger.info( + f"Compressed to {len(compressed_data)} bytes with scale {scale_factor} and JPEG quality 70" + ) + return compressed_b64, "image/jpeg" + + # Reduce scale factor and try again + scale_factor -= 0.1 + + # If we get here, we couldn't compress enough + logger.warning("Could not compress image below required size with quality preservation") + + # Last resort: Use minimum quality and size + buffer = io.BytesIO() + smallest_img = img.resize((int(img.width * 0.5), int(img.height * 0.5)), Image.LANCZOS) + # Convert to RGB if necessary + if smallest_img.mode in ("RGBA", "LA") or ( + smallest_img.mode == "P" and "transparency" in smallest_img.info + ): + rgb_img = Image.new("RGB", smallest_img.size, (255, 255, 255)) + rgb_img.paste( + smallest_img, mask=smallest_img.split()[3] if smallest_img.mode == "RGBA" else None + ) + rgb_img.save(buffer, format="JPEG", quality=20, optimize=True) + else: + smallest_img.save(buffer, format="JPEG", quality=20, optimize=True) + + buffer.seek(0) + final_data = buffer.getvalue() + final_b64 = base64.b64encode(final_data).decode("utf-8") + + logger.warning(f"Final compressed size: {len(final_b64)} bytes (may still exceed limit)") + return final_b64, "image/jpeg" + + except Exception as e: + logger.error(f"Error compressing image: {str(e)}") + raise diff --git a/libs/agent/agent/providers/omni/visualization.py b/libs/agent/agent/providers/omni/visualization.py new file mode 100644 index 00000000..5d856457 --- /dev/null +++ b/libs/agent/agent/providers/omni/visualization.py @@ -0,0 +1,130 @@ +"""Visualization utilities for the Cua provider.""" + +import base64 +import logging +from io import BytesIO +from typing import Tuple +from PIL import Image, ImageDraw + +logger = logging.getLogger(__name__) + + +def visualize_click(x: int, y: int, img_base64: str) -> Image.Image: + """Visualize a click action by drawing on the screenshot. + + Args: + x: X coordinate of the click + y: Y coordinate of the click + img_base64: Base64 encoded image to draw on + + Returns: + PIL Image with visualization + """ + try: + # Decode the base64 image + img_data = base64.b64decode(img_base64) + img = Image.open(BytesIO(img_data)) + + # Create a drawing context + draw = ImageDraw.Draw(img) + + # Draw concentric circles at the click position + small_radius = 10 + large_radius = 30 + + # Draw filled inner circle + draw.ellipse( + [(x - small_radius, y - small_radius), (x + small_radius, y + small_radius)], + fill="red", + ) + + # Draw outlined outer circle + draw.ellipse( + [(x - large_radius, y - large_radius), (x + large_radius, y + large_radius)], + outline="red", + width=3, + ) + + return img + + except Exception as e: + logger.error(f"Error visualizing click: {str(e)}") + # Return a blank image in case of error + return Image.new("RGB", (800, 600), color="white") + + +def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image: + """Visualize a scroll action by drawing arrows on the screenshot. + + Args: + direction: 'up' or 'down' + clicks: Number of scroll clicks + img_base64: Base64 encoded image to draw on + + Returns: + PIL Image with visualization + """ + try: + # Decode the base64 image + img_data = base64.b64decode(img_base64) + img = Image.open(BytesIO(img_data)) + + # Get image dimensions + width, height = img.size + + # Create a drawing context + draw = ImageDraw.Draw(img) + + # Determine arrow direction and positions + center_x = width // 2 + arrow_width = 100 + + if direction.lower() == "up": + # Draw up arrow in the middle of the screen + arrow_y = height // 2 + # Arrow points + points = [ + (center_x, arrow_y - 50), # Top point + (center_x - arrow_width // 2, arrow_y + 50), # Bottom left + (center_x + arrow_width // 2, arrow_y + 50), # Bottom right + ] + color = "blue" + else: # down + # Draw down arrow in the middle of the screen + arrow_y = height // 2 + # Arrow points + points = [ + (center_x, arrow_y + 50), # Bottom point + (center_x - arrow_width // 2, arrow_y - 50), # Top left + (center_x + arrow_width // 2, arrow_y - 50), # Top right + ] + color = "green" + + # Draw filled arrow + draw.polygon(points, fill=color) + + # Add text showing number of clicks + text_y = arrow_y + 70 if direction.lower() == "down" else arrow_y - 70 + draw.text((center_x - 40, text_y), f"{clicks} clicks", fill="black") + + return img + + except Exception as e: + logger.error(f"Error visualizing scroll: {str(e)}") + # Return a blank image in case of error + return Image.new("RGB", (800, 600), color="white") + + +def calculate_element_center(box: Tuple[int, int, int, int]) -> Tuple[int, int]: + """Calculate the center coordinates of a bounding box. + + Args: + box: Tuple of (left, top, right, bottom) coordinates + + Returns: + Tuple of (center_x, center_y) coordinates + """ + left, top, right, bottom = box + center_x = (left + right) // 2 + center_y = (top + bottom) // 2 + return center_x, center_y diff --git a/libs/agent/agent/telemetry.py b/libs/agent/agent/telemetry.py new file mode 100644 index 00000000..50017e8f --- /dev/null +++ b/libs/agent/agent/telemetry.py @@ -0,0 +1,21 @@ +"""Telemetry support for Agent class.""" + +import os +import platform +import sys +import time +from typing import Any, Dict, Optional + +from core.telemetry import ( + record_event, + is_telemetry_enabled, + flush, + get_telemetry_client, + increment, +) + +# System information used for telemetry +SYSTEM_INFO = { + "os": sys.platform, + "python_version": platform.python_version(), +} diff --git a/libs/agent/agent/types/__init__.py b/libs/agent/agent/types/__init__.py new file mode 100644 index 00000000..f42a6efc --- /dev/null +++ b/libs/agent/agent/types/__init__.py @@ -0,0 +1,26 @@ +"""Type definitions for the agent package.""" + +from .base import Provider, HostConfig, TaskResult, Annotation +from .messages import Message, Request, Response, StepMessage, DisengageMessage +from .tools import ToolInvocation, ToolInvocationState, ClientAttachment, ToolResult + +__all__ = [ + # Base types + "Provider", + "HostConfig", + "TaskResult", + "Annotation", + + # Message types + "Message", + "Request", + "Response", + "StepMessage", + "DisengageMessage", + + # Tool types + "ToolInvocation", + "ToolInvocationState", + "ClientAttachment", + "ToolResult", +] diff --git a/libs/agent/agent/types/base.py b/libs/agent/agent/types/base.py new file mode 100644 index 00000000..23cc9a7b --- /dev/null +++ b/libs/agent/agent/types/base.py @@ -0,0 +1,53 @@ +"""Base type definitions.""" + +from enum import Enum, auto +from typing import Dict, Any +from pydantic import BaseModel, ConfigDict + + +class Provider(str, Enum): + """Available AI providers.""" + + UNKNOWN = "unknown" # Default provider for base class + ANTHROPIC = "anthropic" + OPENAI = "openai" + OLLAMA = "ollama" + OMNI = "omni" + GROQ = "groq" + + +class HostConfig(BaseModel): + """Host configuration.""" + + model_config = ConfigDict(extra="forbid") + hostname: str + port: int + + @property + def address(self) -> str: + return f"{self.hostname}:{self.port}" + + +class TaskResult(BaseModel): + """Result of a task execution.""" + + model_config = ConfigDict(extra="forbid") + result: str + vnc_password: str + + +class Annotation(BaseModel): + """Annotation metadata.""" + + model_config = ConfigDict(extra="forbid") + id: str + vm_url: str + + +class AgentLoop(Enum): + """Enumeration of available loop types.""" + + ANTHROPIC = auto() # Anthropic implementation + OPENAI = auto() # OpenAI implementation + OMNI = auto() # OmniLoop implementation + # Add more loop types as needed diff --git a/libs/agent/agent/types/messages.py b/libs/agent/agent/types/messages.py new file mode 100644 index 00000000..ead23d99 --- /dev/null +++ b/libs/agent/agent/types/messages.py @@ -0,0 +1,36 @@ +"""Message-related type definitions.""" + +from typing import List, Dict, Any, Optional +from pydantic import BaseModel, ConfigDict + +from .tools import ToolInvocation + +class Message(BaseModel): + """Base message type.""" + model_config = ConfigDict(extra='forbid') + role: str + content: str + annotations: Optional[List[Dict[str, Any]]] = None + toolInvocations: Optional[List[ToolInvocation]] = None + data: Optional[List[Dict[str, Any]]] = None + errors: Optional[List[str]] = None + +class Request(BaseModel): + """Request type.""" + model_config = ConfigDict(extra='forbid') + messages: List[Message] + selectedModel: str + +class Response(BaseModel): + """Response type.""" + model_config = ConfigDict(extra='forbid') + messages: List[Message] + vm_url: str + +class StepMessage(Message): + """Message for a single step.""" + pass + +class DisengageMessage(BaseModel): + """Message indicating disengagement.""" + pass diff --git a/libs/agent/agent/types/tools.py b/libs/agent/agent/types/tools.py new file mode 100644 index 00000000..13b1f8de --- /dev/null +++ b/libs/agent/agent/types/tools.py @@ -0,0 +1,32 @@ +"""Tool-related type definitions.""" + +from enum import Enum +from typing import Dict, Any, Optional +from pydantic import BaseModel, ConfigDict + +class ToolInvocationState(str, Enum): + """States for tool invocation.""" + CALL = 'call' + PARTIAL_CALL = 'partial-call' + RESULT = 'result' + +class ToolInvocation(BaseModel): + """Tool invocation type.""" + model_config = ConfigDict(extra='forbid') + state: Optional[str] = None + toolCallId: str + toolName: Optional[str] = None + args: Optional[Dict[str, Any]] = None + +class ClientAttachment(BaseModel): + """Client attachment type.""" + name: str + contentType: str + url: str + +class ToolResult(BaseModel): + """Result of a tool execution.""" + model_config = ConfigDict(extra='forbid') + output: Optional[str] = None + error: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None diff --git a/libs/agent/poetry.toml b/libs/agent/poetry.toml new file mode 100644 index 00000000..ab1033bd --- /dev/null +++ b/libs/agent/poetry.toml @@ -0,0 +1,2 @@ +[virtualenvs] +in-project = true diff --git a/libs/agent/pyproject.toml b/libs/agent/pyproject.toml new file mode 100644 index 00000000..dc2dd04d --- /dev/null +++ b/libs/agent/pyproject.toml @@ -0,0 +1,103 @@ +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" + +[project] +name = "cua-agent" +version = "0.1.0" +description = "CUA (Computer Use) Agent for AI-driven computer interaction" +readme = "README.md" +authors = [ + { name = "TryCua", email = "gh@trycua.com" } +] +dependencies = [ + "httpx>=0.27.0,<0.29.0", + "aiohttp>=3.9.3,<4.0.0", + "asyncio", + "anyio>=4.4.1,<5.0.0", + "typing-extensions>=4.12.2,<5.0.0", + "pydantic>=2.6.4,<3.0.0", + "rich>=13.7.1,<14.0.0", + "python-dotenv>=1.0.1,<2.0.0", + "cua-computer>=0.1.0,<0.2.0", + "cua-core>=0.1.0,<0.2.0", + "certifi>=2024.2.2" +] +requires-python = ">=3.10,<3.13" + +[project.optional-dependencies] +anthropic = [ + "anthropic>=0.49.0", + "boto3>=1.35.81,<2.0.0", +] +som = [ + "torch>=2.2.1", + "torchvision>=0.17.1", + "ultralytics>=8.0.0", + "transformers>=4.38.2", + "cua-som>=0.1.0,<0.2.0", + # Include all provider dependencies + "anthropic>=0.46.0,<0.47.0", + "boto3>=1.35.81,<2.0.0", + "openai>=1.14.0,<2.0.0", + "groq>=0.4.0,<0.5.0", + "dashscope>=1.13.0,<2.0.0", + "requests>=2.31.0,<3.0.0" +] +all = [ + # Include all optional dependencies + "torch>=2.2.1", + "torchvision>=0.17.1", + "ultralytics>=8.0.0", + "transformers>=4.38.2", + "cua-som>=0.1.0,<0.2.0", + "anthropic>=0.46.0,<0.47.0", + "boto3>=1.35.81,<2.0.0", + "openai>=1.14.0,<2.0.0", + "groq>=0.4.0,<0.5.0", + "dashscope>=1.13.0,<2.0.0", + "requests>=2.31.0,<3.0.0" +] + +[tool.pdm] +distribution = true + +[tool.pdm.build] +includes = ["agent/"] +source-includes = ["tests/", "README.md", "LICENSE"] + +[tool.black] +line-length = 100 +target-version = ["py310"] + +[tool.ruff] +line-length = 100 +target-version = "py310" +select = ["E", "F", "B", "I"] +fix = true + +[tool.ruff.format] +docstring-code-format = true + +[tool.mypy] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = "test_*.py" +[dependency-groups] +cua-som = [ + "path", + "develop", + "optional", + "groups", + "anthropic>=0.49.0", +] diff --git a/libs/agent/tests/test_agent.py b/libs/agent/tests/test_agent.py new file mode 100644 index 00000000..3030bd1f --- /dev/null +++ b/libs/agent/tests/test_agent.py @@ -0,0 +1,91 @@ +# """Basic tests for the agent package.""" + +# import pytest +# from agent import OmniComputerAgent, LLMProvider +# from agent.base.agent import BaseComputerAgent +# from computer import Computer + +# def test_agent_import(): +# """Test that we can import the OmniComputerAgent class.""" +# assert OmniComputerAgent is not None +# assert LLMProvider is not None + +# def test_agent_init(): +# """Test that we can create an OmniComputerAgent instance.""" +# agent = OmniComputerAgent( +# provider=LLMProvider.OPENAI, +# use_host_computer_server=True +# ) +# assert agent is not None + +# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed") +# def test_computer_agent_anthropic(): +# """Test creating an Anthropic agent.""" +# agent = ComputerAgent(provider=Provider.ANTHROPIC) +# assert isinstance(agent._agent, BaseComputerAgent) + +# def test_computer_agent_invalid_provider(): +# """Test creating an agent with an invalid provider.""" +# with pytest.raises(ValueError, match="Unsupported provider"): +# ComputerAgent(provider="invalid_provider") + +# def test_computer_agent_uninstalled_provider(): +# """Test creating an agent with an uninstalled provider.""" +# with pytest.raises(NotImplementedError, match="OpenAI provider not yet implemented"): +# # OpenAI provider is not implemented yet +# ComputerAgent(provider=Provider.OPENAI) + +# @pytest.mark.asyncio +# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed") +# async def test_agent_cleanup(): +# """Test agent cleanup.""" +# agent = ComputerAgent(provider=Provider.ANTHROPIC) +# await agent.cleanup() # Should not raise any errors + +# @pytest.mark.asyncio +# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed") +# async def test_agent_direct_initialization(): +# """Test direct initialization of the agent.""" +# # Create with default computer +# agent = ComputerAgent(provider=Provider.ANTHROPIC) +# try: +# # Should not raise any errors +# await agent.run("test task") +# finally: +# await agent.cleanup() + +# # Create with custom computer +# custom_computer = Computer( +# display="1920x1080", +# memory="8GB", +# cpu="4", +# os="macos", +# use_host_computer_server=False, +# ) +# agent = ComputerAgent(provider=Provider.ANTHROPIC, computer=custom_computer) +# try: +# # Should not raise any errors +# await agent.run("test task") +# finally: +# await agent.cleanup() + +# @pytest.mark.asyncio +# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed") +# async def test_agent_context_manager(): +# """Test context manager initialization of the agent.""" +# # Test with default computer +# async with ComputerAgent(provider=Provider.ANTHROPIC) as agent: +# # Should not raise any errors +# await agent.run("test task") + +# # Test with custom computer +# custom_computer = Computer( +# display="1920x1080", +# memory="8GB", +# cpu="4", +# os="macos", +# use_host_computer_server=False, +# ) +# async with ComputerAgent(provider=Provider.ANTHROPIC, computer=custom_computer) as agent: +# # Should not raise any errors +# await agent.run("test task") diff --git a/libs/computer-server/README.md b/libs/computer-server/README.md new file mode 100644 index 00000000..54cd0f44 --- /dev/null +++ b/libs/computer-server/README.md @@ -0,0 +1,38 @@ +
+

+
+ + + + Shows my svg + +
+ + [![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#) + [![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#) + [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85) + [![PyPI](https://img.shields.io/pypi/v/cua-computer-server?color=333333)](https://pypi.org/project/cua-computer-server/) +

+
+ +**Computer Server** is the server component for the Computer-Use Interface (CUI) framework powering Cua for interacting with local macOS and Linux sandboxes, PyAutoGUI-compatible, and pluggable with any AI agent systems (Cua, Langchain, CrewAI, AutoGen). + +## Features + +- WebSocket API for computer-use +- Cross-platform support (macOS, Linux) +- Integration with CUA computer library for screen control, keyboard/mouse automation, and accessibility + +## Install + +To install the Computer-Use Interface (CUI): + +```bash +pip install cua-computer-server +``` + +## Run + +Refer to this notebook for a step-by-step guide on how to use the Computer-Use Server on the host system or VM: + +- [Computer-Use Server](../../notebooks/computer_server_nb.ipynb) \ No newline at end of file diff --git a/libs/computer-server/computer_server/__init__.py b/libs/computer-server/computer_server/__init__.py new file mode 100644 index 00000000..ef28cbb1 --- /dev/null +++ b/libs/computer-server/computer_server/__init__.py @@ -0,0 +1,20 @@ +""" +Computer API package. +Provides a server interface for the Computer API. +""" + +from __future__ import annotations + +__version__: str = "0.1.0" + +# Explicitly export Server for static type checkers +from .server import Server as Server # noqa: F401 + +__all__ = ["Server", "run_cli"] + + +def run_cli() -> None: + """Entry point for CLI""" + from .cli import main + + main() diff --git a/libs/computer-server/computer_server/cli.py b/libs/computer-server/computer_server/cli.py new file mode 100644 index 00000000..416e5e95 --- /dev/null +++ b/libs/computer-server/computer_server/cli.py @@ -0,0 +1,59 @@ +""" +Command-line interface for the Computer API server. +""" + +import argparse +import logging +import sys +from typing import List, Optional + +from .server import Server + +logger = logging.getLogger(__name__) + + +def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Start the Computer API server") + parser.add_argument( + "--host", default="0.0.0.0", help="Host to bind the server to (default: 0.0.0.0)" + ) + parser.add_argument( + "--port", type=int, default=8000, help="Port to bind the server to (default: 8000)" + ) + parser.add_argument( + "--log-level", + choices=["debug", "info", "warning", "error", "critical"], + default="info", + help="Logging level (default: info)", + ) + + return parser.parse_args(args) + + +def main() -> None: + """Main entry point for the CLI.""" + args = parse_args() + + # Configure logging + logging.basicConfig( + level=getattr(logging, args.log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + # Create and start the server + logger.info(f"Starting CUA Computer API server on {args.host}:{args.port}...") + server = Server(host=args.host, port=args.port, log_level=args.log_level) + + try: + server.start() + except KeyboardInterrupt: + logger.info("Server stopped by user") + sys.exit(0) + except Exception as e: + logger.error(f"Error starting server: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/libs/computer-server/computer_server/handlers/base.py b/libs/computer-server/computer_server/handlers/base.py new file mode 100644 index 00000000..818d367c --- /dev/null +++ b/libs/computer-server/computer_server/handlers/base.py @@ -0,0 +1,120 @@ +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any + +class BaseAccessibilityHandler(ABC): + """Abstract base class for OS-specific accessibility handlers.""" + + @abstractmethod + async def get_accessibility_tree(self) -> Dict[str, Any]: + """Get the accessibility tree of the current window.""" + pass + + @abstractmethod + async def find_element(self, role: Optional[str] = None, + title: Optional[str] = None, + value: Optional[str] = None) -> Dict[str, Any]: + """Find an element in the accessibility tree by criteria.""" + pass + +class BaseAutomationHandler(ABC): + """Abstract base class for OS-specific automation handlers. + + Categories: + - Mouse Actions: Methods for mouse control + - Keyboard Actions: Methods for keyboard input + - Scrolling Actions: Methods for scrolling + - Screen Actions: Methods for screen interaction + - Clipboard Actions: Methods for clipboard operations + """ + + # Mouse Actions + @abstractmethod + async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]: + """Perform a left click at the current or specified position.""" + pass + + @abstractmethod + async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]: + """Perform a right click at the current or specified position.""" + pass + + @abstractmethod + async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]: + """Perform a double click at the current or specified position.""" + pass + + @abstractmethod + async def move_cursor(self, x: int, y: int) -> Dict[str, Any]: + """Move the cursor to the specified position.""" + pass + + @abstractmethod + async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]: + """Drag the cursor from current position to specified coordinates. + + Args: + x: The x coordinate to drag to + y: The y coordinate to drag to + button: The mouse button to use ('left', 'middle', 'right') + duration: How long the drag should take in seconds + """ + pass + + # Keyboard Actions + @abstractmethod + async def type_text(self, text: str) -> Dict[str, Any]: + """Type the specified text.""" + pass + + @abstractmethod + async def press_key(self, key: str) -> Dict[str, Any]: + """Press the specified key.""" + pass + + @abstractmethod + async def hotkey(self, *keys: str) -> Dict[str, Any]: + """Press a combination of keys together.""" + pass + + # Scrolling Actions + @abstractmethod + async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]: + """Scroll down by the specified number of clicks.""" + pass + + @abstractmethod + async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]: + """Scroll up by the specified number of clicks.""" + pass + + # Screen Actions + @abstractmethod + async def screenshot(self) -> Dict[str, Any]: + """Take a screenshot and return base64 encoded image data.""" + pass + + @abstractmethod + async def get_screen_size(self) -> Dict[str, Any]: + """Get the screen size of the VM.""" + pass + + @abstractmethod + async def get_cursor_position(self) -> Dict[str, Any]: + """Get the current cursor position.""" + pass + + # Clipboard Actions + @abstractmethod + async def copy_to_clipboard(self) -> Dict[str, Any]: + """Get the current clipboard content.""" + pass + + @abstractmethod + async def set_clipboard(self, text: str) -> Dict[str, Any]: + """Set the clipboard content.""" + pass + + @abstractmethod + async def run_command(self, command: str) -> Dict[str, Any]: + """Run a command and return the output.""" + pass \ No newline at end of file diff --git a/libs/computer-server/computer_server/handlers/factory.py b/libs/computer-server/computer_server/handlers/factory.py new file mode 100644 index 00000000..3e21af11 --- /dev/null +++ b/libs/computer-server/computer_server/handlers/factory.py @@ -0,0 +1,49 @@ +import platform +import subprocess +from typing import Tuple, Type +from .base import BaseAccessibilityHandler, BaseAutomationHandler +from .macos import MacOSAccessibilityHandler, MacOSAutomationHandler +# from .linux import LinuxAccessibilityHandler, LinuxAutomationHandler + +class HandlerFactory: + """Factory for creating OS-specific handlers.""" + + @staticmethod + def _get_current_os() -> str: + """Determine the current OS. + + Returns: + str: The OS type ('darwin' for macOS or 'linux' for Linux) + + Raises: + RuntimeError: If unable to determine the current OS + """ + try: + # Use uname -s to determine OS since this runs on the target machine + result = subprocess.run(['uname', '-s'], capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"uname command failed: {result.stderr}") + return result.stdout.strip().lower() + except Exception as e: + raise RuntimeError(f"Failed to determine current OS: {str(e)}") + + @staticmethod + def create_handlers() -> Tuple[BaseAccessibilityHandler, BaseAutomationHandler]: + """Create and return appropriate handlers for the current OS. + + Returns: + Tuple[BaseAccessibilityHandler, BaseAutomationHandler]: A tuple containing + the appropriate accessibility and automation handlers for the current OS. + + Raises: + NotImplementedError: If the current OS is not supported + RuntimeError: If unable to determine the current OS + """ + os_type = HandlerFactory._get_current_os() + + if os_type == 'darwin': + return MacOSAccessibilityHandler(), MacOSAutomationHandler() + # elif os_type == 'linux': + # return LinuxAccessibilityHandler(), LinuxAutomationHandler() + else: + raise NotImplementedError(f"OS '{os_type}' is not supported") \ No newline at end of file diff --git a/libs/computer-server/computer_server/handlers/macos.py b/libs/computer-server/computer_server/handlers/macos.py new file mode 100644 index 00000000..024757fc --- /dev/null +++ b/libs/computer-server/computer_server/handlers/macos.py @@ -0,0 +1,654 @@ +import pyautogui +import base64 +from io import BytesIO +from typing import Optional, Dict, Any, List +from ctypes import byref, c_void_p, POINTER +from AppKit import NSWorkspace # type: ignore +import AppKit +from Quartz.CoreGraphics import * # type: ignore +from Quartz.CoreGraphics import CGPoint, CGSize # type: ignore +import Foundation +from ApplicationServices import ( + AXUIElementCreateSystemWide, # type: ignore + AXUIElementCreateApplication, # type: ignore + AXUIElementCopyAttributeValue, # type: ignore + AXUIElementCopyAttributeValues, # type: ignore + kAXFocusedWindowAttribute, # type: ignore + kAXWindowsAttribute, # type: ignore + kAXMainWindowAttribute, # type: ignore + kAXChildrenAttribute, # type: ignore + kAXRoleAttribute, # type: ignore + kAXTitleAttribute, # type: ignore + kAXValueAttribute, # type: ignore + kAXDescriptionAttribute, # type: ignore + kAXEnabledAttribute, # type: ignore + kAXPositionAttribute, # type: ignore + kAXSizeAttribute, # type: ignore + kAXErrorSuccess, # type: ignore + AXValueGetType, # type: ignore + kAXValueCGSizeType, # type: ignore + kAXValueCGPointType, # type: ignore + kAXValueCFRangeType, # type: ignore + AXUIElementGetTypeID, # type: ignore + AXValueGetValue, # type: ignore + kAXVisibleChildrenAttribute, # type: ignore + kAXRoleDescriptionAttribute, # type: ignore +) +import objc +import re +import json +import copy +from .base import BaseAccessibilityHandler, BaseAutomationHandler + + +def CFAttributeToPyObject(attrValue): + def list_helper(list_value): + list_builder = [] + for item in list_value: + list_builder.append(CFAttributeToPyObject(item)) + return list_builder + + def number_helper(number_value): + success, int_value = Foundation.CFNumberGetValue( # type: ignore + number_value, Foundation.kCFNumberIntType, None # type: ignore + ) + if success: + return int(int_value) + + success, float_value = Foundation.CFNumberGetValue( # type: ignore + number_value, Foundation.kCFNumberDoubleType, None # type: ignore + ) + if success: + return float(float_value) + return None + + def axuielement_helper(element_value): + return element_value + + cf_attr_type = Foundation.CFGetTypeID(attrValue) # type: ignore + cf_type_mapping = { + Foundation.CFStringGetTypeID(): str, # type: ignore + Foundation.CFBooleanGetTypeID(): bool, # type: ignore + Foundation.CFArrayGetTypeID(): list_helper, # type: ignore + Foundation.CFNumberGetTypeID(): number_helper, # type: ignore + AXUIElementGetTypeID(): axuielement_helper, # type: ignore + } + try: + return cf_type_mapping[cf_attr_type](attrValue) + except KeyError: + # did not get a supported CF type. Move on to AX type + pass + + ax_attr_type = AXValueGetType(attrValue) + ax_type_map = { + kAXValueCGSizeType: Foundation.NSSizeFromString, # type: ignore + kAXValueCGPointType: Foundation.NSPointFromString, # type: ignore + kAXValueCFRangeType: Foundation.NSRangeFromString, # type: ignore + } + try: + search_result = re.search("{.*}", attrValue.description()) + if search_result: + extracted_str = search_result.group() + return tuple(ax_type_map[ax_attr_type](extracted_str)) + return None + except KeyError: + return None + + +def element_attribute(element, attribute): + if attribute == kAXChildrenAttribute: + err, value = AXUIElementCopyAttributeValues(element, attribute, 0, 999, None) + if err == kAXErrorSuccess: + if isinstance(value, Foundation.NSArray): # type: ignore + return CFAttributeToPyObject(value) + else: + return value + err, value = AXUIElementCopyAttributeValue(element, attribute, None) + if err == kAXErrorSuccess: + if isinstance(value, Foundation.NSArray): # type: ignore + return CFAttributeToPyObject(value) + else: + return value + return None + + +def element_value(element, type): + err, value = AXValueGetValue(element, type, None) + if err == True: + return value + return None + + +class UIElement: + def __init__(self, element, offset_x=0, offset_y=0, max_depth=None, parents_visible_bbox=None): + self.ax_element = element + self.content_identifier = "" + self.identifier = "" + self.name = "" + self.children = [] + self.description = "" + self.role_description = "" + self.value = None + self.max_depth = max_depth + + # Set role + self.role = element_attribute(element, kAXRoleAttribute) + if self.role is None: + self.role = "No role" + + # Set name + self.name = element_attribute(element, kAXTitleAttribute) + if self.name is not None: + # Convert tuple to string if needed + if isinstance(self.name, tuple): + self.name = str(self.name[0]) if self.name else "" + self.name = self.name.replace(" ", "_") + + # Set enabled + self.enabled = element_attribute(element, kAXEnabledAttribute) + if self.enabled is None: + self.enabled = False + + # Set position and size + position = element_attribute(element, kAXPositionAttribute) + size = element_attribute(element, kAXSizeAttribute) + start_position = element_value(position, kAXValueCGPointType) + + if self.role == "AXWindow" and start_position is not None: + offset_x = start_position.x + offset_y = start_position.y + + self.absolute_position = copy.copy(start_position) + self.position = start_position + if self.position is not None: + self.position.x -= max(0, offset_x) + self.position.y -= max(0, offset_y) + self.size = element_value(size, kAXValueCGSizeType) + + self._set_bboxes(parents_visible_bbox) + + # Set component center + if start_position is None or self.size is None: + print("Position is None") + return + self.center = ( + start_position.x + offset_x + self.size.width / 2, + start_position.y + offset_y + self.size.height / 2, + ) + + self.description = element_attribute(element, kAXDescriptionAttribute) + self.role_description = element_attribute(element, kAXRoleDescriptionAttribute) + attribute_value = element_attribute(element, kAXValueAttribute) + + # Set value + self.value = attribute_value + if attribute_value is not None: + if isinstance(attribute_value, Foundation.NSArray): # type: ignore + self.value = [] + for value in attribute_value: + self.value.append(value) + # Check if it's an accessibility element by checking its type ID + elif Foundation.CFGetTypeID(attribute_value) == AXUIElementGetTypeID(): # type: ignore + self.value = UIElement(attribute_value, offset_x, offset_y) + + # Set children + if self.max_depth is None or self.max_depth > 0: + self.children = self._get_children(element, start_position, offset_x, offset_y) + else: + self.children = [] + + self.calculate_hashes() + + def _set_bboxes(self, parents_visible_bbox): + if not self.position or not self.size: + self.bbox = None + self.visible_bbox = None + return + self.bbox = [ + int(self.position.x), + int(self.position.y), + int(self.position.x + self.size.width), + int(self.position.y + self.size.height), + ] + if parents_visible_bbox: + # check if not intersected + if ( + self.bbox[0] > parents_visible_bbox[2] + or self.bbox[1] > parents_visible_bbox[3] + or self.bbox[2] < parents_visible_bbox[0] + or self.bbox[3] < parents_visible_bbox[1] + ): + self.visible_bbox = None + else: + self.visible_bbox = [ + int(max(self.bbox[0], parents_visible_bbox[0])), + int(max(self.bbox[1], parents_visible_bbox[1])), + int(min(self.bbox[2], parents_visible_bbox[2])), + int(min(self.bbox[3], parents_visible_bbox[3])), + ] + else: + self.visible_bbox = self.bbox + + def _get_children(self, element, start_position, offset_x, offset_y): + children = element_attribute(element, kAXChildrenAttribute) + visible_children = element_attribute(element, kAXVisibleChildrenAttribute) + found_children = [] + if children is not None: + found_children.extend(children) + else: + if visible_children is not None: + found_children.extend(visible_children) + + result = [] + if self.max_depth is None or self.max_depth > 0: + for child in found_children: + child = UIElement( + child, + offset_x, + offset_y, + self.max_depth - 1 if self.max_depth is not None else None, + self.visible_bbox, + ) + result.append(child) + return result + + def calculate_hashes(self): + self.identifier = self.component_hash() + self.content_identifier = self.children_content_hash(self.children) + + def component_hash(self): + if self.position is None or self.size is None: + return "" + position_string = f"{self.position.x:.0f};{self.position.y:.0f}" + size_string = f"{self.size.width:.0f};{self.size.height:.0f}" + enabled_string = str(self.enabled) + # Ensure role is a string + role_string = "" + if self.role is not None: + role_string = str(self.role[0]) if isinstance(self.role, tuple) else str(self.role) + return self.hash_from_string(position_string + size_string + enabled_string + role_string) + + def hash_from_string(self, string): + if string is None or string == "": + return "" + from hashlib import md5 + + return md5(string.encode()).hexdigest() + + def children_content_hash(self, children): + if len(children) == 0: + return "" + all_content_hashes = [] + all_hashes = [] + for child in children: + all_content_hashes.append(child.content_identifier) + all_hashes.append(child.identifier) + all_content_hashes.sort() + if len(all_content_hashes) == 0: + return "" + content_hash = self.hash_from_string("".join(all_content_hashes)) + content_structure_hash = self.hash_from_string("".join(all_hashes)) + return self.hash_from_string(content_hash.join(content_structure_hash)) + + def to_dict(self): + def children_to_dict(children): + result = [] + for child in children: + result.append(child.to_dict()) + return result + + value = self.value + if isinstance(value, UIElement): + value = json.dumps(value.to_dict(), indent=4) + elif isinstance(value, AppKit.NSDate): # type: ignore + value = str(value) + + if self.absolute_position is not None: + absolute_position = f"{self.absolute_position.x:.2f};{self.absolute_position.y:.2f}" + else: + absolute_position = "" + + if self.position is not None: + position = f"{self.position.x:.2f};{self.position.y:.2f}" + else: + position = "" + + if self.size is not None: + size = f"{self.size.width:.0f};{self.size.height:.0f}" + else: + size = "" + + return { + "id": self.identifier, + "name": self.name, + "role": self.role, + "description": self.description, + "role_description": self.role_description, + "value": value, + "absolute_position": absolute_position, + "position": position, + "size": size, + "enabled": self.enabled, + "bbox": self.bbox, + "visible_bbox": self.visible_bbox, + "children": children_to_dict(self.children), + } + + +class MacOSAccessibilityHandler(BaseAccessibilityHandler): + def get_application_windows(self, pid: int): + """Get all windows for a specific application.""" + try: + app = AXUIElementCreateApplication(pid) + err, windows = AXUIElementCopyAttributeValue(app, kAXWindowsAttribute, None) + if err == kAXErrorSuccess and windows: + if isinstance(windows, Foundation.NSArray): # type: ignore + return windows + return [] + except: + return [] + + def get_all_windows(self): + """Get all visible windows in the system.""" + try: + windows = [] + running_apps = NSWorkspace.sharedWorkspace().runningApplications() + + for app in running_apps: + try: + app_name = app.localizedName() + pid = app.processIdentifier() + + # Skip system processes and background apps + if not app.activationPolicy() == 0: # NSApplicationActivationPolicyRegular + continue + + # Get application windows + app_windows = self.get_application_windows(pid) + + windows.append( + { + "app_name": app_name, + "pid": pid, + "frontmost": app.isActive(), + "has_windows": len(app_windows) > 0, + "windows": app_windows, + } + ) + except: + continue + + return windows + except: + return [] + + def get_ax_attribute(self, element, attribute): + return element_attribute(element, attribute) + + def serialize_node(self, element): + # Create a serializable dictionary representation of an accessibility element + result = {} + + # Get basic attributes + result["role"] = self.get_ax_attribute(element, kAXRoleAttribute) + result["title"] = self.get_ax_attribute(element, kAXTitleAttribute) + result["value"] = self.get_ax_attribute(element, kAXValueAttribute) + + # Get position and size if available + position = self.get_ax_attribute(element, kAXPositionAttribute) + if position: + try: + position_dict = {"x": position[0], "y": position[1]} + result["position"] = position_dict + except (IndexError, TypeError): + pass + + size = self.get_ax_attribute(element, kAXSizeAttribute) + if size: + try: + size_dict = {"width": size[0], "height": size[1]} + result["size"] = size_dict + except (IndexError, TypeError): + pass + + return result + + async def get_accessibility_tree(self) -> Dict[str, Any]: + try: + # Get all visible windows first + windows = self.get_all_windows() + if not windows: + return {"success": False, "error": "No visible windows found in the system"} + + # Get the frontmost window + frontmost_app = next((w for w in windows if w["frontmost"]), None) + if not frontmost_app: + frontmost_app = windows[0] + + app_name = frontmost_app["app_name"] + + # Process all applications and their windows + processed_windows = [] + for app in windows: + app_windows = app.get("windows", []) + if app_windows: + window_trees = [] + for window in app_windows: + try: + window_element = UIElement(window) + window_trees.append(window_element.to_dict()) + except: + continue + + processed_windows.append( + { + "app_name": app["app_name"], + "pid": app["pid"], + "frontmost": app["frontmost"], + "has_windows": app["has_windows"], + "windows": window_trees, + } + ) + + if not any(app["windows"] for app in processed_windows): + return { + "success": False, + "error": f"No accessible windows found. Available applications:\n" + + "\n".join( + [ + f"- {w['app_name']} (PID: {w['pid']}, Active: {w['frontmost']}, Has Windows: {w['has_windows']})" + for w in windows + ] + ) + + "\nPlease ensure:\n" + + "1. The terminal has accessibility permissions\n" + + "2. The applications have visible windows\n" + + "3. Try clicking on a window you want to inspect", + } + + return { + "success": True, + "frontmost_application": app_name, + "windows": processed_windows, + } + + except Exception as e: + return {"success": False, "error": str(e)} + + async def find_element( + self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None + ) -> Dict[str, Any]: + try: + system = AXUIElementCreateSystemWide() + + def match_element(element): + if role and self.get_ax_attribute(element, kAXRoleAttribute) != role: + return False + if title and self.get_ax_attribute(element, kAXTitleAttribute) != title: + return False + if value and str(self.get_ax_attribute(element, kAXValueAttribute)) != value: + return False + return True + + def search_tree(element): + if match_element(element): + return self.serialize_node(element) + + children = self.get_ax_attribute(element, kAXChildrenAttribute) + if children: + for child in children: + result = search_tree(child) + if result: + return result + return None + + element = search_tree(system) + return {"success": True, "element": element} + + except Exception as e: + return {"success": False, "error": str(e)} + + +class MacOSAutomationHandler(BaseAutomationHandler): + # Mouse Actions + async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]: + try: + if x is not None and y is not None: + pyautogui.moveTo(x, y) + pyautogui.click() + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]: + try: + if x is not None and y is not None: + pyautogui.moveTo(x, y) + pyautogui.rightClick() + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + async def double_click( + self, x: Optional[int] = None, y: Optional[int] = None + ) -> Dict[str, Any]: + try: + if x is not None and y is not None: + pyautogui.moveTo(x, y) + pyautogui.doubleClick() + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + async def move_cursor(self, x: int, y: int) -> Dict[str, Any]: + try: + pyautogui.moveTo(x, y) + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + async def drag_to( + self, x: int, y: int, button: str = "left", duration: float = 0.5 + ) -> Dict[str, Any]: + try: + pyautogui.dragTo(x, y, button=button, duration=duration) + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + # Keyboard Actions + async def type_text(self, text: str) -> Dict[str, Any]: + try: + pyautogui.write(text) + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + async def press_key(self, key: str) -> Dict[str, Any]: + try: + pyautogui.press(key) + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + async def hotkey(self, keys: List[str]) -> Dict[str, Any]: + try: + pyautogui.hotkey(*keys) + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + # Scrolling Actions + async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]: + try: + pyautogui.scroll(-clicks) + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]: + try: + pyautogui.scroll(clicks) + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + # Screen Actions + async def screenshot(self) -> Dict[str, Any]: + try: + from PIL import Image + + screenshot = pyautogui.screenshot() + if not isinstance(screenshot, Image.Image): + return {"success": False, "error": "Failed to capture screenshot"} + + buffered = BytesIO() + screenshot.save(buffered, format="PNG", optimize=True) + buffered.seek(0) + image_data = base64.b64encode(buffered.getvalue()).decode() + return {"success": True, "image_data": image_data} + except Exception as e: + return {"success": False, "error": f"Screenshot error: {str(e)}"} + + async def get_screen_size(self) -> Dict[str, Any]: + try: + size = pyautogui.size() + return {"success": True, "size": {"width": size.width, "height": size.height}} + except Exception as e: + return {"success": False, "error": str(e)} + + async def get_cursor_position(self) -> Dict[str, Any]: + try: + pos = pyautogui.position() + return {"success": True, "position": {"x": pos.x, "y": pos.y}} + except Exception as e: + return {"success": False, "error": str(e)} + + # Clipboard Actions + async def copy_to_clipboard(self) -> Dict[str, Any]: + try: + import pyperclip + + content = pyperclip.paste() + return {"success": True, "content": content} + except Exception as e: + return {"success": False, "error": str(e)} + + async def set_clipboard(self, text: str) -> Dict[str, Any]: + try: + import pyperclip + + pyperclip.copy(text) + return {"success": True} + except Exception as e: + return {"success": False, "error": str(e)} + + async def run_command(self, command: str) -> Dict[str, Any]: + """Run a shell command and return its output.""" + try: + import subprocess + + process = subprocess.run(command, shell=True, capture_output=True, text=True) + return {"success": True, "stdout": process.stdout, "stderr": process.stderr} + except Exception as e: + return {"success": False, "error": str(e)} diff --git a/libs/computer-server/computer_server/main.py b/libs/computer-server/computer_server/main.py new file mode 100644 index 00000000..c95918d8 --- /dev/null +++ b/libs/computer-server/computer_server/main.py @@ -0,0 +1,123 @@ +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from typing import List, Dict, Any +import uvicorn +import logging +import asyncio +import json +import traceback +from contextlib import redirect_stdout, redirect_stderr +from io import StringIO +from .handlers.factory import HandlerFactory + +# Set up logging with more detail +logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# Configure WebSocket with larger message size +WEBSOCKET_MAX_SIZE = 1024 * 1024 * 10 # 10MB limit + +# Configure application with WebSocket settings +app = FastAPI( + title="Computer API", + description="API for the Computer project", + version="0.1.0", + websocket_max_size=WEBSOCKET_MAX_SIZE, +) + + +class ConnectionManager: + def __init__(self): + self.active_connections: List[WebSocket] = [] + # Create OS-specific handlers + self.accessibility_handler, self.automation_handler = HandlerFactory.create_handlers() + + async def connect(self, websocket: WebSocket): + await websocket.accept() + self.active_connections.append(websocket) + + def disconnect(self, websocket: WebSocket): + self.active_connections.remove(websocket) + + +manager = ConnectionManager() + + +@app.websocket("/ws", name="websocket_endpoint") +async def websocket_endpoint(websocket: WebSocket): + # WebSocket message size is configured at the app or endpoint level, not on the instance + await manager.connect(websocket) + + # Map commands to appropriate handler methods + handlers = { + # Accessibility commands + "get_accessibility_tree": manager.accessibility_handler.get_accessibility_tree, + "find_element": manager.accessibility_handler.find_element, + # Automation commands + "screenshot": manager.automation_handler.screenshot, + "left_click": manager.automation_handler.left_click, + "right_click": manager.automation_handler.right_click, + "double_click": manager.automation_handler.double_click, + "scroll_down": manager.automation_handler.scroll_down, + "scroll_up": manager.automation_handler.scroll_up, + "move_cursor": manager.automation_handler.move_cursor, + "type_text": manager.automation_handler.type_text, + "press_key": manager.automation_handler.press_key, + "drag_to": manager.automation_handler.drag_to, + "hotkey": manager.automation_handler.hotkey, + "get_cursor_position": manager.automation_handler.get_cursor_position, + "get_screen_size": manager.automation_handler.get_screen_size, + "copy_to_clipboard": manager.automation_handler.copy_to_clipboard, + "set_clipboard": manager.automation_handler.set_clipboard, + "run_command": manager.automation_handler.run_command, + } + + try: + while True: + try: + data = await websocket.receive_json() + command = data.get("command") + params = data.get("params", {}) + + if command not in handlers: + await websocket.send_json( + {"success": False, "error": f"Unknown command: {command}"} + ) + continue + + try: + result = await handlers[command](**params) + await websocket.send_json({"success": True, **result}) + except Exception as cmd_error: + logger.error(f"Error executing command {command}: {str(cmd_error)}") + logger.error(traceback.format_exc()) + await websocket.send_json({"success": False, "error": str(cmd_error)}) + + except WebSocketDisconnect: + raise + except json.JSONDecodeError as json_err: + logger.error(f"JSON decode error: {str(json_err)}") + await websocket.send_json( + {"success": False, "error": f"Invalid JSON: {str(json_err)}"} + ) + except Exception as loop_error: + logger.error(f"Error in message loop: {str(loop_error)}") + logger.error(traceback.format_exc()) + await websocket.send_json({"success": False, "error": str(loop_error)}) + + except WebSocketDisconnect: + logger.info("Client disconnected") + manager.disconnect(websocket) + except Exception as e: + logger.error(f"Fatal error in websocket connection: {str(e)}") + logger.error(traceback.format_exc()) + try: + await websocket.close() + except: + pass + manager.disconnect(websocket) + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/libs/computer-server/computer_server/server.py b/libs/computer-server/computer_server/server.py new file mode 100644 index 00000000..2a3d4340 --- /dev/null +++ b/libs/computer-server/computer_server/server.py @@ -0,0 +1,93 @@ +""" +Server interface for Computer API. +Provides a clean API for starting and stopping the server. +""" + +import asyncio +import logging +import uvicorn +from typing import Optional +from fastapi import FastAPI + +from .main import app as fastapi_app + +logger = logging.getLogger(__name__) + + +class Server: + """ + Server interface for Computer API. + + Usage: + from computer_api import Server + + # Synchronous usage + server = Server() + server.start() # Blocks until server is stopped + + # Asynchronous usage + server = Server() + await server.start_async() # Starts server in background + # Do other things + await server.stop() # Stop the server + """ + + def __init__(self, host: str = "0.0.0.0", port: int = 8000, log_level: str = "info"): + """ + Initialize the server. + + Args: + host: Host to bind the server to + port: Port to bind the server to + log_level: Logging level (debug, info, warning, error, critical) + """ + self.host = host + self.port = port + self.log_level = log_level + self.app = fastapi_app + self._server_task: Optional[asyncio.Task] = None + self._should_exit = asyncio.Event() + + def start(self) -> None: + """ + Start the server synchronously. This will block until the server is stopped. + """ + uvicorn.run(self.app, host=self.host, port=self.port, log_level=self.log_level) + + async def start_async(self) -> None: + """ + Start the server asynchronously. This will return immediately and the server + will run in the background. + """ + server_config = uvicorn.Config( + self.app, host=self.host, port=self.port, log_level=self.log_level + ) + + self._should_exit.clear() + server = uvicorn.Server(server_config) + + # Create a task to run the server + self._server_task = asyncio.create_task(server.serve()) + + # Wait a short time to ensure the server starts + await asyncio.sleep(0.5) + + logger.info(f"Server started at http://{self.host}:{self.port}") + + async def stop(self) -> None: + """ + Stop the server if it's running asynchronously. + """ + if self._server_task and not self._server_task.done(): + # Signal the server to exit + self._should_exit.set() + + # Cancel the server task + self._server_task.cancel() + + try: + await self._server_task + except asyncio.CancelledError: + logger.info("Server stopped") + + self._server_task = None diff --git a/libs/computer-server/examples/__init__.py b/libs/computer-server/examples/__init__.py new file mode 100644 index 00000000..918bf6ad --- /dev/null +++ b/libs/computer-server/examples/__init__.py @@ -0,0 +1,3 @@ +""" +Examples package for the CUA Computer API. +""" diff --git a/libs/computer-server/examples/usage_example.py b/libs/computer-server/examples/usage_example.py new file mode 100644 index 00000000..46d02e53 --- /dev/null +++ b/libs/computer-server/examples/usage_example.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +""" +Example showing how to use the CUA Computer API as an imported package. +""" + +import asyncio +import logging +from typing import TYPE_CHECKING + +# For type checking only +if TYPE_CHECKING: + from computer_api import Server + +# Setup logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +# Example 1: Synchronous usage (blocks until server is stopped) +def example_sync(): + """ + Example of synchronous server usage. This will block until interrupted. + Run with: python3 -m examples.usage_example sync + """ + # Import directly to avoid any confusion + from computer_api.server import Server + + server = Server(port=8080) + print("Server started at http://localhost:8080") + print("Press Ctrl+C to stop the server") + + try: + server.start() # This will block until the server is stopped + except KeyboardInterrupt: + print("Server stopped by user") + + +# Example 2: Asynchronous usage +async def example_async(): + """ + Example of asynchronous server usage. This will start the server in the background + and allow other operations to run concurrently. + Run with: python3 -m examples.usage_example async + """ + # Import directly to avoid any confusion + from computer_api.server import Server + + server = Server(port=8080) + + # Start the server in the background + await server.start_async() + + print("Server is running in the background") + print("Performing other tasks...") + + # Do other things while the server is running + for i in range(5): + print(f"Doing work iteration {i+1}/5...") + await asyncio.sleep(2) + + print("Work complete, stopping server...") + + # Stop the server when done + await server.stop() + print("Server stopped") + + +if __name__ == "__main__": + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "async": + asyncio.run(example_async()) + else: + example_sync() diff --git a/libs/computer-server/pyproject.toml b/libs/computer-server/pyproject.toml new file mode 100644 index 00000000..ac8e49ff --- /dev/null +++ b/libs/computer-server/pyproject.toml @@ -0,0 +1,75 @@ +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" + +[project] +name = "cua-computer-server" +version = "0.1.0" +description = "Server component for the Computer-Use Interface (CUI) framework powering Cua" +authors = [ + { name = "TryCua", email = "gh@trycua.com" } +] +dependencies = [ + "fastapi>=0.111.0", + "uvicorn[standard]>=0.27.0", + "pydantic>=2.0.0", + "pyautogui>=0.9.54", + "pyobjc-framework-Cocoa>=10.1; sys_platform == 'darwin'", + "pyobjc-framework-Quartz>=10.1; sys_platform == 'darwin'", + "pyobjc-framework-ApplicationServices>=10.1; sys_platform == 'darwin'", + "python-xlib>=0.33; sys_platform == 'linux'", + "pillow>=10.2.0" +] +requires-python = ">=3.10,<3.13" +readme = "README.md" +license = { text = "MIT" } + +[project.urls] +homepage = "https://github.com/trycua/cua" +repository = "https://github.com/trycua/cua" + +[project.scripts] +cua-computer-server = "computer_server:run_cli" + +[tool.pdm] +distribution = true + +[tool.pdm.build] +includes = ["computer_server"] +package-data = {"computer_server" = ["py.typed"]} + +[tool.pdm.dev-dependencies] +test = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.23.0" +] +format = [ + "black>=23.0.0", + "isort>=5.12.0" +] + +[tool.pdm.scripts] +api = "python -m computer_server" + +[tool.black] +line-length = 100 +target-version = ["py310"] + +[tool.ruff] +line-length = 100 +target-version = "py310" +select = ["E", "F", "B", "I"] +fix = true + +[tool.ruff.format] +docstring-code-format = true + +[tool.mypy] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false \ No newline at end of file diff --git a/libs/computer-server/run_server.py b/libs/computer-server/run_server.py new file mode 100755 index 00000000..1818caa1 --- /dev/null +++ b/libs/computer-server/run_server.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +""" +Entrypoint script for the Computer Server. + +This script provides a simple way to start the Computer Server from the command line +or using a launch configuration in an IDE. + +Usage: + python run_server.py [--host HOST] [--port PORT] [--log-level LEVEL] +""" + +import sys +from computer_server.cli import main + +if __name__ == "__main__": + sys.exit(main()) diff --git a/libs/computer-server/test_connection.py b/libs/computer-server/test_connection.py new file mode 100755 index 00000000..dee73e8d --- /dev/null +++ b/libs/computer-server/test_connection.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +""" +Connection test script for Computer Server. + +This script tests the WebSocket connection to the Computer Server and keeps +it alive, allowing you to verify the server is running correctly. +""" + +import asyncio +import json +import websockets +import argparse +import sys + + +async def test_connection(host="localhost", port=8000, keep_alive=False): + """Test connection to the Computer Server.""" + uri = f"ws://{host}:{port}/ws" + print(f"Connecting to {uri}...") + + try: + async with websockets.connect(uri) as websocket: + print("Connection established!") + + # Send a test command to get screen size + await websocket.send(json.dumps({"command": "get_screen_size", "params": {}})) + response = await websocket.recv() + print(f"Response: {response}") + + if keep_alive: + print("\nKeeping connection alive. Press Ctrl+C to exit...") + while True: + # Send a command every 5 seconds to keep the connection alive + await asyncio.sleep(5) + await websocket.send( + json.dumps({"command": "get_cursor_position", "params": {}}) + ) + response = await websocket.recv() + print(f"Cursor position: {response}") + except websockets.exceptions.ConnectionClosed as e: + print(f"Connection closed: {e}") + return False + except ConnectionRefusedError: + print(f"Connection refused. Is the server running at {host}:{port}?") + return False + except Exception as e: + print(f"Error: {e}") + return False + + return True + + +def parse_args(): + parser = argparse.ArgumentParser(description="Test connection to Computer Server") + parser.add_argument("--host", default="localhost", help="Host address (default: localhost)") + parser.add_argument("--port", type=int, default=8000, help="Port number (default: 8000)") + parser.add_argument("--keep-alive", action="store_true", help="Keep connection alive") + return parser.parse_args() + + +async def main(): + args = parse_args() + success = await test_connection(args.host, args.port, args.keep_alive) + return 0 if success else 1 + + +if __name__ == "__main__": + try: + sys.exit(asyncio.run(main())) + except KeyboardInterrupt: + print("\nExiting...") + sys.exit(0) diff --git a/libs/computer/README.md b/libs/computer/README.md new file mode 100644 index 00000000..38247b35 --- /dev/null +++ b/libs/computer/README.md @@ -0,0 +1,66 @@ +
+

+
+ + + + Shows my svg + +
+ + [![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#) + [![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#) + [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85) + [![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/) +

+
+ +**Computer** is a Computer-Use Interface (CUI) framework powering Cua for interacting with local macOS and Linux sandboxes, PyAutoGUI-compatible, and pluggable with any AI agent systems (Cua, Langchain, CrewAI, AutoGen). Computer relies on [Lume](https://github.com/trycua/lume) for creating and managing sandbox environments. + +### Get started with Computer + +
+ +
+ +```python +from computer import Computer + +computer = Computer(os="macos", display="1024x768", memory="8GB", cpu="4") +try: + await computer.run() + + screenshot = await computer.interface.screenshot() + with open("screenshot.png", "wb") as f: + f.write(screenshot) + + await computer.interface.move_cursor(100, 100) + await computer.interface.left_click() + await computer.interface.right_click(300, 300) + await computer.interface.double_click(400, 400) + + await computer.interface.type("Hello, World!") + await computer.interface.press_key("enter") + + await computer.interface.set_clipboard("Test clipboard") + content = await computer.interface.copy_to_clipboard() + print(f"Clipboard content: {content}") +finally: + await computer.stop() +``` + +## Install + +To install the Computer-Use Interface (CUI): + +```bash +pip install cua-computer +``` + +The `cua-computer` PyPi package pulls automatically the latest executable version of Lume through [pylume](https://github.com/trycua/pylume). + +## Run + +Refer to this notebook for a step-by-step guide on how to use the Computer-Use Interface (CUI): + +- [Computer-Use Interface (CUI)](../../notebooks/computer_nb.ipynb) \ No newline at end of file diff --git a/libs/computer/computer/__init__.py b/libs/computer/computer/__init__.py new file mode 100644 index 00000000..e50f5983 --- /dev/null +++ b/libs/computer/computer/__init__.py @@ -0,0 +1,47 @@ +"""CUA Computer Interface for cross-platform computer control.""" + +import logging +import sys + +__version__ = "0.1.0" + +# Initialize logging +logger = logging.getLogger("cua.computer") + +# Initialize telemetry when the package is imported +try: + # Import from core telemetry + from core.telemetry import ( + is_telemetry_enabled, + flush, + record_event, + ) + + # Check if telemetry is enabled + if is_telemetry_enabled(): + logger.info("Telemetry is enabled") + + # Record package initialization + record_event( + "module_init", + { + "module": "computer", + "version": __version__, + "python_version": sys.version, + }, + ) + + # Flush events to ensure they're sent + flush() + else: + logger.info("Telemetry is disabled") +except ImportError as e: + # Telemetry not available + logger.warning(f"Telemetry not available: {e}") +except Exception as e: + # Other issues with telemetry + logger.warning(f"Error initializing telemetry: {e}") + +from .computer import Computer + +__all__ = ["Computer"] diff --git a/libs/computer/computer/computer.py b/libs/computer/computer/computer.py new file mode 100644 index 00000000..8d86b672 --- /dev/null +++ b/libs/computer/computer/computer.py @@ -0,0 +1,515 @@ +from typing import Optional, List, Literal, Dict, Any, Union, TYPE_CHECKING, cast +from pylume import PyLume +from pylume.models import VMRunOpts, VMUpdateOpts, ImageRef, SharedDirectory +import asyncio +from .models import Computer as ComputerConfig, Display +from .interface.factory import InterfaceFactory +import time +from PIL import Image +import io +from .utils import bytes_to_image +import re +from .logger import Logger, LogLevel +import json +import logging +from .telemetry import record_computer_initialization + +OSType = Literal["macos", "linux"] + +# Import BaseComputerInterface for type annotations +if TYPE_CHECKING: + from .interface.base import BaseComputerInterface + + +class Computer: + """Computer is the main class for interacting with the computer.""" + + def __init__( + self, + display: Union[Display, Dict[str, int], str] = "1024x768", + memory: str = "8GB", + cpu: str = "4", + os: OSType = "macos", + name: str = "", + image: str = "macos-sequoia-cua:latest", + shared_directories: Optional[List[str]] = None, + use_host_computer_server: bool = False, + verbosity: Union[int, LogLevel] = logging.INFO, + telemetry_enabled: bool = True, + ): + """Initialize a new Computer instance. + + Args: + display: The display configuration. Can be: + - A Display object + - A dict with 'width' and 'height' + - A string in format "WIDTHxHEIGHT" (e.g. "1920x1080") + Defaults to "1024x768" + memory: The VM memory allocation. Defaults to "8GB" + cpu: The VM CPU allocation. Defaults to "4" + os: The operating system type ('macos' or 'linux') + name: The VM name + image: The VM image name + shared_directories: Optional list of directory paths to share with the VM + use_host_computer_server: If True, target localhost instead of starting a VM + verbosity: Logging level (standard Python logging levels: logging.DEBUG, logging.INFO, etc.) + LogLevel enum values are still accepted for backward compatibility + telemetry_enabled: Whether to enable telemetry tracking. Defaults to True. + """ + if TYPE_CHECKING: + from .interface.base import BaseComputerInterface + + self.logger = Logger("cua.computer", verbosity) + self.logger.info("Initializing Computer...") + + # Store original parameters + self.image = image + + # Store telemetry preference + self._telemetry_enabled = telemetry_enabled + + # Set initialization flag + self._initialized = False + self._running = False + + # Configure root logger + self.verbosity = verbosity + self.logger = Logger("cua", verbosity) + + # Configure component loggers with proper hierarchy + self.vm_logger = Logger("cua.vm", verbosity) + self.interface_logger = Logger("cua.interface", verbosity) + + if not use_host_computer_server: + if ":" not in image or len(image.split(":")) != 2: + raise ValueError("Image must be in the format :") + + if not name: + # Normalize the name to be used for the VM + name = image.replace(":", "_") + + # Convert display parameter to Display object + if isinstance(display, str): + # Parse string format "WIDTHxHEIGHT" + match = re.match(r"(\d+)x(\d+)", display) + if not match: + raise ValueError( + "Display string must be in format 'WIDTHxHEIGHT' (e.g. '1024x768')" + ) + width, height = map(int, match.groups()) + display_config = Display(width=width, height=height) + elif isinstance(display, dict): + display_config = Display(**display) + else: + display_config = display + + self.config = ComputerConfig( + image=image.split(":")[0], + tag=image.split(":")[1], + name=name, + display=display_config, + memory=memory, + cpu=cpu, + ) + # Initialize PyLume but don't start the server yet - we'll do that in run() + self.config.pylume = PyLume( + debug=(self.verbosity == LogLevel.DEBUG), + port=3000, + use_existing_server=False, + server_start_timeout=120, # Increase timeout to 2 minutes + ) + + # Initialize with proper typing - None at first, will be set in run() + self._interface = None + self.os = os + self.shared_paths = [] + if shared_directories: + for path in shared_directories: + abs_path = os.path.abspath(os.path.expanduser(path)) # type: ignore[attr-defined] + if not os.path.exists(abs_path): # type: ignore[attr-defined] + raise ValueError(f"Shared directory does not exist: {path}") + self.shared_paths.append(abs_path) + self._pylume_context = None + self.use_host_computer_server = use_host_computer_server + + # Record initialization in telemetry (if enabled) + if telemetry_enabled: + record_computer_initialization() + else: + self.logger.debug("Telemetry disabled - skipping initialization tracking") + + async def __aenter__(self): + """Enter async context manager.""" + await self.run() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit async context manager.""" + pass + + async def run(self) -> None: + """Initialize the VM and computer interface.""" + if TYPE_CHECKING: + from .interface.base import BaseComputerInterface + + # If already initialized, just log and return + if hasattr(self, "_initialized") and self._initialized: + self.logger.info("Computer already initialized, skipping initialization") + return + + self.logger.info("Starting computer...") + start_time = time.time() + + try: + # If using host computer server + if self.use_host_computer_server: + self.logger.info("Using host computer server") + # Set ip_address for host computer server mode + ip_address = "localhost" + # Create the interface with explicit type annotation + from .interface.base import BaseComputerInterface + + self._interface = cast( + BaseComputerInterface, + InterfaceFactory.create_interface_for_os( + os=self.os, ip_address=ip_address # type: ignore[arg-type] + ), + ) + + self.logger.info("Waiting for host computer server to be ready...") + await self._interface.wait_for_ready() + self.logger.info("Host computer server ready") + else: + # Start or connect to VM + self.logger.info(f"Starting VM: {self.image}") + if not self._pylume_context: + try: + self.logger.verbose("Initializing PyLume context...") + self._pylume_context = await self.config.pylume.__aenter__() # type: ignore[attr-defined] + self.logger.verbose("PyLume context initialized successfully") + except Exception as e: + self.logger.error(f"Failed to initialize PyLume context: {e}") + raise RuntimeError(f"Failed to initialize PyLume: {e}") + + # Try to get the VM, if it doesn't exist, create it and pull the image + try: + vm = await self.config.pylume.get_vm(self.config.name) # type: ignore[attr-defined] + self.logger.verbose(f"Found existing VM: {self.config.name}") + except Exception as e: + self.logger.verbose(f"VM not found, pulling image: {e}") + image_ref = ImageRef( + image=self.config.image, + tag=self.config.tag, + registry="ghcr.io", + organization="trycua", + ) + self.logger.info(f"Pulling image {self.config.image}:{self.config.tag}...") + try: + await self.config.pylume.pull_image(image_ref, name=self.config.name) # type: ignore[attr-defined] + except Exception as pull_error: + self.logger.error(f"Failed to pull image: {pull_error}") + raise RuntimeError(f"Failed to pull VM image: {pull_error}") + + # Convert paths to SharedDirectory objects + shared_directories = [] + for path in self.shared_paths: + self.logger.verbose(f"Adding shared directory: {path}") + shared_directories.append( + SharedDirectory(host_path=path) # type: ignore[arg-type] + ) + + # Run with shared directories + self.logger.info(f"Starting VM {self.config.name}...") + run_opts = VMRunOpts( + no_display=False, # type: ignore[arg-type] + shared_directories=shared_directories, # type: ignore[arg-type] + ) + + # Log the run options for debugging + self.logger.info(f"VM run options: {vars(run_opts)}") + + # Log the equivalent curl command for debugging + payload = json.dumps({"noDisplay": False, "sharedDirectories": []}) + curl_cmd = f"curl -X POST 'http://localhost:3000/lume/vms/{self.config.name}/run' -H 'Content-Type: application/json' -d '{payload}'" + print(f"\nEquivalent curl command:\n{curl_cmd}\n") + + try: + response = await self.config.pylume.run_vm(self.config.name, run_opts) # type: ignore[attr-defined] + self.logger.info(f"VM run response: {response if response else 'None'}") + except Exception as run_error: + self.logger.error(f"Failed to run VM: {run_error}") + raise RuntimeError(f"Failed to start VM: {run_error}") + + # Wait for VM to be ready with required properties + self.logger.info("Waiting for VM to be ready...") + try: + vm = await self.wait_vm_ready() + if not vm or not vm.ip_address: # type: ignore[attr-defined] + raise RuntimeError(f"VM {self.config.name} failed to get IP address") + ip_address = vm.ip_address # type: ignore[attr-defined] + self.logger.info(f"VM is ready with IP: {ip_address}") + except Exception as wait_error: + self.logger.error(f"Error waiting for VM: {wait_error}") + raise RuntimeError(f"VM failed to become ready: {wait_error}") + except Exception as e: + self.logger.error(f"Failed to initialize computer: {e}") + raise RuntimeError(f"Failed to initialize computer: {e}") + + try: + # Initialize the interface using the factory with the specified OS + self.logger.info(f"Initializing interface for {self.os} at {ip_address}") + from .interface.base import BaseComputerInterface + + self._interface = cast( + BaseComputerInterface, + InterfaceFactory.create_interface_for_os( + os=self.os, ip_address=ip_address # type: ignore[arg-type] + ), + ) + + # Wait for the WebSocket interface to be ready + self.logger.info("Connecting to WebSocket interface...") + + try: + # Use a single timeout for the entire connection process + await self._interface.wait_for_ready(timeout=60) + self.logger.info("WebSocket interface connected successfully") + except TimeoutError as e: + self.logger.error("Failed to connect to WebSocket interface") + raise TimeoutError( + f"Could not connect to WebSocket interface at {ip_address}:8000/ws: {str(e)}" + ) + + # Create an event to keep the VM running in background if needed + if not self.use_host_computer_server: + self._stop_event = asyncio.Event() + self._keep_alive_task = asyncio.create_task(self._stop_event.wait()) + + self.logger.info("Computer is ready") + + # Set the initialization flag and clear the initializing flag + self._initialized = True + self.logger.info("Computer successfully initialized") + except Exception as e: + raise + finally: + # Log initialization time for performance monitoring + duration_ms = (time.time() - start_time) * 1000 + self.logger.debug(f"Computer initialization took {duration_ms:.2f}ms") + return + + async def stop(self) -> None: + """Stop computer control.""" + start_time = time.time() + + try: + if self._running: + self._running = False + self.logger.info("Stopping Computer...") + + if hasattr(self, "_stop_event"): + self._stop_event.set() + if hasattr(self, "_keep_alive_task"): + await self._keep_alive_task + + if self._interface: # Only try to close interface if it exists + self.logger.verbose("Closing interface...") + # For host computer server, just use normal close to keep the server running + if self.use_host_computer_server: + self._interface.close() + else: + # For VM mode, force close the connection + if hasattr(self._interface, "force_close"): + self._interface.force_close() + else: + self._interface.close() + + if not self.use_host_computer_server and self._pylume_context: + try: + self.logger.info(f"Stopping VM {self.config.name}...") + await self.config.pylume.stop_vm(self.config.name) # type: ignore[attr-defined] + except Exception as e: + self.logger.verbose(f"Error stopping VM: {e}") # VM might already be stopped + self.logger.verbose("Closing PyLume context...") + await self.config.pylume.__aexit__(None, None, None) # type: ignore[attr-defined] + self._pylume_context = None + self.logger.info("Computer stopped") + except Exception as e: + self.logger.debug( + f"Error during cleanup: {e}" + ) # Log as debug since this might be expected + finally: + # Log stop time for performance monitoring + duration_ms = (time.time() - start_time) * 1000 + self.logger.debug(f"Computer stop process took {duration_ms:.2f}ms") + return + + # @property + async def get_ip(self) -> str: + """Get the IP address of the VM or localhost if using host computer server.""" + if self.use_host_computer_server: + return "127.0.0.1" + ip = await self.config.get_ip() + return ip or "unknown" # Return "unknown" if ip is None + + async def wait_vm_ready(self) -> Optional[Union[Dict[str, Any], "VMStatus"]]: + """Wait for VM to be ready with an IP address. + + Returns: + VM status information or None if using host computer server. + """ + if self.use_host_computer_server: + return None + + timeout = 600 # 10 minutes timeout (increased from 4 minutes) + interval = 2.0 # 2 seconds between checks (increased to reduce API load) + start_time = time.time() + last_status = None + attempts = 0 + + self.logger.info(f"Waiting for VM {self.config.name} to be ready (timeout: {timeout}s)...") + + while time.time() - start_time < timeout: + attempts += 1 + elapsed = time.time() - start_time + + try: + # Keep polling for VM info + vm = await self.config.pylume.get_vm(self.config.name) # type: ignore[attr-defined] + + # Log full VM properties for debugging (every 30 attempts) + if attempts % 30 == 0: + self.logger.info( + f"VM properties at attempt {attempts}: {vars(vm) if vm else 'None'}" + ) + + # Get current status for logging + current_status = getattr(vm, "status", None) if vm else None + if current_status != last_status: + self.logger.info( + f"VM status changed to: {current_status} (after {elapsed:.1f}s)" + ) + last_status = current_status + + # Check for IP address - ensure it's not None or empty + ip = getattr(vm, "ip_address", None) if vm else None + if ip and ip.strip(): # Check for non-empty string + self.logger.info( + f"VM {self.config.name} got IP address: {ip} (after {elapsed:.1f}s)" + ) + return vm + + if attempts % 10 == 0: # Log every 10 attempts to avoid flooding + self.logger.info( + f"Still waiting for VM IP address... (elapsed: {elapsed:.1f}s)" + ) + else: + self.logger.debug( + f"Waiting for VM IP address... Current IP: {ip}, Status: {current_status}" + ) + + except Exception as e: + self.logger.warning(f"Error checking VM status (attempt {attempts}): {str(e)}") + # If we've been trying for a while and still getting errors, log more details + if elapsed > 60: # After 1 minute of errors, log more details + self.logger.error(f"Persistent error getting VM status: {str(e)}") + self.logger.info("Trying to get VM list for debugging...") + try: + vms = await self.config.pylume.list_vms() # type: ignore[attr-defined] + self.logger.info( + f"Available VMs: {[vm.name for vm in vms if hasattr(vm, 'name')]}" + ) + except Exception as list_error: + self.logger.error(f"Failed to list VMs: {str(list_error)}") + + await asyncio.sleep(interval) + + # If we get here, we've timed out + elapsed = time.time() - start_time + self.logger.error(f"VM {self.config.name} not ready after {elapsed:.1f} seconds") + + # Try to get final VM status for debugging + try: + vm = await self.config.pylume.get_vm(self.config.name) # type: ignore[attr-defined] + status = getattr(vm, "status", "unknown") if vm else "unknown" + ip = getattr(vm, "ip_address", None) if vm else None + self.logger.error(f"Final VM status: {status}, IP: {ip}") + except Exception as e: + self.logger.error(f"Failed to get final VM status: {str(e)}") + + raise TimeoutError( + f"VM {self.config.name} not ready after {elapsed:.1f} seconds - IP address not assigned" + ) + + async def update(self, cpu: Optional[int] = None, memory: Optional[str] = None): + """Update VM settings.""" + self.logger.info( + f"Updating VM settings: CPU={cpu or self.config.cpu}, Memory={memory or self.config.memory}" + ) + update_opts = VMUpdateOpts( + cpu=cpu or int(self.config.cpu), memory=memory or self.config.memory + ) + await self.config.pylume.update_vm(self.config.image, update_opts) # type: ignore[attr-defined] + + def get_screenshot_size(self, screenshot: bytes) -> Dict[str, int]: + """Get the dimensions of a screenshot. + + Args: + screenshot: The screenshot bytes + + Returns: + Dict[str, int]: Dictionary containing 'width' and 'height' of the image + """ + image = Image.open(io.BytesIO(screenshot)) + width, height = image.size + return {"width": width, "height": height} + + @property + def interface(self): + """Get the computer interface for interacting with the VM. + + Returns: + The computer interface + """ + if not hasattr(self, "_interface") or self._interface is None: + error_msg = "Computer interface not initialized. Call run() first." + self.logger.error(error_msg) + self.logger.error( + "Make sure to call await computer.run() before using any interface methods." + ) + raise RuntimeError(error_msg) + + return self._interface + + @property + def telemetry_enabled(self) -> bool: + """Check if telemetry is enabled for this computer instance. + + Returns: + bool: True if telemetry is enabled, False otherwise + """ + return self._telemetry_enabled + + async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]: + """Convert normalized coordinates to screen coordinates. + + Args: + x: X coordinate between 0 and 1 + y: Y coordinate between 0 and 1 + + Returns: + tuple[float, float]: Screen coordinates (x, y) + """ + return await self.interface.to_screen_coordinates(x, y) + + async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]: + """Convert screen coordinates to screenshot coordinates. + + Args: + x: X coordinate in screen space + y: Y coordinate in screen space + + Returns: + tuple[float, float]: (x, y) coordinates in screenshot space + """ + return await self.interface.to_screenshot_coordinates(x, y) diff --git a/libs/computer/computer/interface/__init__.py b/libs/computer/computer/interface/__init__.py new file mode 100644 index 00000000..6d7e1b78 --- /dev/null +++ b/libs/computer/computer/interface/__init__.py @@ -0,0 +1,13 @@ +""" +Interface package for Computer SDK. +""" + +from .factory import InterfaceFactory +from .base import BaseComputerInterface +from .macos import MacOSComputerInterface + +__all__ = [ + "InterfaceFactory", + "BaseComputerInterface", + "MacOSComputerInterface", +] \ No newline at end of file diff --git a/libs/computer/computer/interface/base.py b/libs/computer/computer/interface/base.py new file mode 100644 index 00000000..31106c14 --- /dev/null +++ b/libs/computer/computer/interface/base.py @@ -0,0 +1,190 @@ +"""Base interface for computer control.""" + +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, Tuple, List +from ..logger import Logger, LogLevel + + +class BaseComputerInterface(ABC): + """Base class for computer control interfaces.""" + + def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"): + """Initialize interface. + + Args: + ip_address: IP address of the computer to control + username: Username for authentication + password: Password for authentication + """ + self.ip_address = ip_address + self.username = username + self.password = password + self.logger = Logger("cua.interface", LogLevel.NORMAL) + + @abstractmethod + async def wait_for_ready(self, timeout: int = 60) -> None: + """Wait for interface to be ready. + + Args: + timeout: Maximum time to wait in seconds + + Raises: + TimeoutError: If interface is not ready within timeout + """ + pass + + @abstractmethod + def close(self) -> None: + """Close the interface connection.""" + pass + + def force_close(self) -> None: + """Force close the interface connection. + + By default, this just calls close(), but subclasses can override + to provide more forceful cleanup. + """ + self.close() + + # Mouse Actions + @abstractmethod + async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None: + """Perform a left click.""" + pass + + @abstractmethod + async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None: + """Perform a right click.""" + pass + + @abstractmethod + async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None: + """Perform a double click.""" + pass + + @abstractmethod + async def move_cursor(self, x: int, y: int) -> None: + """Move the cursor to specified position.""" + pass + + @abstractmethod + async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> None: + """Drag from current position to specified coordinates. + + Args: + x: The x coordinate to drag to + y: The y coordinate to drag to + button: The mouse button to use ('left', 'middle', 'right') + duration: How long the drag should take in seconds + """ + pass + + # Keyboard Actions + @abstractmethod + async def type_text(self, text: str) -> None: + """Type the specified text.""" + pass + + @abstractmethod + async def press_key(self, key: str) -> None: + """Press a single key.""" + pass + + @abstractmethod + async def hotkey(self, *keys: str) -> None: + """Press multiple keys simultaneously.""" + pass + + # Scrolling Actions + @abstractmethod + async def scroll_down(self, clicks: int = 1) -> None: + """Scroll down.""" + pass + + @abstractmethod + async def scroll_up(self, clicks: int = 1) -> None: + """Scroll up.""" + pass + + # Screen Actions + @abstractmethod + async def screenshot(self) -> bytes: + """Take a screenshot. + + Returns: + Raw bytes of the screenshot image + """ + pass + + @abstractmethod + async def get_screen_size(self) -> Dict[str, int]: + """Get the screen dimensions. + + Returns: + Dict with 'width' and 'height' keys + """ + pass + + @abstractmethod + async def get_cursor_position(self) -> Dict[str, int]: + """Get current cursor position.""" + pass + + # Clipboard Actions + @abstractmethod + async def copy_to_clipboard(self) -> str: + """Get clipboard content.""" + pass + + @abstractmethod + async def set_clipboard(self, text: str) -> None: + """Set clipboard content.""" + pass + + # File System Actions + @abstractmethod + async def file_exists(self, path: str) -> bool: + """Check if file exists.""" + pass + + @abstractmethod + async def directory_exists(self, path: str) -> bool: + """Check if directory exists.""" + pass + + @abstractmethod + async def run_command(self, command: str) -> Tuple[str, str]: + """Run shell command.""" + pass + + # Accessibility Actions + @abstractmethod + async def get_accessibility_tree(self) -> Dict: + """Get the accessibility tree of the current screen.""" + pass + + @abstractmethod + async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]: + """Convert screenshot coordinates to screen coordinates. + + Args: + x: X coordinate in screenshot space + y: Y coordinate in screenshot space + + Returns: + tuple[float, float]: (x, y) coordinates in screen space + """ + pass + + @abstractmethod + async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]: + """Convert screen coordinates to screenshot coordinates. + + Args: + x: X coordinate in screen space + y: Y coordinate in screen space + + Returns: + tuple[float, float]: (x, y) coordinates in screenshot space + """ + pass diff --git a/libs/computer/computer/interface/factory.py b/libs/computer/computer/interface/factory.py new file mode 100644 index 00000000..cb9a8d93 --- /dev/null +++ b/libs/computer/computer/interface/factory.py @@ -0,0 +1,32 @@ +"""Factory for creating computer interfaces.""" + +from typing import Literal +from .base import BaseComputerInterface + +class InterfaceFactory: + """Factory for creating OS-specific computer interfaces.""" + + @staticmethod + def create_interface_for_os( + os: Literal['macos', 'linux'], + ip_address: str + ) -> BaseComputerInterface: + """Create an interface for the specified OS. + + Args: + os: Operating system type ('macos' or 'linux') + ip_address: IP address of the computer to control + + Returns: + BaseComputerInterface: The appropriate interface for the OS + + Raises: + ValueError: If the OS type is not supported + """ + # Import implementations here to avoid circular imports + from .macos import MacOSComputerInterface + + if os == 'macos': + return MacOSComputerInterface(ip_address) + else: + raise ValueError(f"Unsupported OS type: {os}") \ No newline at end of file diff --git a/libs/computer/computer/interface/linux.py b/libs/computer/computer/interface/linux.py new file mode 100644 index 00000000..52e173b4 --- /dev/null +++ b/libs/computer/computer/interface/linux.py @@ -0,0 +1,27 @@ +"""Linux computer interface implementation.""" + +from typing import Dict +from .base import BaseComputerInterface + +class LinuxInterface(BaseComputerInterface): + """Linux-specific computer interface.""" + + async def wait_for_ready(self, timeout: int = 60) -> None: + """Wait for interface to be ready.""" + # Placeholder implementation + pass + + def close(self) -> None: + """Close the interface connection.""" + # Placeholder implementation + pass + + async def get_screen_size(self) -> Dict[str, int]: + """Get the screen dimensions.""" + # Placeholder implementation + return {"width": 1920, "height": 1080} + + async def screenshot(self) -> bytes: + """Take a screenshot.""" + # Placeholder implementation + return b"" \ No newline at end of file diff --git a/libs/computer/computer/interface/macos.py b/libs/computer/computer/interface/macos.py new file mode 100644 index 00000000..dcacb13c --- /dev/null +++ b/libs/computer/computer/interface/macos.py @@ -0,0 +1,548 @@ +import asyncio +import json +import time +from typing import Any, Dict, List, Optional, Tuple +from PIL import Image + +import websockets + +from ..logger import Logger, LogLevel +from .base import BaseComputerInterface +from ..utils import decode_base64_image, bytes_to_image, draw_box, resize_image +from .models import Key, KeyType + + +class MacOSComputerInterface(BaseComputerInterface): + """Interface for MacOS.""" + + def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"): + super().__init__(ip_address, username, password) + self.ws_uri = f"ws://{ip_address}:8000/ws" + self._ws = None + self._reconnect_task = None + self._closed = False + self._last_ping = 0 + self._ping_interval = 5 # Send ping every 5 seconds + self._ping_timeout = 10 # Wait 10 seconds for pong response + self._reconnect_delay = 1 # Start with 1 second delay + self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts + self._log_connection_attempts = True # Flag to control connection attempt logging + + # Set logger name for MacOS interface + self.logger = Logger("cua.interface.macos", LogLevel.NORMAL) + + async def _keep_alive(self): + """Keep the WebSocket connection alive with automatic reconnection.""" + retry_count = 0 + max_log_attempts = 1 # Only log the first attempt at INFO level + log_interval = 500 # Then log every 500th attempt (significantly increased from 30) + last_warning_time = 0 + min_warning_interval = 30 # Minimum seconds between connection lost warnings + min_retry_delay = 0.5 # Minimum delay between connection attempts (500ms) + + while not self._closed: + try: + if self._ws is None or ( + self._ws and self._ws.state == websockets.protocol.State.CLOSED + ): + try: + retry_count += 1 + + # Add a minimum delay between connection attempts to avoid flooding + if retry_count > 1: + await asyncio.sleep(min_retry_delay) + + # Only log the first attempt at INFO level, then every Nth attempt + if retry_count == 1: + self.logger.info(f"Attempting WebSocket connection to {self.ws_uri}") + elif retry_count % log_interval == 0: + self.logger.info( + f"Still attempting WebSocket connection (attempt {retry_count})..." + ) + else: + # All other attempts are logged at DEBUG level + self.logger.debug( + f"Attempting WebSocket connection to {self.ws_uri} (attempt {retry_count})" + ) + + self._ws = await asyncio.wait_for( + websockets.connect( + self.ws_uri, + max_size=1024 * 1024 * 10, # 10MB limit + max_queue=32, + ping_interval=self._ping_interval, + ping_timeout=self._ping_timeout, + close_timeout=5, + compression=None, # Disable compression to reduce overhead + ), + timeout=30, + ) + self.logger.info("WebSocket connection established") + self._reconnect_delay = 1 # Reset reconnect delay on successful connection + self._last_ping = time.time() + retry_count = 0 # Reset retry count on successful connection + except (asyncio.TimeoutError, websockets.exceptions.WebSocketException) as e: + next_retry = self._reconnect_delay + + # Only log the first error at WARNING level, then every Nth attempt + if retry_count == 1: + self.logger.warning( + f"Computer API Server not ready yet. Will retry automatically." + ) + elif retry_count % log_interval == 0: + self.logger.warning( + f"Still waiting for Computer API Server (attempt {retry_count})..." + ) + else: + # All other errors are logged at DEBUG level + self.logger.debug(f"Connection attempt {retry_count} failed: {e}") + + if self._ws: + try: + await self._ws.close() + except: + pass + self._ws = None + + # Use exponential backoff for connection retries + await asyncio.sleep(self._reconnect_delay) + self._reconnect_delay = min( + self._reconnect_delay * 2, self._max_reconnect_delay + ) + continue + + # Regular ping to check connection + if self._ws and self._ws.state == websockets.protocol.State.OPEN: + try: + if time.time() - self._last_ping >= self._ping_interval: + pong_waiter = await self._ws.ping() + await asyncio.wait_for(pong_waiter, timeout=self._ping_timeout) + self._last_ping = time.time() + except Exception as e: + self.logger.debug(f"Ping failed: {e}") + if self._ws: + try: + await self._ws.close() + except: + pass + self._ws = None + continue + + await asyncio.sleep(1) + + except Exception as e: + current_time = time.time() + # Only log connection lost warnings at most once every min_warning_interval seconds + if current_time - last_warning_time >= min_warning_interval: + self.logger.warning( + f"Computer API Server connection lost. Will retry automatically." + ) + last_warning_time = current_time + else: + # Log at debug level instead + self.logger.debug(f"Connection lost: {e}") + + if self._ws: + try: + await self._ws.close() + except: + pass + self._ws = None + + async def _ensure_connection(self): + """Ensure WebSocket connection is established.""" + if self._reconnect_task is None or self._reconnect_task.done(): + self._reconnect_task = asyncio.create_task(self._keep_alive()) + + retry_count = 0 + max_retries = 5 + + while retry_count < max_retries: + try: + if self._ws and self._ws.state == websockets.protocol.State.OPEN: + return + retry_count += 1 + await asyncio.sleep(1) + except Exception as e: + # Only log at ERROR level for the last retry attempt + if retry_count == max_retries - 1: + self.logger.error( + f"Persistent connection check error after {retry_count} attempts: {e}" + ) + else: + self.logger.debug(f"Connection check error (attempt {retry_count}): {e}") + retry_count += 1 + await asyncio.sleep(1) + continue + + raise ConnectionError("Failed to establish WebSocket connection after multiple retries") + + async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]: + """Send command through WebSocket.""" + max_retries = 3 + retry_count = 0 + last_error = None + + while retry_count < max_retries: + try: + await self._ensure_connection() + if not self._ws: + raise ConnectionError("WebSocket connection is not established") + + message = {"command": command, "params": params or {}} + await self._ws.send(json.dumps(message)) + response = await asyncio.wait_for(self._ws.recv(), timeout=30) + return json.loads(response) + except Exception as e: + last_error = e + retry_count += 1 + if retry_count < max_retries: + # Only log at debug level for intermediate retries + self.logger.debug( + f"Command '{command}' failed (attempt {retry_count}/{max_retries}): {e}" + ) + await asyncio.sleep(1) + continue + else: + # Only log at error level for the final failure + self.logger.error( + f"Failed to send command '{command}' after {max_retries} retries" + ) + self.logger.debug(f"Command failure details: {e}") + raise + + raise last_error if last_error else RuntimeError("Failed to send command") + + async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0): + """Wait for WebSocket connection to become available.""" + start_time = time.time() + last_error = None + attempt_count = 0 + progress_interval = 10 # Log progress every 10 seconds + last_progress_time = start_time + + # Disable detailed logging for connection attempts + self._log_connection_attempts = False + + try: + self.logger.info( + f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..." + ) + + # Start the keep-alive task if it's not already running + if self._reconnect_task is None or self._reconnect_task.done(): + self._reconnect_task = asyncio.create_task(self._keep_alive()) + + # Wait for the connection to be established + while time.time() - start_time < timeout: + try: + attempt_count += 1 + current_time = time.time() + + # Log progress periodically without flooding logs + if current_time - last_progress_time >= progress_interval: + elapsed = current_time - start_time + self.logger.info( + f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})" + ) + last_progress_time = current_time + + # Check if we have a connection + if self._ws and self._ws.state == websockets.protocol.State.OPEN: + # Test the connection with a simple command + try: + await self._send_command("get_screen_size") + elapsed = time.time() - start_time + self.logger.info( + f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)" + ) + return # Connection is fully working + except Exception as e: + last_error = e + self.logger.debug(f"Connection test failed: {e}") + + # Wait before trying again + await asyncio.sleep(interval) + + except Exception as e: + last_error = e + self.logger.debug(f"Connection attempt {attempt_count} failed: {e}") + await asyncio.sleep(interval) + + # If we get here, we've timed out + error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds" + if last_error: + error_msg += f": {str(last_error)}" + self.logger.error(error_msg) + raise TimeoutError(error_msg) + finally: + # Reset to default logging behavior + self._log_connection_attempts = False + + def close(self): + """Close WebSocket connection. + + Note: In host computer server mode, we leave the connection open + to allow other clients to connect to the same server. The server + will handle cleaning up idle connections. + """ + # Only cancel the reconnect task + if self._reconnect_task: + self._reconnect_task.cancel() + + # Don't set closed flag or close websocket by default + # This allows the server to stay connected for other clients + # self._closed = True + # if self._ws: + # asyncio.create_task(self._ws.close()) + # self._ws = None + + def force_close(self): + """Force close the WebSocket connection. + + This method should be called when you want to completely + shut down the connection, not just for regular cleanup. + """ + self._closed = True + if self._reconnect_task: + self._reconnect_task.cancel() + if self._ws: + asyncio.create_task(self._ws.close()) + self._ws = None + + # Mouse Actions + async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None: + await self._send_command("left_click", {"x": x, "y": y}) + + async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None: + await self._send_command("right_click", {"x": x, "y": y}) + + async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None: + await self._send_command("double_click", {"x": x, "y": y}) + + async def move_cursor(self, x: int, y: int) -> None: + await self._send_command("move_cursor", {"x": x, "y": y}) + + async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> None: + await self._send_command( + "drag_to", {"x": x, "y": y, "button": button, "duration": duration} + ) + + # Keyboard Actions + async def type_text(self, text: str) -> None: + await self._send_command("type_text", {"text": text}) + + async def press(self, key: "KeyType") -> None: + """Press a single key. + + Args: + key: The key to press. Can be any of: + - A Key enum value (recommended), e.g. Key.PAGE_DOWN + - A direct key value string, e.g. 'pagedown' + - A single character string, e.g. 'a' + + Examples: + ```python + # Using enum (recommended) + await interface.press(Key.PAGE_DOWN) + await interface.press(Key.ENTER) + + # Using direct values + await interface.press('pagedown') + await interface.press('enter') + + # Using single characters + await interface.press('a') + ``` + + Raises: + ValueError: If the key type is invalid or the key is not recognized + """ + if isinstance(key, Key): + actual_key = key.value + elif isinstance(key, str): + # Try to convert to enum if it matches a known key + key_or_enum = Key.from_string(key) + actual_key = key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum + else: + raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.") + + await self._send_command("press_key", {"key": actual_key}) + + async def press_key(self, key: "KeyType") -> None: + """DEPRECATED: Use press() instead. + + This method is kept for backward compatibility but will be removed in a future version. + Please use the press() method instead. + """ + await self.press(key) + + async def hotkey(self, *keys: str) -> None: + await self._send_command("hotkey", {"keys": list(keys)}) + + # Scrolling Actions + async def scroll_down(self, clicks: int = 1) -> None: + for _ in range(clicks): + await self.hotkey("pagedown") + + async def scroll_up(self, clicks: int = 1) -> None: + for _ in range(clicks): + await self.hotkey("pageup") + + # Screen Actions + async def screenshot( + self, + boxes: Optional[List[Tuple[int, int, int, int]]] = None, + box_color: str = "#FF0000", + box_thickness: int = 2, + scale_factor: float = 1.0, + ) -> bytes: + """Take a screenshot with optional box drawing and scaling. + + Args: + boxes: Optional list of (x, y, width, height) tuples defining boxes to draw in screen coordinates + box_color: Color of the boxes in hex format (default: "#FF0000" red) + box_thickness: Thickness of the box borders in pixels (default: 2) + scale_factor: Factor to scale the final image by (default: 1.0) + Use > 1.0 to enlarge, < 1.0 to shrink (e.g., 0.5 for half size, 2.0 for double) + + Returns: + bytes: The screenshot image data, optionally with boxes drawn on it and scaled + """ + result = await self._send_command("screenshot") + if not result.get("image_data"): + raise RuntimeError("Failed to take screenshot") + + screenshot = decode_base64_image(result["image_data"]) + + if boxes: + # Get the natural scaling between screen and screenshot + screen_size = await self.get_screen_size() + screenshot_width, screenshot_height = bytes_to_image(screenshot).size + width_scale = screenshot_width / screen_size["width"] + height_scale = screenshot_height / screen_size["height"] + + # Scale box coordinates from screen space to screenshot space + for box in boxes: + scaled_box = ( + int(box[0] * width_scale), # x + int(box[1] * height_scale), # y + int(box[2] * width_scale), # width + int(box[3] * height_scale), # height + ) + screenshot = draw_box( + screenshot, + x=scaled_box[0], + y=scaled_box[1], + width=scaled_box[2], + height=scaled_box[3], + color=box_color, + thickness=box_thickness, + ) + + if scale_factor != 1.0: + screenshot = resize_image(screenshot, scale_factor) + + return screenshot + + async def get_screen_size(self) -> Dict[str, int]: + result = await self._send_command("get_screen_size") + if result["success"] and result["size"]: + return result["size"] + raise RuntimeError("Failed to get screen size") + + async def get_cursor_position(self) -> Dict[str, int]: + result = await self._send_command("get_cursor_position") + if result["success"] and result["position"]: + return result["position"] + raise RuntimeError("Failed to get cursor position") + + # Clipboard Actions + async def copy_to_clipboard(self) -> str: + result = await self._send_command("copy_to_clipboard") + if result["success"] and result["content"]: + return result["content"] + raise RuntimeError("Failed to get clipboard content") + + async def set_clipboard(self, text: str) -> None: + await self._send_command("set_clipboard", {"text": text}) + + # File System Actions + async def file_exists(self, path: str) -> bool: + result = await self._send_command("file_exists", {"path": path}) + return result.get("exists", False) + + async def directory_exists(self, path: str) -> bool: + result = await self._send_command("directory_exists", {"path": path}) + return result.get("exists", False) + + async def run_command(self, command: str) -> Tuple[str, str]: + result = await self._send_command("run_command", {"command": command}) + if not result.get("success", False): + raise RuntimeError(result.get("error", "Failed to run command")) + return result.get("stdout", ""), result.get("stderr", "") + + # Accessibility Actions + async def get_accessibility_tree(self) -> Dict[str, Any]: + """Get the accessibility tree of the current screen.""" + result = await self._send_command("get_accessibility_tree") + if not result.get("success", False): + raise RuntimeError(result.get("error", "Failed to get accessibility tree")) + return result.get("tree", {}) + + async def get_active_window_bounds(self) -> Dict[str, int]: + """Get the bounds of the currently active window.""" + result = await self._send_command("get_active_window_bounds") + if result["success"] and result["bounds"]: + return result["bounds"] + raise RuntimeError("Failed to get active window bounds") + + async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]: + """Convert screenshot coordinates to screen coordinates. + + Args: + x: X coordinate in screenshot space + y: Y coordinate in screenshot space + + Returns: + tuple[float, float]: (x, y) coordinates in screen space + """ + screen_size = await self.get_screen_size() + screenshot = await self.screenshot() + screenshot_img = bytes_to_image(screenshot) + screenshot_width, screenshot_height = screenshot_img.size + + # Calculate scaling factors + width_scale = screen_size["width"] / screenshot_width + height_scale = screen_size["height"] / screenshot_height + + # Convert coordinates + screen_x = x * width_scale + screen_y = y * height_scale + + return screen_x, screen_y + + async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]: + """Convert screen coordinates to screenshot coordinates. + + Args: + x: X coordinate in screen space + y: Y coordinate in screen space + + Returns: + tuple[float, float]: (x, y) coordinates in screenshot space + """ + screen_size = await self.get_screen_size() + screenshot = await self.screenshot() + screenshot_img = bytes_to_image(screenshot) + screenshot_width, screenshot_height = screenshot_img.size + + # Calculate scaling factors + width_scale = screenshot_width / screen_size["width"] + height_scale = screenshot_height / screen_size["height"] + + # Convert coordinates + screenshot_x = x * width_scale + screenshot_y = y * height_scale + + return screenshot_x, screenshot_y diff --git a/libs/computer/computer/interface/models.py b/libs/computer/computer/interface/models.py new file mode 100644 index 00000000..b586a9f7 --- /dev/null +++ b/libs/computer/computer/interface/models.py @@ -0,0 +1,97 @@ +from enum import Enum +from typing import Dict, List, Any, TypedDict, Union, Literal + +# Navigation key literals +NavigationKey = Literal['pagedown', 'pageup', 'home', 'end', 'left', 'right', 'up', 'down'] + +# Special key literals +SpecialKey = Literal['enter', 'esc', 'tab', 'space', 'backspace', 'del'] + +# Function key literals +FunctionKey = Literal['f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'f10', 'f11', 'f12'] + +class Key(Enum): + """Keyboard keys that can be used with press_key. + + These key names map to PyAutoGUI's expected key names. + """ + # Navigation + PAGE_DOWN = 'pagedown' + PAGE_UP = 'pageup' + HOME = 'home' + END = 'end' + LEFT = 'left' + RIGHT = 'right' + UP = 'up' + DOWN = 'down' + + # Special keys + RETURN = 'enter' + ENTER = 'enter' + ESCAPE = 'esc' + ESC = 'esc' + TAB = 'tab' + SPACE = 'space' + BACKSPACE = 'backspace' + DELETE = 'del' + + # Function keys + F1 = 'f1' + F2 = 'f2' + F3 = 'f3' + F4 = 'f4' + F5 = 'f5' + F6 = 'f6' + F7 = 'f7' + F8 = 'f8' + F9 = 'f9' + F10 = 'f10' + F11 = 'f11' + F12 = 'f12' + + @classmethod + def from_string(cls, key: str) -> 'Key | str': + """Convert a string key name to a Key enum value. + + Args: + key: String key name to convert + + Returns: + Key enum value if the string matches a known key, + otherwise returns the original string for single character keys + """ + # Map common alternative names to enum values + key_mapping = { + 'page_down': cls.PAGE_DOWN, + 'page down': cls.PAGE_DOWN, + 'pagedown': cls.PAGE_DOWN, + 'page_up': cls.PAGE_UP, + 'page up': cls.PAGE_UP, + 'pageup': cls.PAGE_UP, + 'return': cls.RETURN, + 'enter': cls.ENTER, + 'escape': cls.ESCAPE, + 'esc': cls.ESC, + 'delete': cls.DELETE, + 'del': cls.DELETE + } + + normalized = key.lower().strip() + return key_mapping.get(normalized, key) + +# Combined key type +KeyType = Union[Key, NavigationKey, SpecialKey, FunctionKey, str] + +class AccessibilityWindow(TypedDict): + """Information about a window in the accessibility tree.""" + app_name: str + pid: int + frontmost: bool + has_windows: bool + windows: List[Dict[str, Any]] + +class AccessibilityTree(TypedDict): + """Complete accessibility tree information.""" + success: bool + frontmost_application: str + windows: List[AccessibilityWindow] \ No newline at end of file diff --git a/libs/computer/computer/logger.py b/libs/computer/computer/logger.py new file mode 100644 index 00000000..46b7fbd1 --- /dev/null +++ b/libs/computer/computer/logger.py @@ -0,0 +1,84 @@ +"""Logging utilities for the Computer module.""" + +import logging +from enum import IntEnum + + +# Keep LogLevel for backward compatibility, but it will be deprecated +class LogLevel(IntEnum): + """Log levels for logging. Deprecated - use standard logging levels instead.""" + + QUIET = 0 # Only warnings and errors + NORMAL = 1 # Info level, standard output + VERBOSE = 2 # More detailed information + DEBUG = 3 # Full debug information + + +# Map LogLevel to standard logging levels for backward compatibility +LOGLEVEL_MAP = { + LogLevel.QUIET: logging.WARNING, + LogLevel.NORMAL: logging.INFO, + LogLevel.VERBOSE: logging.DEBUG, + LogLevel.DEBUG: logging.DEBUG, +} + + +class Logger: + """Logger class for Computer.""" + + def __init__(self, name: str, verbosity: int): + """Initialize the logger. + + Args: + name: The name of the logger. + verbosity: The log level (use standard logging levels like logging.INFO). + For backward compatibility, LogLevel enum values are also accepted. + """ + self.logger = logging.getLogger(name) + + # Convert LogLevel enum to standard logging level if needed + if isinstance(verbosity, LogLevel): + self.verbosity = LOGLEVEL_MAP.get(verbosity, logging.INFO) + else: + self.verbosity = verbosity + + self._configure() + + def _configure(self): + """Configure the logger based on log level.""" + # Set the logging level directly + self.logger.setLevel(self.verbosity) + + # Log the verbosity level that was set + if self.verbosity <= logging.DEBUG: + self.logger.info("Logger set to DEBUG level") + elif self.verbosity <= logging.INFO: + self.logger.info("Logger set to INFO level") + elif self.verbosity <= logging.WARNING: + self.logger.warning("Logger set to WARNING level") + elif self.verbosity <= logging.ERROR: + self.logger.warning("Logger set to ERROR level") + elif self.verbosity <= logging.CRITICAL: + self.logger.warning("Logger set to CRITICAL level") + + def debug(self, message: str): + """Log a debug message if log level is DEBUG or lower.""" + self.logger.debug(message) + + def info(self, message: str): + """Log an info message if log level is INFO or lower.""" + self.logger.info(message) + + def verbose(self, message: str): + """Log a verbose message between INFO and DEBUG levels.""" + # Since there's no standard verbose level, + # use debug level with [VERBOSE] prefix for backward compatibility + self.logger.debug(f"[VERBOSE] {message}") + + def warning(self, message: str): + """Log a warning message.""" + self.logger.warning(message) + + def error(self, message: str): + """Log an error message.""" + self.logger.error(message) diff --git a/libs/computer/computer/models.py b/libs/computer/computer/models.py new file mode 100644 index 00000000..13ff36b2 --- /dev/null +++ b/libs/computer/computer/models.py @@ -0,0 +1,35 @@ +"""Models for computer configuration.""" + +from dataclasses import dataclass +from typing import Optional +from pylume import PyLume + +@dataclass +class Display: + """Display configuration.""" + width: int + height: int + +@dataclass +class Image: + """VM image configuration.""" + image: str + tag: str + name: str + +@dataclass +class Computer: + """Computer configuration.""" + image: str + tag: str + name: str + display: Display + memory: str + cpu: str + pylume: Optional[PyLume] = None + + # @property # Remove the property decorator + async def get_ip(self) -> Optional[str]: + """Get the IP address of the VM.""" + vm = await self.pylume.get_vm(self.name) # type: ignore[attr-defined] + return vm.ip_address if vm else None \ No newline at end of file diff --git a/libs/computer/computer/telemetry.py b/libs/computer/computer/telemetry.py new file mode 100644 index 00000000..e5996c0c --- /dev/null +++ b/libs/computer/computer/telemetry.py @@ -0,0 +1,117 @@ +"""Computer telemetry for tracking anonymous usage and feature usage.""" + +import logging +import platform +from typing import Any + +# Import the core telemetry module +TELEMETRY_AVAILABLE = False + +try: + from core.telemetry import record_event, increment, is_telemetry_enabled + + def increment_counter(counter_name: str, value: int = 1) -> None: + """Wrapper for increment to maintain backward compatibility.""" + if is_telemetry_enabled(): + increment(counter_name, value) + + def set_dimension(name: str, value: Any) -> None: + """Set a dimension that will be attached to all events.""" + logger = logging.getLogger("cua.computer.telemetry") + logger.debug(f"Setting dimension {name}={value}") + + TELEMETRY_AVAILABLE = True + logger = logging.getLogger("cua.computer.telemetry") + logger.info("Successfully imported telemetry") +except ImportError as e: + logger = logging.getLogger("cua.computer.telemetry") + logger.warning(f"Could not import telemetry: {e}") + TELEMETRY_AVAILABLE = False + + +# Local fallbacks in case core telemetry isn't available +def _noop(*args: Any, **kwargs: Any) -> None: + """No-op function for when telemetry is not available.""" + pass + + +logger = logging.getLogger("cua.computer.telemetry") + +# If telemetry isn't available, use no-op functions +if not TELEMETRY_AVAILABLE: + logger.debug("Telemetry not available, using no-op functions") + record_event = _noop # type: ignore + increment_counter = _noop # type: ignore + set_dimension = _noop # type: ignore + get_telemetry_client = lambda: None # type: ignore + flush = _noop # type: ignore + is_telemetry_enabled = lambda: False # type: ignore + is_telemetry_globally_disabled = lambda: True # type: ignore + +# Get system info once to use in telemetry +SYSTEM_INFO = { + "os": platform.system().lower(), + "os_version": platform.release(), + "python_version": platform.python_version(), +} + + +def enable_telemetry() -> bool: + """Enable telemetry if available. + + Returns: + bool: True if telemetry was successfully enabled, False otherwise + """ + global TELEMETRY_AVAILABLE + + # Check if globally disabled using core function + if TELEMETRY_AVAILABLE and is_telemetry_globally_disabled(): + logger.info("Telemetry is globally disabled via environment variable - cannot enable") + return False + + # Already enabled + if TELEMETRY_AVAILABLE: + return True + + # Try to import and enable + try: + from core.telemetry import ( + is_telemetry_globally_disabled, + ) + + # Check again after import + if is_telemetry_globally_disabled(): + logger.info("Telemetry is globally disabled via environment variable - cannot enable") + return False + + TELEMETRY_AVAILABLE = True + logger.info("Telemetry successfully enabled") + return True + except ImportError as e: + logger.warning(f"Could not enable telemetry: {e}") + return False + + +def is_telemetry_enabled() -> bool: + """Check if telemetry is enabled. + + Returns: + bool: True if telemetry is enabled, False otherwise + """ + # Use the core function if available, otherwise use our local flag + if TELEMETRY_AVAILABLE: + from core.telemetry import is_telemetry_enabled as core_is_enabled + + return core_is_enabled() + return False + + +def record_computer_initialization() -> None: + """Record when a computer instance is initialized.""" + if TELEMETRY_AVAILABLE and is_telemetry_enabled(): + record_event("computer_initialized", SYSTEM_INFO) + + # Set dimensions that will be attached to all events + set_dimension("os", SYSTEM_INFO["os"]) + set_dimension("os_version", SYSTEM_INFO["os_version"]) + set_dimension("python_version", SYSTEM_INFO["python_version"]) diff --git a/libs/computer/computer/utils.py b/libs/computer/computer/utils.py new file mode 100644 index 00000000..070f8ebc --- /dev/null +++ b/libs/computer/computer/utils.py @@ -0,0 +1,101 @@ +import base64 +from typing import Tuple, Optional, Dict, Any +from PIL import Image, ImageDraw +import io + +def decode_base64_image(base64_str: str) -> bytes: + """Decode a base64 string into image bytes.""" + return base64.b64decode(base64_str) + +def encode_base64_image(image_bytes: bytes) -> str: + """Encode image bytes to base64 string.""" + return base64.b64encode(image_bytes).decode('utf-8') + +def bytes_to_image(image_bytes: bytes) -> Image.Image: + """Convert bytes to PIL Image. + + Args: + image_bytes: Raw image bytes + + Returns: + PIL.Image: The converted image + """ + return Image.open(io.BytesIO(image_bytes)) + +def image_to_bytes(image: Image.Image, format: str = 'PNG') -> bytes: + """Convert PIL Image to bytes.""" + buf = io.BytesIO() + image.save(buf, format=format) + return buf.getvalue() + +def resize_image(image_bytes: bytes, scale_factor: float) -> bytes: + """Resize an image by a scale factor. + + Args: + image_bytes: The original image as bytes + scale_factor: Factor to scale the image by (e.g., 0.5 for half size, 2.0 for double) + + Returns: + bytes: The resized image as bytes + """ + image = bytes_to_image(image_bytes) + if scale_factor != 1.0: + new_size = (int(image.width * scale_factor), int(image.height * scale_factor)) + image = image.resize(new_size, Image.Resampling.LANCZOS) + return image_to_bytes(image) + +def draw_box( + image_bytes: bytes, + x: int, + y: int, + width: int, + height: int, + color: str = "#FF0000", + thickness: int = 2 +) -> bytes: + """Draw a box on an image. + + Args: + image_bytes: The original image as bytes + x: X coordinate of top-left corner + y: Y coordinate of top-left corner + width: Width of the box + height: Height of the box + color: Color of the box in hex format + thickness: Thickness of the box border in pixels + + Returns: + bytes: The modified image as bytes + """ + # Convert bytes to PIL Image + image = bytes_to_image(image_bytes) + + # Create drawing context + draw = ImageDraw.Draw(image) + + # Draw rectangle + draw.rectangle( + [(x, y), (x + width, y + height)], + outline=color, + width=thickness + ) + + # Convert back to bytes + return image_to_bytes(image) + +def get_image_size(image_bytes: bytes) -> Tuple[int, int]: + """Get the dimensions of an image. + + Args: + image_bytes: The image as bytes + + Returns: + Tuple[int, int]: Width and height of the image + """ + image = bytes_to_image(image_bytes) + return image.size + +def parse_vm_info(vm_info: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Parse VM info from pylume response.""" + if not vm_info: + return None \ No newline at end of file diff --git a/libs/computer/poetry.toml b/libs/computer/poetry.toml new file mode 100644 index 00000000..ab1033bd --- /dev/null +++ b/libs/computer/poetry.toml @@ -0,0 +1,2 @@ +[virtualenvs] +in-project = true diff --git a/libs/computer/pyproject.toml b/libs/computer/pyproject.toml new file mode 100644 index 00000000..27c5b8c0 --- /dev/null +++ b/libs/computer/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" + +[project] +name = "cua-computer" +version = "0.1.0" +description = "Computer-Use Interface (CUI) framework powering Cua" +readme = "README.md" +authors = [ + { name = "TryCua", email = "gh@trycua.com" } +] +dependencies = [ + "pylume>=0.1.8", + "pillow>=10.0.0", + "websocket-client>=1.8.0", + "websockets>=12.0", + "aiohttp>=3.9.0", + "cua-core>=0.1.0,<0.2.0" +] +requires-python = ">=3.10,<3.13" + +[tool.pdm] +distribution = true + +[tool.pdm.build] +includes = ["computer/"] +source-includes = ["tests/", "README.md", "LICENSE"] + +[tool.black] +line-length = 100 +target-version = ["py310"] + +[tool.ruff] +line-length = 100 +target-version = "py310" +select = ["E", "F", "B", "I"] +fix = true + +[tool.ruff.format] +docstring-code-format = true + +[tool.mypy] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = "test_*.py" \ No newline at end of file diff --git a/libs/computer/tests/test_computer.py b/libs/computer/tests/test_computer.py new file mode 100644 index 00000000..642ccb4e --- /dev/null +++ b/libs/computer/tests/test_computer.py @@ -0,0 +1,18 @@ +"""Basic tests for the computer package.""" + +import pytest +from computer import Computer + +def test_computer_import(): + """Test that we can import the Computer class.""" + assert Computer is not None + +def test_computer_init(): + """Test that we can create a Computer instance.""" + computer = Computer( + display={"width": 1920, "height": 1080}, + memory="16GB", + cpu="4", + use_host_computer_server=True + ) + assert computer is not None \ No newline at end of file diff --git a/libs/core/README.md b/libs/core/README.md new file mode 100644 index 00000000..1fb0bec6 --- /dev/null +++ b/libs/core/README.md @@ -0,0 +1,28 @@ +
+

+
+ + + + Shows my svg + +
+ + [![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#) + [![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#) + [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85) + [![PyPI](https://img.shields.io/pypi/v/cua-core?color=333333)](https://pypi.org/project/cua-core/) +

+
+ +**Cua Core** provides essential shared functionality and utilities used across the Cua ecosystem: + +- Privacy-focused telemetry system for transparent usage analytics +- Common helper functions and utilities used by other Cua packages +- Core infrastructure components shared between modules + +## Installation + +```bash +pip install cua-core +``` \ No newline at end of file diff --git a/libs/core/core/__init__.py b/libs/core/core/__init__.py new file mode 100644 index 00000000..32b90f4c --- /dev/null +++ b/libs/core/core/__init__.py @@ -0,0 +1,3 @@ +"""Core functionality shared across Cua components.""" + +__version__ = "0.1.0" diff --git a/libs/core/core/telemetry/__init__.py b/libs/core/core/telemetry/__init__.py new file mode 100644 index 00000000..04f3f057 --- /dev/null +++ b/libs/core/core/telemetry/__init__.py @@ -0,0 +1,29 @@ +"""This module provides the core telemetry functionality for CUA libraries. + +It provides a low-overhead way to collect anonymous usage data. +""" + +from core.telemetry.telemetry import ( + UniversalTelemetryClient, + enable_telemetry, + disable_telemetry, + flush, + get_telemetry_client, + increment, + record_event, + is_telemetry_enabled, + is_telemetry_globally_disabled, +) + + +__all__ = [ + "UniversalTelemetryClient", + "enable_telemetry", + "disable_telemetry", + "flush", + "get_telemetry_client", + "increment", + "record_event", + "is_telemetry_enabled", + "is_telemetry_globally_disabled", +] diff --git a/libs/core/core/telemetry/client.py b/libs/core/core/telemetry/client.py new file mode 100644 index 00000000..22686890 --- /dev/null +++ b/libs/core/core/telemetry/client.py @@ -0,0 +1,233 @@ +"""Telemetry client for collecting anonymous usage data.""" + +from __future__ import annotations + +import json +import logging +import os +import random +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +from core import __version__ +from core.telemetry.sender import send_telemetry + +logger = logging.getLogger("cua.telemetry") + +# Controls how frequently telemetry will be sent (percentage) +TELEMETRY_SAMPLE_RATE = 5 # 5% sampling rate + + +@dataclass +class TelemetryConfig: + """Configuration for telemetry collection.""" + + enabled: bool = False # Default to opt-in + sample_rate: float = TELEMETRY_SAMPLE_RATE + project_root: Optional[Path] = None + + @classmethod + def from_env(cls, project_root: Optional[Path] = None) -> TelemetryConfig: + """Load config from environment variables.""" + # CUA_TELEMETRY should be set to "on" to enable telemetry (opt-in) + return cls( + enabled=os.environ.get("CUA_TELEMETRY", "").lower() == "on", + sample_rate=float(os.environ.get("CUA_TELEMETRY_SAMPLE_RATE", TELEMETRY_SAMPLE_RATE)), + project_root=project_root, + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert config to dictionary.""" + return { + "enabled": self.enabled, + "sample_rate": self.sample_rate, + } + + +class TelemetryClient: + """Collects and reports telemetry data with transparency and sampling.""" + + def __init__( + self, project_root: Optional[Path] = None, config: Optional[TelemetryConfig] = None + ): + """Initialize telemetry client. + + Args: + project_root: Root directory of the project + config: Telemetry configuration, or None to load from environment + """ + self.config = config or TelemetryConfig.from_env(project_root) + self.installation_id = self._get_or_create_installation_id() + self.counters: Dict[str, int] = {} + self.events: List[Dict[str, Any]] = [] + self.start_time = time.time() + + # Log telemetry status on startup + if self.config.enabled: + logger.info(f"Telemetry enabled (sampling at {self.config.sample_rate}%)") + else: + logger.info("Telemetry disabled") + + # Create .cua directory if it doesn't exist and config is provided + if self.config.project_root: + self._setup_local_storage() + + def _get_or_create_installation_id(self) -> str: + """Get or create a random installation ID. + + This ID is not tied to any personal information. + """ + if self.config.project_root: + id_file = self.config.project_root / ".cua" / "installation_id" + if id_file.exists(): + try: + return id_file.read_text().strip() + except Exception: + pass + + # Create new ID if not exists + new_id = str(uuid.uuid4()) + try: + id_file.parent.mkdir(parents=True, exist_ok=True) + id_file.write_text(new_id) + return new_id + except Exception: + pass + + # Fallback to in-memory ID if file operations fail + return str(uuid.uuid4()) + + def _setup_local_storage(self) -> None: + """Create local storage directories and files.""" + if not self.config.project_root: + return + + cua_dir = self.config.project_root / ".cua" + cua_dir.mkdir(parents=True, exist_ok=True) + + # Store telemetry config + config_path = cua_dir / "telemetry_config.json" + with open(config_path, "w") as f: + json.dump(self.config.to_dict(), f) + + def increment(self, counter_name: str, value: int = 1) -> None: + """Increment a named counter. + + Args: + counter_name: Name of the counter + value: Amount to increment by (default: 1) + """ + if not self.config.enabled: + return + + if counter_name not in self.counters: + self.counters[counter_name] = 0 + self.counters[counter_name] += value + + def record_event(self, event_name: str, properties: Optional[Dict[str, Any]] = None) -> None: + """Record an event with optional properties. + + Args: + event_name: Name of the event + properties: Event properties (must not contain sensitive data) + """ + if not self.config.enabled: + return + + # Increment counter for this event type + counter_key = f"event:{event_name}" + self.increment(counter_key) + + # Record event details for deeper analysis (if sampled) + if properties and random.random() * 100 <= self.config.sample_rate: + self.events.append( + {"name": event_name, "properties": properties, "timestamp": time.time()} + ) + + def flush(self) -> bool: + """Send collected telemetry if sampling criteria is met. + + Returns: + bool: True if telemetry was sent, False otherwise + """ + if not self.config.enabled or (not self.counters and not self.events): + return False + + # Apply sampling - only send data for a percentage of installations + if random.random() * 100 > self.config.sample_rate: + logger.debug("Telemetry sampled out") + self.counters.clear() + self.events.clear() + return False + + # Prepare telemetry payload + payload = { + "version": __version__, + "installation_id": self.installation_id, + "counters": self.counters.copy(), + "events": self.events.copy(), + "duration": time.time() - self.start_time, + "timestamp": time.time(), + } + + try: + # Send telemetry data + success = send_telemetry(payload) + if success: + logger.debug( + f"Telemetry sent: {len(self.counters)} counters, {len(self.events)} events" + ) + else: + logger.debug("Failed to send telemetry") + return success + except Exception as e: + logger.debug(f"Failed to send telemetry: {e}") + return False + finally: + # Clear data after sending + self.counters.clear() + self.events.clear() + + def enable(self) -> None: + """Enable telemetry collection.""" + self.config.enabled = True + logger.info("Telemetry enabled") + if self.config.project_root: + self._setup_local_storage() + + def disable(self) -> None: + """Disable telemetry collection.""" + self.config.enabled = False + logger.info("Telemetry disabled") + if self.config.project_root: + self._setup_local_storage() + + +# Global telemetry client instance +_client: Optional[TelemetryClient] = None + + +def get_telemetry_client(project_root: Optional[Path] = None) -> TelemetryClient: + """Get or initialize the global telemetry client. + + Args: + project_root: Root directory of the project + + Returns: + The global telemetry client instance + """ + global _client + + if _client is None: + _client = TelemetryClient(project_root) + + return _client + + +def disable_telemetry() -> None: + """Disable telemetry collection globally.""" + if _client is not None: + _client.disable() diff --git a/libs/core/core/telemetry/models.py b/libs/core/core/telemetry/models.py new file mode 100644 index 00000000..d37e8685 --- /dev/null +++ b/libs/core/core/telemetry/models.py @@ -0,0 +1,37 @@ +"""Models for telemetry data.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class TelemetryEvent(BaseModel): + """A telemetry event with properties.""" + + name: str + properties: Dict[str, Any] = Field(default_factory=dict) + timestamp: float = Field(default_factory=lambda: datetime.now().timestamp()) + + +class TelemetryPayload(BaseModel): + """Telemetry payload sent to the server.""" + + version: str + installation_id: str + counters: Dict[str, int] = Field(default_factory=dict) + events: List[TelemetryEvent] = Field(default_factory=list) + duration: float = 0 + timestamp: float = Field(default_factory=lambda: datetime.now().timestamp()) + + +class UserRecord(BaseModel): + """User record stored in the telemetry database.""" + + id: str + version: Optional[str] = None + created_at: Optional[datetime] = None + last_seen_at: Optional[datetime] = None + is_ci: bool = False diff --git a/libs/core/core/telemetry/posthog_client.py b/libs/core/core/telemetry/posthog_client.py new file mode 100644 index 00000000..8ddb1dd2 --- /dev/null +++ b/libs/core/core/telemetry/posthog_client.py @@ -0,0 +1,338 @@ +"""Telemetry client using PostHog for collecting anonymous usage data.""" + +from __future__ import annotations + +import json +import logging +import os +import random +import time +import uuid +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +import posthog +from core import __version__ + +logger = logging.getLogger("cua.telemetry") + +# Controls how frequently telemetry will be sent (percentage) +TELEMETRY_SAMPLE_RATE = 100 # 100% sampling rate (was 5%) + +# Public PostHog config for anonymous telemetry +# These values are intentionally public and meant for anonymous telemetry only +# https://posthog.com/docs/product-analytics/troubleshooting#is-it-ok-for-my-api-key-to-be-exposed-and-public +PUBLIC_POSTHOG_API_KEY = "phc_eSkLnbLxsnYFaXksif1ksbrNzYlJShr35miFLDppF14" +PUBLIC_POSTHOG_HOST = "https://eu.i.posthog.com" + + +@dataclass +class TelemetryConfig: + """Configuration for telemetry collection.""" + + enabled: bool = True # Default to enabled (opt-out) + sample_rate: float = TELEMETRY_SAMPLE_RATE + + @classmethod + def from_env(cls) -> TelemetryConfig: + """Load config from environment variables.""" + # Check for multiple environment variables that can disable telemetry: + # CUA_TELEMETRY=off to disable telemetry (legacy way) + # CUA_TELEMETRY_DISABLED=1 to disable telemetry (new, more explicit way) + telemetry_disabled = os.environ.get("CUA_TELEMETRY", "").lower() == "off" or os.environ.get( + "CUA_TELEMETRY_DISABLED", "" + ).lower() in ("1", "true", "yes", "on") + + return cls( + enabled=not telemetry_disabled, + sample_rate=float(os.environ.get("CUA_TELEMETRY_SAMPLE_RATE", TELEMETRY_SAMPLE_RATE)), + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert config to dictionary.""" + return { + "enabled": self.enabled, + "sample_rate": self.sample_rate, + } + + +def get_posthog_config() -> dict: + """Get PostHog configuration for anonymous telemetry. + + Uses the public API key that's specifically intended for anonymous telemetry collection. + No private keys are used or required from users. + + Returns: + Dict with PostHog configuration + """ + # Return the public config + logger.debug("Using public PostHog configuration") + return {"api_key": PUBLIC_POSTHOG_API_KEY, "host": PUBLIC_POSTHOG_HOST} + + +class PostHogTelemetryClient: + """Collects and reports telemetry data via PostHog.""" + + def __init__(self): + """Initialize PostHog telemetry client.""" + self.config = TelemetryConfig.from_env() + self.installation_id = self._get_or_create_installation_id() + self.initialized = False + self.queued_events: List[Dict[str, Any]] = [] + self.start_time = time.time() + + # Log telemetry status on startup + if self.config.enabled: + logger.info(f"Telemetry enabled (sampling at {self.config.sample_rate}%)") + # Initialize PostHog client if config is available + self._initialize_posthog() + else: + logger.info("Telemetry disabled") + + def _initialize_posthog(self) -> bool: + """Initialize the PostHog client with configuration. + + Returns: + bool: True if initialized successfully, False otherwise + """ + if self.initialized: + return True + + posthog_config = get_posthog_config() + + try: + # Initialize the PostHog client + posthog.api_key = posthog_config["api_key"] + posthog.host = posthog_config["host"] + + # Configure the client + posthog.debug = os.environ.get("CUA_TELEMETRY_DEBUG", "").lower() == "on" + posthog.disabled = not self.config.enabled + + # Log telemetry status + if not posthog.disabled: + logger.info( + f"Initializing PostHog telemetry with installation ID: {self.installation_id}" + ) + if posthog.debug: + logger.debug(f"PostHog API Key: {posthog.api_key}") + logger.debug(f"PostHog Host: {posthog.host}") + else: + logger.info("PostHog telemetry is disabled") + + # Identify this installation + self._identify() + + # Process any queued events + for event in self.queued_events: + posthog.capture( + distinct_id=self.installation_id, + event=event["event"], + properties=event["properties"], + ) + self.queued_events = [] + + self.initialized = True + return True + except Exception as e: + logger.warning(f"Failed to initialize PostHog: {e}") + return False + + def _identify(self) -> None: + """Identify the current installation with PostHog.""" + try: + properties = { + "version": __version__, + "is_ci": "CI" in os.environ, + "os": os.name, + "python_version": sys.version.split()[0], + } + + logger.debug( + f"Identifying PostHog user: {self.installation_id} with properties: {properties}" + ) + posthog.identify( + distinct_id=self.installation_id, + properties=properties, + ) + except Exception as e: + logger.warning(f"Failed to identify with PostHog: {e}") + + def _get_or_create_installation_id(self) -> str: + """Get or create a unique installation ID that persists across runs. + + The ID is always stored within the core library directory itself, + ensuring it persists regardless of how the library is used. + + This ID is not tied to any personal information. + """ + # Get the core library directory (where this file is located) + try: + # Find the core module directory using this file's location + core_module_dir = Path( + __file__ + ).parent.parent # core/telemetry/posthog_client.py -> core/telemetry -> core + storage_dir = core_module_dir / ".storage" + storage_dir.mkdir(exist_ok=True) + + id_file = storage_dir / "installation_id" + + # Try to read existing ID + if id_file.exists(): + try: + stored_id = id_file.read_text().strip() + if stored_id: # Make sure it's not empty + logger.debug(f"Using existing installation ID: {stored_id}") + return stored_id + except Exception as e: + logger.debug(f"Error reading installation ID file: {e}") + + # Create new ID + new_id = str(uuid.uuid4()) + try: + id_file.write_text(new_id) + logger.debug(f"Created new installation ID: {new_id}") + return new_id + except Exception as e: + logger.warning(f"Could not write installation ID: {e}") + except Exception as e: + logger.warning(f"Error accessing core module directory: {e}") + + # Last resort: Create a new in-memory ID + logger.warning("Using random installation ID (will not persist across runs)") + return str(uuid.uuid4()) + + def increment(self, counter_name: str, value: int = 1) -> None: + """Increment a named counter. + + Args: + counter_name: Name of the counter + value: Amount to increment by (default: 1) + """ + if not self.config.enabled: + return + + # Apply sampling to reduce number of events + if random.random() * 100 > self.config.sample_rate: + return + + properties = { + "value": value, + "counter_name": counter_name, + "version": __version__, + } + + if self.initialized: + try: + posthog.capture( + distinct_id=self.installation_id, + event="counter_increment", + properties=properties, + ) + except Exception as e: + logger.debug(f"Failed to send counter event to PostHog: {e}") + else: + # Queue the event for later + self.queued_events.append({"event": "counter_increment", "properties": properties}) + # Try to initialize now if not already + self._initialize_posthog() + + def record_event(self, event_name: str, properties: Optional[Dict[str, Any]] = None) -> None: + """Record an event with optional properties. + + Args: + event_name: Name of the event + properties: Event properties (must not contain sensitive data) + """ + if not self.config.enabled: + logger.debug(f"Telemetry disabled, skipping event: {event_name}") + return + + # Apply sampling to reduce number of events + if random.random() * 100 > self.config.sample_rate: + logger.debug( + f"Event sampled out due to sampling rate {self.config.sample_rate}%: {event_name}" + ) + return + + event_properties = {"version": __version__, **(properties or {})} + + logger.info(f"Recording event: {event_name} with properties: {event_properties}") + + if self.initialized: + try: + posthog.capture( + distinct_id=self.installation_id, event=event_name, properties=event_properties + ) + logger.info(f"Sent event to PostHog: {event_name}") + # Flush immediately to ensure delivery + posthog.flush() + except Exception as e: + logger.warning(f"Failed to send event to PostHog: {e}") + else: + # Queue the event for later + logger.info(f"PostHog not initialized, queuing event for later: {event_name}") + self.queued_events.append({"event": event_name, "properties": event_properties}) + # Try to initialize now if not already + initialize_result = self._initialize_posthog() + logger.info(f"Attempted to initialize PostHog: {initialize_result}") + + def flush(self) -> bool: + """Flush any pending events to PostHog. + + Returns: + bool: True if successful, False otherwise + """ + if not self.config.enabled: + return False + + if not self.initialized and not self._initialize_posthog(): + return False + + try: + posthog.flush() + return True + except Exception as e: + logger.debug(f"Failed to flush PostHog events: {e}") + return False + + def enable(self) -> None: + """Enable telemetry collection.""" + self.config.enabled = True + if posthog: + posthog.disabled = False + logger.info("Telemetry enabled") + self._initialize_posthog() + + def disable(self) -> None: + """Disable telemetry collection.""" + self.config.enabled = False + if posthog: + posthog.disabled = True + logger.info("Telemetry disabled") + + +# Global telemetry client instance +_client: Optional[PostHogTelemetryClient] = None + + +def get_posthog_telemetry_client() -> PostHogTelemetryClient: + """Get or initialize the global PostHog telemetry client. + + Returns: + The global telemetry client instance + """ + global _client + + if _client is None: + _client = PostHogTelemetryClient() + + return _client + + +def disable_telemetry() -> None: + """Disable telemetry collection globally.""" + if _client is not None: + _client.disable() diff --git a/libs/core/core/telemetry/telemetry.py b/libs/core/core/telemetry/telemetry.py new file mode 100644 index 00000000..2a6e052e --- /dev/null +++ b/libs/core/core/telemetry/telemetry.py @@ -0,0 +1,311 @@ +"""Universal telemetry module for collecting anonymous usage data. +This module provides a unified interface for telemetry collection, +using PostHog as the backend. +""" + +from __future__ import annotations + +import logging +import os +import sys +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Optional, Union + + +# Configure telemetry logging before importing anything else +# By default, set telemetry loggers to WARNING level to hide INFO messages +# This can be overridden with CUA_TELEMETRY_LOG_LEVEL environment variable +def _configure_telemetry_logging() -> None: + """Set up initial logging configuration for telemetry.""" + # Determine log level from environment variable or use WARNING by default + env_level = os.environ.get("CUA_TELEMETRY_LOG_LEVEL", "WARNING").upper() + level = logging.WARNING # Default to WARNING to hide INFO messages + + if env_level == "DEBUG": + level = logging.DEBUG + elif env_level == "INFO": + level = logging.INFO + elif env_level == "ERROR": + level = logging.ERROR + + # Configure the main telemetry logger + telemetry_logger = logging.getLogger("cua.telemetry") + telemetry_logger.setLevel(level) + + +# Configure logging immediately +_configure_telemetry_logging() + +# Import telemetry backend +try: + from core.telemetry.posthog_client import ( + PostHogTelemetryClient, + get_posthog_telemetry_client, + ) + + POSTHOG_AVAILABLE = True +except ImportError: + logger = logging.getLogger("cua.telemetry") + logger.info("PostHog not available. Install with: pdm add posthog") + POSTHOG_AVAILABLE = False + +logger = logging.getLogger("cua.telemetry") + + +# Check environment variables for global telemetry opt-out +def is_telemetry_globally_disabled() -> bool: + """Check if telemetry is globally disabled via environment variables. + + Returns: + bool: True if telemetry is globally disabled, False otherwise + """ + # Only check for CUA_TELEMETRY_ENABLED - telemetry is enabled only if explicitly set to a truthy value + telemetry_enabled = os.environ.get("CUA_TELEMETRY_ENABLED", "true").lower() + return telemetry_enabled not in ("1", "true", "yes", "on") + + +class TelemetryBackend(str, Enum): + """Available telemetry backend types.""" + + POSTHOG = "posthog" + NONE = "none" + + +class UniversalTelemetryClient: + """Universal telemetry client that delegates to the PostHog backend.""" + + def __init__( + self, + backend: Optional[str] = None, + ): + """Initialize the universal telemetry client. + + Args: + backend: Backend to use ("posthog" or "none") + If not specified, will try PostHog + """ + # Check for global opt-out first + if is_telemetry_globally_disabled(): + self.backend_type = TelemetryBackend.NONE + logger.info("Telemetry globally disabled via environment variable") + # Determine which backend to use + elif backend and backend.lower() == "none": + self.backend_type = TelemetryBackend.NONE + else: + # Auto-detect based on environment variables and available backends + if POSTHOG_AVAILABLE: + self.backend_type = TelemetryBackend.POSTHOG + else: + self.backend_type = TelemetryBackend.NONE + logger.warning("PostHog is not available, telemetry will be disabled") + + # Initialize the appropriate client + self._client = self._initialize_client() + self._enabled = self.backend_type != TelemetryBackend.NONE + + def _initialize_client(self) -> Any: + """Initialize the appropriate telemetry client based on the selected backend.""" + if self.backend_type == TelemetryBackend.POSTHOG and POSTHOG_AVAILABLE: + logger.debug("Initializing PostHog telemetry client") + return get_posthog_telemetry_client() + else: + logger.debug("No telemetry client initialized") + return None + + def increment(self, counter_name: str, value: int = 1) -> None: + """Increment a named counter. + + Args: + counter_name: Name of the counter + value: Amount to increment by (default: 1) + """ + if self._client and self._enabled: + self._client.increment(counter_name, value) + + def record_event(self, event_name: str, properties: Optional[Dict[str, Any]] = None) -> None: + """Record an event with optional properties. + + Args: + event_name: Name of the event + properties: Event properties (must not contain sensitive data) + """ + if self._client and self._enabled: + self._client.record_event(event_name, properties) + + def flush(self) -> bool: + """Flush any pending events to the backend. + + Returns: + bool: True if successful, False otherwise + """ + if self._client and self._enabled: + return self._client.flush() + return False + + def enable(self) -> None: + """Enable telemetry collection.""" + if self._client and not is_telemetry_globally_disabled(): + self._client.enable() + self._enabled = True + else: + if is_telemetry_globally_disabled(): + logger.info("Cannot enable telemetry: globally disabled via environment variable") + self._enabled = False + + def disable(self) -> None: + """Disable telemetry collection.""" + if self._client: + self._client.disable() + self._enabled = False + + def is_enabled(self) -> bool: + """Check if telemetry is enabled. + + Returns: + bool: True if telemetry is enabled, False otherwise + """ + return self._enabled and not is_telemetry_globally_disabled() + + +# Global telemetry client instance +_universal_client: Optional[UniversalTelemetryClient] = None + + +def get_telemetry_client( + backend: Optional[str] = None, +) -> UniversalTelemetryClient: + """Get or initialize the global telemetry client. + + Args: + backend: Backend to use ("posthog" or "none") + + Returns: + The global telemetry client instance + """ + global _universal_client + + if _universal_client is None: + _universal_client = UniversalTelemetryClient(backend) + + return _universal_client + + +def increment(counter_name: str, value: int = 1) -> None: + """Increment a named counter using the global telemetry client. + + Args: + counter_name: Name of the counter + value: Amount to increment by (default: 1) + """ + client = get_telemetry_client() + client.increment(counter_name, value) + + +def record_event(event_name: str, properties: Optional[Dict[str, Any]] = None) -> None: + """Record an event with optional properties using the global telemetry client. + + Args: + event_name: Name of the event + properties: Event properties (must not contain sensitive data) + """ + client = get_telemetry_client() + client.record_event(event_name, properties) + + +def flush() -> bool: + """Flush any pending events using the global telemetry client. + + Returns: + bool: True if successful, False otherwise + """ + client = get_telemetry_client() + return client.flush() + + +def enable_telemetry() -> bool: + """Enable telemetry collection globally. + + Returns: + bool: True if successfully enabled, False if globally disabled + """ + if is_telemetry_globally_disabled(): + logger.info("Cannot enable telemetry: globally disabled via environment variable") + return False + + client = get_telemetry_client() + client.enable() + return True + + +def disable_telemetry() -> None: + """Disable telemetry collection globally.""" + client = get_telemetry_client() + client.disable() + + +def is_telemetry_enabled() -> bool: + """Check if telemetry is enabled. + + Returns: + bool: True if telemetry is enabled, False otherwise + """ + # First check for global disable + if is_telemetry_globally_disabled(): + return False + + # Get the global client and check + client = get_telemetry_client() + return client.is_enabled() + + +def set_telemetry_log_level(level: Optional[int] = None) -> None: + """Set the logging level for telemetry loggers to reduce console output. + + By default, checks the CUA_TELEMETRY_LOG_LEVEL environment variable: + - If set to "DEBUG", sets level to logging.DEBUG + - If set to "INFO", sets level to logging.INFO + - If set to "WARNING", sets level to logging.WARNING + - If set to "ERROR", sets level to logging.ERROR + - If not set, defaults to logging.WARNING + + This means telemetry logs will only show up when explicitly requested via + the environment variable, not during normal operation. + + Args: + level: The logging level to set (overrides environment variable if provided) + """ + # Determine the level from environment variable if not explicitly provided + if level is None: + env_level = os.environ.get("CUA_TELEMETRY_LOG_LEVEL", "WARNING").upper() + if env_level == "DEBUG": + level = logging.DEBUG + elif env_level == "INFO": + level = logging.INFO + elif env_level == "WARNING": + level = logging.WARNING + elif env_level == "ERROR": + level = logging.ERROR + else: + # Default to WARNING if environment variable is not recognized + level = logging.WARNING + + # Set the level for all telemetry-related loggers + telemetry_loggers = [ + "cua.telemetry", + "core.telemetry", + "cua.agent.telemetry", + "cua.computer.telemetry", + "posthog", + ] + + for logger_name in telemetry_loggers: + try: + logging.getLogger(logger_name).setLevel(level) + except Exception: + pass + + +# Set telemetry loggers to appropriate level based on environment variable +# This is called at module import time to ensure proper configuration before any logging happens +set_telemetry_log_level() diff --git a/libs/core/pdm.lock b/libs/core/pdm.lock new file mode 100644 index 00000000..61935145 --- /dev/null +++ b/libs/core/pdm.lock @@ -0,0 +1,411 @@ +# This file is @generated by PDM. +# It is not intended for manual editing. + +[metadata] +groups = ["default", "dev"] +strategy = [] +lock_version = "4.5.0" +content_hash = "sha256:012f523673653e261a7b65007c36c67b540b2477da9bf3a71a849ae36aeeb7b1" + +[[metadata.targets]] +requires_python = ">=3.10,<3.13" + +[[package]] +name = "annotated-types" +version = "0.7.0" +summary = "" +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] + +[[package]] +name = "anyio" +version = "4.8.0" +summary = "" +dependencies = [ + "exceptiongroup; python_full_version < \"3.11\"", + "idna", + "sniffio", + "typing-extensions", +] +files = [ + {file = "anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a"}, + {file = "anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a"}, +] + +[[package]] +name = "backoff" +version = "2.2.1" +summary = "" +files = [ + {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, + {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, +] + +[[package]] +name = "certifi" +version = "2025.1.31" +summary = "" +files = [ + {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, + {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.1" +summary = "" +files = [ + {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f"}, + {file = "charset_normalizer-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-win32.whl", hash = "sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b"}, + {file = "charset_normalizer-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-win32.whl", hash = "sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35"}, + {file = "charset_normalizer-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f"}, + {file = "charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85"}, + {file = "charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3"}, +] + +[[package]] +name = "colorama" +version = "0.4.6" +summary = "" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "distro" +version = "1.9.0" +summary = "" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + +[[package]] +name = "exceptiongroup" +version = "1.2.2" +summary = "" +files = [ + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, +] + +[[package]] +name = "h11" +version = "0.14.0" +summary = "" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.7" +summary = "" +dependencies = [ + "certifi", + "h11", +] +files = [ + {file = "httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd"}, + {file = "httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c"}, +] + +[[package]] +name = "httpx" +version = "0.28.1" +summary = "" +dependencies = [ + "anyio", + "certifi", + "httpcore", + "idna", +] +files = [ + {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, + {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, +] + +[[package]] +name = "idna" +version = "3.10" +summary = "" +files = [ + {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, + {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, +] + +[[package]] +name = "iniconfig" +version = "2.0.0" +summary = "" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "monotonic" +version = "1.6" +summary = "" +files = [ + {file = "monotonic-1.6-py2.py3-none-any.whl", hash = "sha256:68687e19a14f11f26d140dd5c86f3dba4bf5df58003000ed467e0e2a69bca96c"}, + {file = "monotonic-1.6.tar.gz", hash = "sha256:3a55207bcfed53ddd5c5bae174524062935efed17792e9de2ad0205ce9ad63f7"}, +] + +[[package]] +name = "packaging" +version = "24.2" +summary = "" +files = [ + {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, + {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +summary = "" +files = [ + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, +] + +[[package]] +name = "posthog" +version = "3.20.0" +summary = "" +dependencies = [ + "backoff", + "distro", + "monotonic", + "python-dateutil", + "requests", + "six", +] +files = [ + {file = "posthog-3.20.0-py2.py3-none-any.whl", hash = "sha256:ce3aa75a39c36bc3af2b6947757493e6c7d021fe5088b185d3277157770d4ef4"}, + {file = "posthog-3.20.0.tar.gz", hash = "sha256:7933f7c98c0152a34e387e441fefdc62e2b86aade5dea94dc6ecbe7358138828"}, +] + +[[package]] +name = "pydantic" +version = "2.10.6" +summary = "" +dependencies = [ + "annotated-types", + "pydantic-core", + "typing-extensions", +] +files = [ + {file = "pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584"}, + {file = "pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236"}, +] + +[[package]] +name = "pydantic-core" +version = "2.27.2" +summary = "" +dependencies = [ + "typing-extensions", +] +files = [ + {file = "pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa"}, + {file = "pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7969e133a6f183be60e9f6f56bfae753585680f3b7307a8e555a948d443cc05a"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3de9961f2a346257caf0aa508a4da705467f53778e9ef6fe744c038119737ef5"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2bb4d3e5873c37bb3dd58714d4cd0b0e6238cebc4177ac8fe878f8b3aa8e74c"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:280d219beebb0752699480fe8f1dc61ab6615c2046d76b7ab7ee38858de0a4e7"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47956ae78b6422cbd46f772f1746799cbb862de838fd8d1fbd34a82e05b0983a"}, + {file = "pydantic_core-2.27.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:14d4a5c49d2f009d62a2a7140d3064f686d17a5d1a268bc641954ba181880236"}, + {file = "pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:337b443af21d488716f8d0b6164de833e788aa6bd7e3a39c005febc1284f4962"}, + {file = "pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:03d0f86ea3184a12f41a2d23f7ccb79cdb5a18e06993f8a45baa8dfec746f0e9"}, + {file = "pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7041c36f5680c6e0f08d922aed302e98b3745d97fe1589db0a3eebf6624523af"}, + {file = "pydantic_core-2.27.2-cp310-cp310-win32.whl", hash = "sha256:50a68f3e3819077be2c98110c1f9dcb3817e93f267ba80a2c05bb4f8799e2ff4"}, + {file = "pydantic_core-2.27.2-cp310-cp310-win_amd64.whl", hash = "sha256:e0fd26b16394ead34a424eecf8a31a1f5137094cabe84a1bcb10fa6ba39d3d31"}, + {file = "pydantic_core-2.27.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:8e10c99ef58cfdf2a66fc15d66b16c4a04f62bca39db589ae8cba08bc55331bc"}, + {file = "pydantic_core-2.27.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:26f32e0adf166a84d0cb63be85c562ca8a6fa8de28e5f0d92250c6b7e9e2aff7"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c19d1ea0673cd13cc2f872f6c9ab42acc4e4f492a7ca9d3795ce2b112dd7e15"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e68c4446fe0810e959cdff46ab0a41ce2f2c86d227d96dc3847af0ba7def306"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9640b0059ff4f14d1f37321b94061c6db164fbe49b334b31643e0528d100d99"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:40d02e7d45c9f8af700f3452f329ead92da4c5f4317ca9b896de7ce7199ea459"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c1fd185014191700554795c99b347d64f2bb637966c4cfc16998a0ca700d048"}, + {file = "pydantic_core-2.27.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d81d2068e1c1228a565af076598f9e7451712700b673de8f502f0334f281387d"}, + {file = "pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1a4207639fb02ec2dbb76227d7c751a20b1a6b4bc52850568e52260cae64ca3b"}, + {file = "pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:3de3ce3c9ddc8bbd88f6e0e304dea0e66d843ec9de1b0042b0911c1663ffd474"}, + {file = "pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:30c5f68ded0c36466acede341551106821043e9afaad516adfb6e8fa80a4e6a6"}, + {file = "pydantic_core-2.27.2-cp311-cp311-win32.whl", hash = "sha256:c70c26d2c99f78b125a3459f8afe1aed4d9687c24fd677c6a4436bc042e50d6c"}, + {file = "pydantic_core-2.27.2-cp311-cp311-win_amd64.whl", hash = "sha256:08e125dbdc505fa69ca7d9c499639ab6407cfa909214d500897d02afb816e7cc"}, + {file = "pydantic_core-2.27.2-cp311-cp311-win_arm64.whl", hash = "sha256:26f0d68d4b235a2bae0c3fc585c585b4ecc51382db0e3ba402a22cbc440915e4"}, + {file = "pydantic_core-2.27.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9e0c8cfefa0ef83b4da9588448b6d8d2a2bf1a53c3f1ae5fca39eb3061e2f0b0"}, + {file = "pydantic_core-2.27.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:83097677b8e3bd7eaa6775720ec8e0405f1575015a463285a92bfdfe254529ef"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:172fce187655fece0c90d90a678424b013f8fbb0ca8b036ac266749c09438cb7"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:519f29f5213271eeeeb3093f662ba2fd512b91c5f188f3bb7b27bc5973816934"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05e3a55d124407fffba0dd6b0c0cd056d10e983ceb4e5dbd10dda135c31071d6"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c3ed807c7b91de05e63930188f19e921d1fe90de6b4f5cd43ee7fcc3525cb8c"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fb4aadc0b9a0c063206846d603b92030eb6f03069151a625667f982887153e2"}, + {file = "pydantic_core-2.27.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:28ccb213807e037460326424ceb8b5245acb88f32f3d2777427476e1b32c48c4"}, + {file = "pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:de3cd1899e2c279b140adde9357c4495ed9d47131b4a4eaff9052f23398076b3"}, + {file = "pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:220f892729375e2d736b97d0e51466252ad84c51857d4d15f5e9692f9ef12be4"}, + {file = "pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a0fcd29cd6b4e74fe8ddd2c90330fd8edf2e30cb52acda47f06dd615ae72da57"}, + {file = "pydantic_core-2.27.2-cp312-cp312-win32.whl", hash = "sha256:1e2cb691ed9834cd6a8be61228471d0a503731abfb42f82458ff27be7b2186fc"}, + {file = "pydantic_core-2.27.2-cp312-cp312-win_amd64.whl", hash = "sha256:cc3f1a99a4f4f9dd1de4fe0312c114e740b5ddead65bb4102884b384c15d8bc9"}, + {file = "pydantic_core-2.27.2-cp312-cp312-win_arm64.whl", hash = "sha256:3911ac9284cd8a1792d3cb26a2da18f3ca26c6908cc434a18f730dc0db7bfa3b"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2bf14caea37e91198329b828eae1618c068dfb8ef17bb33287a7ad4b61ac314e"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b0cb791f5b45307caae8810c2023a184c74605ec3bcbb67d13846c28ff731ff8"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:688d3fd9fcb71f41c4c015c023d12a79d1c4c0732ec9eb35d96e3388a120dcf3"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d591580c34f4d731592f0e9fe40f9cc1b430d297eecc70b962e93c5c668f15f"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:82f986faf4e644ffc189a7f1aafc86e46ef70372bb153e7001e8afccc6e54133"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bec317a27290e2537f922639cafd54990551725fc844249e64c523301d0822fc"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:0296abcb83a797db256b773f45773da397da75a08f5fcaef41f2044adec05f50"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0d75070718e369e452075a6017fbf187f788e17ed67a3abd47fa934d001863d9"}, + {file = "pydantic_core-2.27.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7e17b560be3c98a8e3aa66ce828bdebb9e9ac6ad5466fba92eb74c4c95cb1151"}, + {file = "pydantic_core-2.27.2.tar.gz", hash = "sha256:eb026e5a4c1fee05726072337ff51d1efb6f59090b7da90d30ea58625b1ffb39"}, +] + +[[package]] +name = "pytest" +version = "8.3.5" +summary = "" +dependencies = [ + "colorama; sys_platform == \"win32\"", + "exceptiongroup; python_full_version < \"3.11\"", + "iniconfig", + "packaging", + "pluggy", + "tomli; python_full_version < \"3.11\"", +] +files = [ + {file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"}, + {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"}, +] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +summary = "" +dependencies = [ + "six", +] +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[[package]] +name = "requests" +version = "2.32.3" +summary = "" +dependencies = [ + "certifi", + "charset-normalizer", + "idna", + "urllib3", +] +files = [ + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, +] + +[[package]] +name = "six" +version = "1.17.0" +summary = "" +files = [ + {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, + {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +summary = "" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + +[[package]] +name = "tomli" +version = "2.2.1" +summary = "" +files = [ + {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"}, + {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"}, + {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"}, + {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"}, + {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"}, + {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"}, + {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, +] + +[[package]] +name = "typing-extensions" +version = "4.12.2" +summary = "" +files = [ + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, +] + +[[package]] +name = "urllib3" +version = "2.3.0" +summary = "" +files = [ + {file = "urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df"}, + {file = "urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d"}, +] diff --git a/libs/core/poetry.toml b/libs/core/poetry.toml new file mode 100644 index 00000000..b6125d29 --- /dev/null +++ b/libs/core/poetry.toml @@ -0,0 +1,2 @@ +[virtualenvs] +in-project = true \ No newline at end of file diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml new file mode 100644 index 00000000..db7ffbab --- /dev/null +++ b/libs/core/pyproject.toml @@ -0,0 +1,57 @@ +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" + +[project] +name = "cua-core" +version = "0.1.0" +description = "Core functionality for Cua including telemetry and shared utilities" +readme = "README.md" +authors = [ + { name = "TryCua", email = "gh@trycua.com" } +] +dependencies = [ + "pydantic>=2.0.0", + "httpx>=0.24.0", + "posthog>=3.20.0" +] +requires-python = ">=3.10,<3.13" + +[tool.pdm] +distribution = true + +[tool.pdm.build] +includes = ["core/"] +source-includes = ["tests/", "README.md", "LICENSE"] + +[tool.black] +line-length = 100 +target-version = ["py310"] + +[tool.ruff] +line-length = 100 +target-version = "py310" +select = ["E", "F", "B", "I"] +fix = true + +[tool.ruff.format] +docstring-code-format = true + +[tool.mypy] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = "test_*.py" +[dependency-groups] +dev = [ + "pytest>=8.3.5", +] diff --git a/libs/core/tests/test_posthog_telemetry.py b/libs/core/tests/test_posthog_telemetry.py new file mode 100644 index 00000000..d3d0f2ef --- /dev/null +++ b/libs/core/tests/test_posthog_telemetry.py @@ -0,0 +1,154 @@ +"""Tests for the PostHog telemetry client.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from core.telemetry.posthog_client import ( + PostHogTelemetryClient, + TelemetryConfig, + get_posthog_config, + get_posthog_telemetry_client, +) + + +@pytest.fixture +def mock_environment(): + """Set up and tear down environment variables for testing.""" + original_env = os.environ.copy() + os.environ["CUA_TELEMETRY_SAMPLE_RATE"] = "100" + # Remove PostHog env vars as they're hardcoded now + # os.environ["CUA_POSTHOG_API_KEY"] = "test-api-key" + # os.environ["CUA_POSTHOG_HOST"] = "https://test.posthog.com" + + yield + + # Restore original environment + os.environ.clear() + os.environ.update(original_env) + + +@pytest.fixture +def mock_disabled_environment(): + """Set up and tear down environment variables with telemetry disabled.""" + original_env = os.environ.copy() + os.environ["CUA_TELEMETRY"] = "off" + os.environ["CUA_TELEMETRY_SAMPLE_RATE"] = "100" + # Remove PostHog env vars as they're hardcoded now + # os.environ["CUA_POSTHOG_API_KEY"] = "test-api-key" + # os.environ["CUA_POSTHOG_HOST"] = "https://test.posthog.com" + + yield + + # Restore original environment + os.environ.clear() + os.environ.update(original_env) + + +class TestTelemetryConfig: + """Tests for telemetry configuration.""" + + def test_from_env_defaults(self): + """Test loading config from environment with defaults.""" + # Clear relevant environment variables + with patch.dict( + os.environ, + { + k: v + for k, v in os.environ.items() + if k not in ["CUA_TELEMETRY", "CUA_TELEMETRY_SAMPLE_RATE"] + }, + ): + config = TelemetryConfig.from_env() + assert config.enabled is True # Default is now enabled + assert config.sample_rate == 5 + assert config.project_root is None + + def test_from_env_with_vars(self, mock_environment): + """Test loading config from environment variables.""" + config = TelemetryConfig.from_env() + assert config.enabled is True + assert config.sample_rate == 100 + assert config.project_root is None + + def test_from_env_disabled(self, mock_disabled_environment): + """Test disabling telemetry via environment variable.""" + config = TelemetryConfig.from_env() + assert config.enabled is False + assert config.sample_rate == 100 + assert config.project_root is None + + def test_to_dict(self): + """Test converting config to dictionary.""" + config = TelemetryConfig(enabled=True, sample_rate=50) + config_dict = config.to_dict() + assert config_dict == {"enabled": True, "sample_rate": 50} + + +class TestPostHogConfig: + """Tests for PostHog configuration.""" + + def test_get_posthog_config(self): + """Test getting PostHog config.""" + config = get_posthog_config() + assert config is not None + assert config["api_key"] == "phc_eSkLnbLxsnYFaXksif1ksbrNzYlJShr35miFLDppF14" + assert config["host"] == "https://eu.i.posthog.com" + + +class TestPostHogTelemetryClient: + """Tests for PostHog telemetry client.""" + + @patch("posthog.capture") + @patch("posthog.identify") + def test_initialization(self, mock_identify, mock_capture, mock_environment): + """Test client initialization.""" + client = PostHogTelemetryClient() + assert client.config.enabled is True + assert client.initialized is True + mock_identify.assert_called_once() + + @patch("posthog.capture") + def test_increment_counter(self, mock_capture, mock_environment): + """Test incrementing a counter.""" + client = PostHogTelemetryClient() + client.increment("test_counter", 5) + mock_capture.assert_called_once() + args, kwargs = mock_capture.call_args + assert kwargs["event"] == "counter_increment" + assert kwargs["properties"]["counter_name"] == "test_counter" + assert kwargs["properties"]["value"] == 5 + + @patch("posthog.capture") + def test_record_event(self, mock_capture, mock_environment): + """Test recording an event.""" + client = PostHogTelemetryClient() + client.record_event("test_event", {"param": "value"}) + mock_capture.assert_called_once() + args, kwargs = mock_capture.call_args + assert kwargs["event"] == "test_event" + assert kwargs["properties"]["param"] == "value" + + @patch("posthog.capture") + def test_disabled_client(self, mock_capture, mock_environment): + """Test that disabled client doesn't send events.""" + client = PostHogTelemetryClient() + client.disable() + client.increment("test_counter") + client.record_event("test_event") + mock_capture.assert_not_called() + + @patch("posthog.flush") + def test_flush(self, mock_flush, mock_environment): + """Test flushing events.""" + client = PostHogTelemetryClient() + result = client.flush() + assert result is True + mock_flush.assert_called_once() + + def test_global_client(self, mock_environment): + """Test global client initialization.""" + client1 = get_posthog_telemetry_client() + client2 = get_posthog_telemetry_client() + assert client1 is client2 # Same instance diff --git a/libs/core/tests/test_telemetry.py b/libs/core/tests/test_telemetry.py new file mode 100644 index 00000000..5b9c256d --- /dev/null +++ b/libs/core/tests/test_telemetry.py @@ -0,0 +1,169 @@ +"""Tests for the telemetry module.""" + +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from core.telemetry import ( + UniversalTelemetryClient, + disable_telemetry, + enable_telemetry, + get_telemetry_client, +) + + +@pytest.fixture +def mock_project_root(tmp_path): + """Create a temporary directory for testing.""" + return tmp_path + + +@pytest.fixture +def mock_environment(): + """Set up and tear down environment variables for testing.""" + original_env = os.environ.copy() + os.environ["CUA_TELEMETRY_SAMPLE_RATE"] = "100" + + yield + + # Restore original environment + os.environ.clear() + os.environ.update(original_env) + + +@pytest.fixture +def mock_disabled_environment(): + """Set up environment variables with telemetry disabled.""" + original_env = os.environ.copy() + os.environ["CUA_TELEMETRY"] = "off" + os.environ["CUA_TELEMETRY_SAMPLE_RATE"] = "100" + + yield + + # Restore original environment + os.environ.clear() + os.environ.update(original_env) + + +class TestTelemetryClient: + """Tests for the universal telemetry client.""" + + @patch("core.telemetry.telemetry.POSTHOG_AVAILABLE", True) + @patch("core.telemetry.telemetry.get_posthog_telemetry_client") + def test_initialization(self, mock_get_posthog, mock_project_root, mock_environment): + """Test client initialization.""" + mock_client = MagicMock() + mock_get_posthog.return_value = mock_client + + client = UniversalTelemetryClient(mock_project_root) + assert client._client is not None + mock_get_posthog.assert_called_once_with(mock_project_root) + + @patch("core.telemetry.telemetry.POSTHOG_AVAILABLE", True) + @patch("core.telemetry.telemetry.get_posthog_telemetry_client") + def test_increment(self, mock_get_posthog, mock_project_root, mock_environment): + """Test incrementing counters.""" + mock_client = MagicMock() + mock_get_posthog.return_value = mock_client + + client = UniversalTelemetryClient(mock_project_root) + client.increment("test_counter", 5) + + mock_client.increment.assert_called_once_with("test_counter", 5) + + @patch("core.telemetry.telemetry.POSTHOG_AVAILABLE", True) + @patch("core.telemetry.telemetry.get_posthog_telemetry_client") + def test_record_event(self, mock_get_posthog, mock_project_root, mock_environment): + """Test recording events.""" + mock_client = MagicMock() + mock_get_posthog.return_value = mock_client + + client = UniversalTelemetryClient(mock_project_root) + client.record_event("test_event", {"prop1": "value1"}) + + mock_client.record_event.assert_called_once_with("test_event", {"prop1": "value1"}) + + @patch("core.telemetry.telemetry.POSTHOG_AVAILABLE", True) + @patch("core.telemetry.telemetry.get_posthog_telemetry_client") + def test_flush(self, mock_get_posthog, mock_project_root, mock_environment): + """Test flushing telemetry data.""" + mock_client = MagicMock() + mock_client.flush.return_value = True + mock_get_posthog.return_value = mock_client + + client = UniversalTelemetryClient(mock_project_root) + result = client.flush() + + assert result is True + mock_client.flush.assert_called_once() + + @patch("core.telemetry.telemetry.POSTHOG_AVAILABLE", True) + @patch("core.telemetry.telemetry.get_posthog_telemetry_client") + def test_enable_disable(self, mock_get_posthog, mock_project_root): + """Test enabling and disabling telemetry.""" + mock_client = MagicMock() + mock_get_posthog.return_value = mock_client + + client = UniversalTelemetryClient(mock_project_root) + + client.enable() + mock_client.enable.assert_called_once() + + client.disable() + mock_client.disable.assert_called_once() + + +def test_get_telemetry_client(): + """Test the global client getter.""" + # Reset global state + from core.telemetry.telemetry import _universal_client + + _universal_client = None + + with patch("core.telemetry.telemetry.UniversalTelemetryClient") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # First call should create a new client + client1 = get_telemetry_client() + assert client1 is mock_client + mock_client_class.assert_called_once() + + # Second call should return the same client + client2 = get_telemetry_client() + assert client2 is client1 + assert mock_client_class.call_count == 1 + + +def test_disable_telemetry(): + """Test the global disable function.""" + # Reset global state + from core.telemetry.telemetry import _universal_client + + _universal_client = None + + with patch("core.telemetry.telemetry.get_telemetry_client") as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # Disable globally + disable_telemetry() + mock_client.disable.assert_called_once() + + +def test_enable_telemetry(): + """Test the global enable function.""" + # Reset global state + from core.telemetry.telemetry import _universal_client + + _universal_client = None + + with patch("core.telemetry.telemetry.get_telemetry_client") as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # Enable globally + enable_telemetry() + mock_client.enable.assert_called_once() diff --git a/libs/lume/CONTRIBUTING.md b/libs/lume/CONTRIBUTING.md new file mode 100644 index 00000000..6c51a416 --- /dev/null +++ b/libs/lume/CONTRIBUTING.md @@ -0,0 +1,39 @@ +# Contributing to lume + +We deeply appreciate your interest in contributing to lume! Whether you're reporting bugs, suggesting enhancements, improving docs, or submitting pull requests, your contributions help improve the project for everyone. + +## Reporting Bugs + +If you've encountered a bug in the project, we encourage you to report it. Please follow these steps: + +1. **Check the Issue Tracker**: Before submitting a new bug report, please check our issue tracker to see if the bug has already been reported. +2. **Create a New Issue**: If the bug hasn't been reported, create a new issue with: + - A clear title and detailed description + - Steps to reproduce the issue + - Expected vs actual behavior + - Your environment (macOS version, lume version) + - Any relevant logs or error messages +3. **Label Your Issue**: Label your issue as a `bug` to help maintainers identify it quickly. + +## Suggesting Enhancements + +We're always looking for suggestions to make lume better. If you have an idea: + +1. **Check Existing Issues**: See if someone else has already suggested something similar. +2. **Create a New Issue**: If your enhancement is new, create an issue describing: + - The problem your enhancement solves + - How your enhancement would work + - Any potential implementation details + - Why this enhancement would benefit lume users + +## Documentation + +Documentation improvements are always welcome. You can: +- Fix typos or unclear explanations +- Add examples and use cases +- Improve API documentation +- Add tutorials or guides + +For detailed instructions on setting up your development environment and submitting code contributions, please see our [Development.md](docs/Development.md) guide. + +Feel free to join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get help with your contributions. \ No newline at end of file diff --git a/Package.resolved b/libs/lume/Package.resolved similarity index 100% rename from Package.resolved rename to libs/lume/Package.resolved diff --git a/Package.swift b/libs/lume/Package.swift similarity index 100% rename from Package.swift rename to libs/lume/Package.swift diff --git a/libs/lume/README.md b/libs/lume/README.md new file mode 100644 index 00000000..d5e271b5 --- /dev/null +++ b/libs/lume/README.md @@ -0,0 +1,175 @@ +
+

+
+ + + + Shows my svg + +
+ + [![Swift 6](https://img.shields.io/badge/Swift_6-F54A2A?logo=swift&logoColor=white&labelColor=F54A2A)](#) + [![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#) + [![Homebrew](https://img.shields.io/badge/Homebrew-FBB040?logo=homebrew&logoColor=fff)](#install) + [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85) +

+
+ + +**lume** is a lightweight Command Line Interface and local API server to create, run and manage macOS and Linux virtual machines (VMs) with near-native performance on Apple Silicon, using Apple's `Virtualization.Framework`. + +### Run prebuilt macOS images in just 1 step + +
+lume cli +
+ + +```bash +lume run macos-sequoia-vanilla:latest +``` + +## Development Environment + +If you're working on Lume in the context of the CUA monorepo, we recommend using the dedicated VS Code workspace configuration: + +```bash +# Open VS Code workspace from the root of the monorepo +code .vscode/lume.code-workspace +``` +This workspace is preconfigured with Swift language support, build tasks, and debug configurations. + +## Usage + +```bash +lume + +Commands: + lume create Create a new macOS or Linux VM + lume run Run a VM + lume ls List all VMs + lume get Get detailed information about a VM + lume set Modify VM configuration + lume stop Stop a running VM + lume delete Delete a VM + lume pull Pull a macOS image from container registry + lume clone Clone an existing VM + lume images List available macOS images in local cache + lume ipsw Get the latest macOS restore image URL + lume prune Remove cached images + lume serve Start the API server + +Options: + --help Show help [boolean] + --version Show version number [boolean] + +Command Options: + create: + --os Operating system to install (macOS or linux, default: macOS) + --cpu Number of CPU cores (default: 4) + --memory Memory size, e.g., 8GB (default: 4GB) + --disk-size Disk size, e.g., 50GB (default: 40GB) + --display Display resolution (default: 1024x768) + --ipsw Path to IPSW file or 'latest' for macOS VMs + + run: + --no-display Do not start the VNC client app + --shared-dir Share directory with VM (format: path[:ro|rw]) + --mount For Linux VMs only, attach a read-only disk image + --registry Container registry URL (default: ghcr.io) + --organization Organization to pull from (default: trycua) + --vnc-port Port to use for the VNC server (default: 0 for auto-assign) + --recovery-mode For MacOS VMs only, start VM in recovery mode (default: false) + + set: + --cpu New number of CPU cores (e.g., 4) + --memory New memory size (e.g., 8192MB or 8GB) + --disk-size New disk size (e.g., 40960MB or 40GB) + --display New display resolution in format WIDTHxHEIGHT (e.g., 1024x768) + + delete: + --force Force deletion without confirmation + + pull: + --registry Container registry URL (default: ghcr.io) + --organization Organization to pull from (default: trycua) + + serve: + --port Port to listen on (default: 3000) +``` + +## Install + +```bash +brew tap trycua/lume +brew install lume +``` + +You can also download the `lume.pkg.tar.gz` archive from the [latest release](https://github.com/trycua/lume/releases), extract it, and install the package manually. + +## Prebuilt Images + +Pre-built images are available in the registry [ghcr.io/trycua](https://github.com/orgs/trycua/packages). +These images come with an SSH server pre-configured and auto-login enabled. + +For the security of your VM, change the default password `lume` immediately after your first login. + +| Image | Tag | Description | Size | +|-------|------------|-------------|------| +| `macos-sequoia-vanilla` | `latest`, `15.2` | macOS Sequoia 15.2 | 40GB | +| `macos-sequoia-xcode` | `latest`, `15.2` | macOS Sequoia 15.2 with Xcode command line tools | 50GB | +| `ubuntu-noble-vanilla` | `latest`, `24.04.1` | [Ubuntu Server for ARM 24.04.1 LTS](https://ubuntu.com/download/server/arm) with Ubuntu Desktop | 20GB | + +For additional disk space, resize the VM disk after pulling the image using the `lume set --disk-size ` command. + +## Local API Server + +`lume` exposes a local HTTP API server that listens on `http://localhost:3000/lume`, enabling automated management of VMs. + +```bash +lume serve +``` + +For detailed API documentation, please refer to [API Reference](docs/API-Reference.md). + +## Docs + +- [API Reference](docs/API-Reference.md) +- [Development](docs/Development.md) +- [FAQ](docs/FAQ.md) + +## Contributing + +We welcome and greatly appreciate contributions to lume! Whether you're improving documentation, adding new features, fixing bugs, or adding new VM images, your efforts help make lume better for everyone. For detailed instructions on how to contribute, please refer to our [Contributing Guidelines](CONTRIBUTING.md). + +Join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get assistance. + +## License + +lume is open-sourced under the MIT License - see the [LICENSE](LICENSE) file for details. + +## Trademarks + +Apple, macOS, and Apple Silicon are trademarks of Apple Inc. Ubuntu and Canonical are registered trademarks of Canonical Ltd. This project is not affiliated with, endorsed by, or sponsored by Apple Inc. or Canonical Ltd. + +## Stargazers over time + +[![Stargazers over time](https://starchart.cc/trycua/lume.svg?variant=adaptive)](https://starchart.cc/trycua/lume) + +## Contributors + + + + + + + + + + +
f-trycua
f-trycua

💻
+ + + + + diff --git a/docs/API-Reference.md b/libs/lume/docs/API-Reference.md similarity index 100% rename from docs/API-Reference.md rename to libs/lume/docs/API-Reference.md diff --git a/docs/Development.md b/libs/lume/docs/Development.md similarity index 100% rename from docs/Development.md rename to libs/lume/docs/Development.md diff --git a/libs/lume/docs/FAQ.md b/libs/lume/docs/FAQ.md new file mode 100644 index 00000000..9150fbb5 --- /dev/null +++ b/libs/lume/docs/FAQ.md @@ -0,0 +1,55 @@ +# FAQs + +### Where are the VMs stored? + +VMs are stored in `~/.lume`. + +### How are images cached? + +Images are cached in `~/.lume/cache`. When doing `lume pull `, it will check if the image is already cached. If not, it will download the image and cache it, removing any older versions. + +### Are VM disks taking up all the disk space? + +No, macOS uses sparse files, which only allocate space as needed. For example, VM disks totaling 50 GB may only use 20 GB on disk. + +### How do I get the latest macOS restore image URL? + +```bash +lume ipsw +``` + +### How do I delete a VM? + +```bash +lume delete +``` + +### How to Install macOS from an IPSW Image + +#### Create a new macOS VM using the latest supported IPSW image: +Run the following command to create a new macOS virtual machine using the latest available IPSW image: + +```bash +lume create --os macos --ipsw latest +``` + +#### Create a new macOS VM using a specific IPSW image: +To create a macOS virtual machine from an older or specific IPSW file, first download the desired IPSW (UniversalMac) from a trusted source. + +Then, use the downloaded IPSW path: + +```bash +lume create --os macos --ipsw +``` + +### How do I install a custom Linux image? + +The process for creating a custom Linux image differs than macOS, with IPSW restore files not being used. You need to create a linux VM first, then mount a setup image file to the VM for the first boot. + +```bash +lume create --os linux + +lume run --mount + +lume run +``` diff --git a/libs/lume/img/cli.png b/libs/lume/img/cli.png new file mode 100644 index 00000000..2d0b6465 Binary files /dev/null and b/libs/lume/img/cli.png differ diff --git a/libs/lume/img/logo_black.png b/libs/lume/img/logo_black.png new file mode 100644 index 00000000..8198ec61 Binary files /dev/null and b/libs/lume/img/logo_black.png differ diff --git a/libs/lume/img/logo_white.png b/libs/lume/img/logo_white.png new file mode 100644 index 00000000..cd83c45b Binary files /dev/null and b/libs/lume/img/logo_white.png differ diff --git a/resources/lume.entitlements b/libs/lume/resources/lume.entitlements similarity index 100% rename from resources/lume.entitlements rename to libs/lume/resources/lume.entitlements diff --git a/scripts/build/build-debug.sh b/libs/lume/scripts/build/build-debug.sh similarity index 100% rename from scripts/build/build-debug.sh rename to libs/lume/scripts/build/build-debug.sh diff --git a/scripts/build/build-release-notarized.sh b/libs/lume/scripts/build/build-release-notarized.sh similarity index 100% rename from scripts/build/build-release-notarized.sh rename to libs/lume/scripts/build/build-release-notarized.sh diff --git a/scripts/build/build-release.sh b/libs/lume/scripts/build/build-release.sh similarity index 100% rename from scripts/build/build-release.sh rename to libs/lume/scripts/build/build-release.sh diff --git a/scripts/ghcr/pull-ghcr.sh b/libs/lume/scripts/ghcr/pull-ghcr.sh similarity index 100% rename from scripts/ghcr/pull-ghcr.sh rename to libs/lume/scripts/ghcr/pull-ghcr.sh diff --git a/scripts/ghcr/push-ghcr.sh b/libs/lume/scripts/ghcr/push-ghcr.sh similarity index 100% rename from scripts/ghcr/push-ghcr.sh rename to libs/lume/scripts/ghcr/push-ghcr.sh diff --git a/src/Commands/Clone.swift b/libs/lume/src/Commands/Clone.swift similarity index 100% rename from src/Commands/Clone.swift rename to libs/lume/src/Commands/Clone.swift diff --git a/src/Commands/Create.swift b/libs/lume/src/Commands/Create.swift similarity index 100% rename from src/Commands/Create.swift rename to libs/lume/src/Commands/Create.swift diff --git a/src/Commands/Delete.swift b/libs/lume/src/Commands/Delete.swift similarity index 100% rename from src/Commands/Delete.swift rename to libs/lume/src/Commands/Delete.swift diff --git a/src/Commands/Get.swift b/libs/lume/src/Commands/Get.swift similarity index 100% rename from src/Commands/Get.swift rename to libs/lume/src/Commands/Get.swift diff --git a/src/Commands/IPSW.swift b/libs/lume/src/Commands/IPSW.swift similarity index 100% rename from src/Commands/IPSW.swift rename to libs/lume/src/Commands/IPSW.swift diff --git a/src/Commands/Images.swift b/libs/lume/src/Commands/Images.swift similarity index 100% rename from src/Commands/Images.swift rename to libs/lume/src/Commands/Images.swift diff --git a/src/Commands/List.swift b/libs/lume/src/Commands/List.swift similarity index 100% rename from src/Commands/List.swift rename to libs/lume/src/Commands/List.swift diff --git a/src/Commands/Options/FormatOption.swift b/libs/lume/src/Commands/Options/FormatOption.swift similarity index 100% rename from src/Commands/Options/FormatOption.swift rename to libs/lume/src/Commands/Options/FormatOption.swift diff --git a/src/Commands/Prune.swift b/libs/lume/src/Commands/Prune.swift similarity index 100% rename from src/Commands/Prune.swift rename to libs/lume/src/Commands/Prune.swift diff --git a/src/Commands/Pull.swift b/libs/lume/src/Commands/Pull.swift similarity index 100% rename from src/Commands/Pull.swift rename to libs/lume/src/Commands/Pull.swift diff --git a/src/Commands/Run.swift b/libs/lume/src/Commands/Run.swift similarity index 100% rename from src/Commands/Run.swift rename to libs/lume/src/Commands/Run.swift diff --git a/src/Commands/Serve.swift b/libs/lume/src/Commands/Serve.swift similarity index 100% rename from src/Commands/Serve.swift rename to libs/lume/src/Commands/Serve.swift diff --git a/src/Commands/Set.swift b/libs/lume/src/Commands/Set.swift similarity index 100% rename from src/Commands/Set.swift rename to libs/lume/src/Commands/Set.swift diff --git a/src/Commands/Stop.swift b/libs/lume/src/Commands/Stop.swift similarity index 100% rename from src/Commands/Stop.swift rename to libs/lume/src/Commands/Stop.swift diff --git a/src/ContainerRegistry/ImageContainerRegistry.swift b/libs/lume/src/ContainerRegistry/ImageContainerRegistry.swift similarity index 100% rename from src/ContainerRegistry/ImageContainerRegistry.swift rename to libs/lume/src/ContainerRegistry/ImageContainerRegistry.swift diff --git a/src/ContainerRegistry/ImageList.swift b/libs/lume/src/ContainerRegistry/ImageList.swift similarity index 100% rename from src/ContainerRegistry/ImageList.swift rename to libs/lume/src/ContainerRegistry/ImageList.swift diff --git a/src/ContainerRegistry/ImagesPrinter.swift b/libs/lume/src/ContainerRegistry/ImagesPrinter.swift similarity index 100% rename from src/ContainerRegistry/ImagesPrinter.swift rename to libs/lume/src/ContainerRegistry/ImagesPrinter.swift diff --git a/src/Errors/Errors.swift b/libs/lume/src/Errors/Errors.swift similarity index 100% rename from src/Errors/Errors.swift rename to libs/lume/src/Errors/Errors.swift diff --git a/src/FileSystem/Home.swift b/libs/lume/src/FileSystem/Home.swift similarity index 100% rename from src/FileSystem/Home.swift rename to libs/lume/src/FileSystem/Home.swift diff --git a/src/FileSystem/VMConfig.swift b/libs/lume/src/FileSystem/VMConfig.swift similarity index 100% rename from src/FileSystem/VMConfig.swift rename to libs/lume/src/FileSystem/VMConfig.swift diff --git a/src/FileSystem/VMDirectory.swift b/libs/lume/src/FileSystem/VMDirectory.swift similarity index 100% rename from src/FileSystem/VMDirectory.swift rename to libs/lume/src/FileSystem/VMDirectory.swift diff --git a/src/LumeController.swift b/libs/lume/src/LumeController.swift similarity index 100% rename from src/LumeController.swift rename to libs/lume/src/LumeController.swift diff --git a/src/Main.swift b/libs/lume/src/Main.swift similarity index 100% rename from src/Main.swift rename to libs/lume/src/Main.swift diff --git a/src/Server/HTTP.swift b/libs/lume/src/Server/HTTP.swift similarity index 100% rename from src/Server/HTTP.swift rename to libs/lume/src/Server/HTTP.swift diff --git a/src/Server/Handlers.swift b/libs/lume/src/Server/Handlers.swift similarity index 100% rename from src/Server/Handlers.swift rename to libs/lume/src/Server/Handlers.swift diff --git a/src/Server/Requests.swift b/libs/lume/src/Server/Requests.swift similarity index 100% rename from src/Server/Requests.swift rename to libs/lume/src/Server/Requests.swift diff --git a/src/Server/Responses.swift b/libs/lume/src/Server/Responses.swift similarity index 100% rename from src/Server/Responses.swift rename to libs/lume/src/Server/Responses.swift diff --git a/src/Server/Server.swift b/libs/lume/src/Server/Server.swift similarity index 100% rename from src/Server/Server.swift rename to libs/lume/src/Server/Server.swift diff --git a/src/Utils/CommandRegistry.swift b/libs/lume/src/Utils/CommandRegistry.swift similarity index 100% rename from src/Utils/CommandRegistry.swift rename to libs/lume/src/Utils/CommandRegistry.swift diff --git a/src/Utils/CommandUtils.swift b/libs/lume/src/Utils/CommandUtils.swift similarity index 100% rename from src/Utils/CommandUtils.swift rename to libs/lume/src/Utils/CommandUtils.swift diff --git a/src/Utils/Logger.swift b/libs/lume/src/Utils/Logger.swift similarity index 100% rename from src/Utils/Logger.swift rename to libs/lume/src/Utils/Logger.swift diff --git a/src/Utils/NetworkUtils.swift b/libs/lume/src/Utils/NetworkUtils.swift similarity index 100% rename from src/Utils/NetworkUtils.swift rename to libs/lume/src/Utils/NetworkUtils.swift diff --git a/src/Utils/Path.swift b/libs/lume/src/Utils/Path.swift similarity index 100% rename from src/Utils/Path.swift rename to libs/lume/src/Utils/Path.swift diff --git a/src/Utils/ProcessRunner.swift b/libs/lume/src/Utils/ProcessRunner.swift similarity index 100% rename from src/Utils/ProcessRunner.swift rename to libs/lume/src/Utils/ProcessRunner.swift diff --git a/src/Utils/ProgressLogger.swift b/libs/lume/src/Utils/ProgressLogger.swift similarity index 100% rename from src/Utils/ProgressLogger.swift rename to libs/lume/src/Utils/ProgressLogger.swift diff --git a/src/Utils/String.swift b/libs/lume/src/Utils/String.swift similarity index 100% rename from src/Utils/String.swift rename to libs/lume/src/Utils/String.swift diff --git a/src/Utils/Utils.swift b/libs/lume/src/Utils/Utils.swift similarity index 100% rename from src/Utils/Utils.swift rename to libs/lume/src/Utils/Utils.swift diff --git a/src/VM/DarwinVM.swift b/libs/lume/src/VM/DarwinVM.swift similarity index 100% rename from src/VM/DarwinVM.swift rename to libs/lume/src/VM/DarwinVM.swift diff --git a/src/VM/LinuxVM.swift b/libs/lume/src/VM/LinuxVM.swift similarity index 100% rename from src/VM/LinuxVM.swift rename to libs/lume/src/VM/LinuxVM.swift diff --git a/src/VM/VM.swift b/libs/lume/src/VM/VM.swift similarity index 100% rename from src/VM/VM.swift rename to libs/lume/src/VM/VM.swift diff --git a/src/VM/VMDetails.swift b/libs/lume/src/VM/VMDetails.swift similarity index 100% rename from src/VM/VMDetails.swift rename to libs/lume/src/VM/VMDetails.swift diff --git a/src/VM/VMDetailsPrinter.swift b/libs/lume/src/VM/VMDetailsPrinter.swift similarity index 100% rename from src/VM/VMDetailsPrinter.swift rename to libs/lume/src/VM/VMDetailsPrinter.swift diff --git a/src/VM/VMDisplayResolution.swift b/libs/lume/src/VM/VMDisplayResolution.swift similarity index 100% rename from src/VM/VMDisplayResolution.swift rename to libs/lume/src/VM/VMDisplayResolution.swift diff --git a/src/VM/VMFactory.swift b/libs/lume/src/VM/VMFactory.swift similarity index 100% rename from src/VM/VMFactory.swift rename to libs/lume/src/VM/VMFactory.swift diff --git a/src/VNC/PassphraseGenerator.swift b/libs/lume/src/VNC/PassphraseGenerator.swift similarity index 100% rename from src/VNC/PassphraseGenerator.swift rename to libs/lume/src/VNC/PassphraseGenerator.swift diff --git a/src/VNC/VNCService.swift b/libs/lume/src/VNC/VNCService.swift similarity index 100% rename from src/VNC/VNCService.swift rename to libs/lume/src/VNC/VNCService.swift diff --git a/src/Virtualization/DHCPLeaseParser.swift b/libs/lume/src/Virtualization/DHCPLeaseParser.swift similarity index 100% rename from src/Virtualization/DHCPLeaseParser.swift rename to libs/lume/src/Virtualization/DHCPLeaseParser.swift diff --git a/src/Virtualization/DarwinImageLoader.swift b/libs/lume/src/Virtualization/DarwinImageLoader.swift similarity index 100% rename from src/Virtualization/DarwinImageLoader.swift rename to libs/lume/src/Virtualization/DarwinImageLoader.swift diff --git a/src/Virtualization/ImageLoaderFactory.swift b/libs/lume/src/Virtualization/ImageLoaderFactory.swift similarity index 100% rename from src/Virtualization/ImageLoaderFactory.swift rename to libs/lume/src/Virtualization/ImageLoaderFactory.swift diff --git a/src/Virtualization/VMVirtualizationService.swift b/libs/lume/src/Virtualization/VMVirtualizationService.swift similarity index 100% rename from src/Virtualization/VMVirtualizationService.swift rename to libs/lume/src/Virtualization/VMVirtualizationService.swift diff --git a/tests/Mocks/MockVM.swift b/libs/lume/tests/Mocks/MockVM.swift similarity index 100% rename from tests/Mocks/MockVM.swift rename to libs/lume/tests/Mocks/MockVM.swift diff --git a/tests/Mocks/MockVMVirtualizationService.swift b/libs/lume/tests/Mocks/MockVMVirtualizationService.swift similarity index 100% rename from tests/Mocks/MockVMVirtualizationService.swift rename to libs/lume/tests/Mocks/MockVMVirtualizationService.swift diff --git a/tests/Mocks/MockVNCService.swift b/libs/lume/tests/Mocks/MockVNCService.swift similarity index 100% rename from tests/Mocks/MockVNCService.swift rename to libs/lume/tests/Mocks/MockVNCService.swift diff --git a/tests/VM/VMDetailsPrinterTests.swift b/libs/lume/tests/VM/VMDetailsPrinterTests.swift similarity index 100% rename from tests/VM/VMDetailsPrinterTests.swift rename to libs/lume/tests/VM/VMDetailsPrinterTests.swift diff --git a/tests/VMTests.swift b/libs/lume/tests/VMTests.swift similarity index 100% rename from tests/VMTests.swift rename to libs/lume/tests/VMTests.swift diff --git a/tests/VMVirtualizationServiceTests.swift b/libs/lume/tests/VMVirtualizationServiceTests.swift similarity index 100% rename from tests/VMVirtualizationServiceTests.swift rename to libs/lume/tests/VMVirtualizationServiceTests.swift diff --git a/tests/VNCServiceTests.swift b/libs/lume/tests/VNCServiceTests.swift similarity index 100% rename from tests/VNCServiceTests.swift rename to libs/lume/tests/VNCServiceTests.swift diff --git a/libs/pylume/README.md b/libs/pylume/README.md new file mode 100644 index 00000000..7f5888fc --- /dev/null +++ b/libs/pylume/README.md @@ -0,0 +1,51 @@ +
+

+
+ + + + Shows my svg + +
+ + [![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#) + [![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#) + [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85) + [![PyPI](https://img.shields.io/pypi/v/pylume?color=333333)](https://pypi.org/project/pylume/) +

+
+ + +**pylume** is a lightweight Python library based on [lume](https://github.com/trycua/lume) to create, run and manage macOS and Linux virtual machines (VMs) natively on Apple Silicon. + +
+lume-py +
+ + +```bash +pip install pylume +``` + +## Usage + +Please refer to this [Notebook](./samples/nb.ipynb) for a quickstart. More details about the underlying API used by pylume are available [here](https://github.com/trycua/lume/docs/API-Reference.md). + +## Prebuilt Images + +Pre-built images are available on [ghcr.io/trycua](https://github.com/orgs/trycua/packages). +These images come pre-configured with an SSH server and auto-login enabled. + +## Contributing + +We welcome and greatly appreciate contributions to lume! Whether you're improving documentation, adding new features, fixing bugs, or adding new VM images, your efforts help make pylume better for everyone. + +Join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get assistance. + +## License + +lume is open-sourced under the MIT License - see the [LICENSE](LICENSE) file for details. + +## Stargazers over time + +[![Stargazers over time](https://starchart.cc/trycua/pylume.svg?variant=adaptive)](https://starchart.cc/trycua/pylume) diff --git a/libs/pylume/__init__.py b/libs/pylume/__init__.py new file mode 100644 index 00000000..65cacee1 --- /dev/null +++ b/libs/pylume/__init__.py @@ -0,0 +1,9 @@ +""" +PyLume Python SDK - A client library for managing macOS VMs with PyLume. +""" + +from pylume.pylume import * +from pylume.models import * +from pylume.exceptions import * + +__version__ = "0.1.0" diff --git a/libs/pylume/pylume/__init__.py b/libs/pylume/pylume/__init__.py new file mode 100644 index 00000000..c7a1c082 --- /dev/null +++ b/libs/pylume/pylume/__init__.py @@ -0,0 +1,66 @@ +""" +PyLume Python SDK - A client library for managing macOS VMs with PyLume. + +Example: + >>> from pylume import PyLume, VMConfig + >>> client = PyLume() + >>> config = VMConfig( + ... name="my-vm", + ... cpu=4, + ... memory="8GB", + ... disk_size="64GB" + ... ) + >>> client.create_vm(config) + >>> client.run_vm("my-vm") +""" + +# Import all models first +from .models import ( + VMConfig, + VMStatus, + VMRunOpts, + VMUpdateOpts, + ImageRef, + CloneSpec, + SharedDirectory, + ImageList, + ImageInfo, +) + +# Import exceptions +from .exceptions import ( + LumeError, + LumeServerError, + LumeConnectionError, + LumeTimeoutError, + LumeNotFoundError, + LumeConfigError, + LumeVMError, + LumeImageError, +) + +# Import main class last to avoid circular imports +from .pylume import PyLume + +__version__ = "0.1.0" + +__all__ = [ + "PyLume", + "VMConfig", + "VMStatus", + "VMRunOpts", + "VMUpdateOpts", + "ImageRef", + "CloneSpec", + "SharedDirectory", + "ImageList", + "ImageInfo", + "LumeError", + "LumeServerError", + "LumeConnectionError", + "LumeTimeoutError", + "LumeNotFoundError", + "LumeConfigError", + "LumeVMError", + "LumeImageError", +] diff --git a/libs/pylume/pylume/client.py b/libs/pylume/pylume/client.py new file mode 100644 index 00000000..607ddd0a --- /dev/null +++ b/libs/pylume/pylume/client.py @@ -0,0 +1,112 @@ +import json +import asyncio +import subprocess +from typing import Optional, Any, Dict +import shlex + +from .exceptions import ( + LumeError, + LumeServerError, + LumeConnectionError, + LumeTimeoutError, + LumeNotFoundError, + LumeConfigError, +) + +class LumeClient: + def __init__(self, base_url: str, timeout: float = 60.0, debug: bool = False): + self.base_url = base_url + self.timeout = timeout + self.debug = debug + + def _log_debug(self, message: str, **kwargs) -> None: + """Log debug information if debug mode is enabled.""" + if self.debug: + print(f"DEBUG: {message}") + if kwargs: + print(json.dumps(kwargs, indent=2)) + + async def _run_curl(self, method: str, path: str, data: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None) -> Any: + """Execute a curl command and return the response.""" + url = f"{self.base_url}{path}" + if params: + param_str = "&".join(f"{k}={v}" for k, v in params.items()) + url = f"{url}?{param_str}" + + cmd = ["curl", "-X", method, "-s", "-w", "%{http_code}", "-m", str(self.timeout)] + + if data is not None: + cmd.extend(["-H", "Content-Type: application/json", "-d", json.dumps(data)]) + + cmd.append(url) + + self._log_debug(f"Running curl command: {' '.join(map(shlex.quote, cmd))}") + + try: + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + if process.returncode != 0: + raise LumeConnectionError(f"Curl command failed: {stderr.decode()}") + + # The last 3 characters are the status code + response = stdout.decode() + status_code = int(response[-3:]) + response_body = response[:-3] # Remove status code from response + + if status_code >= 400: + if status_code == 404: + raise LumeNotFoundError(f"Resource not found: {path}") + elif status_code == 400: + raise LumeConfigError(f"Invalid request: {response_body}") + elif status_code >= 500: + raise LumeServerError(f"Server error: {response_body}") + else: + raise LumeError(f"Request failed with status {status_code}: {response_body}") + + return json.loads(response_body) if response_body.strip() else None + + except asyncio.TimeoutError: + raise LumeTimeoutError(f"Request timed out after {self.timeout} seconds") + + async def get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Any: + """Make a GET request.""" + return await self._run_curl("GET", path, params=params) + + async def post(self, path: str, data: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None) -> Any: + """Make a POST request.""" + old_timeout = self.timeout + if timeout is not None: + self.timeout = timeout + try: + return await self._run_curl("POST", path, data=data) + finally: + self.timeout = old_timeout + + async def patch(self, path: str, data: Dict[str, Any]) -> None: + """Make a PATCH request.""" + await self._run_curl("PATCH", path, data=data) + + async def delete(self, path: str) -> None: + """Make a DELETE request.""" + await self._run_curl("DELETE", path) + + def print_curl(self, method: str, path: str, data: Optional[Dict[str, Any]] = None) -> None: + """Print equivalent curl command for debugging.""" + curl_cmd = f"""curl -X {method} \\ + '{self.base_url}{path}'""" + + if data: + curl_cmd += f" \\\n -H 'Content-Type: application/json' \\\n -d '{json.dumps(data)}'" + + print("\nEquivalent curl command:") + print(curl_cmd) + print() + + async def close(self) -> None: + """Close the client resources.""" + pass # No shared resources to clean up \ No newline at end of file diff --git a/libs/pylume/pylume/exceptions.py b/libs/pylume/pylume/exceptions.py new file mode 100644 index 00000000..420b5d42 --- /dev/null +++ b/libs/pylume/pylume/exceptions.py @@ -0,0 +1,36 @@ +from typing import Optional + +class LumeError(Exception): + """Base exception for all PyLume errors.""" + pass + +class LumeServerError(LumeError): + """Raised when there's an error with the PyLume server.""" + def __init__(self, message: str, status_code: Optional[int] = None, response_text: Optional[str] = None): + self.status_code = status_code + self.response_text = response_text + super().__init__(message) + +class LumeConnectionError(LumeError): + """Raised when there's an error connecting to the PyLume server.""" + pass + +class LumeTimeoutError(LumeError): + """Raised when a request to the PyLume server times out.""" + pass + +class LumeNotFoundError(LumeError): + """Raised when a requested resource is not found.""" + pass + +class LumeConfigError(LumeError): + """Raised when there's an error with the configuration.""" + pass + +class LumeVMError(LumeError): + """Raised when there's an error with a VM operation.""" + pass + +class LumeImageError(LumeError): + """Raised when there's an error with an image operation.""" + pass \ No newline at end of file diff --git a/libs/pylume/pylume/lume b/libs/pylume/pylume/lume new file mode 100755 index 00000000..5ea1be47 Binary files /dev/null and b/libs/pylume/pylume/lume differ diff --git a/libs/pylume/pylume/models.py b/libs/pylume/pylume/models.py new file mode 100644 index 00000000..664065ad --- /dev/null +++ b/libs/pylume/pylume/models.py @@ -0,0 +1,136 @@ +from typing import Optional, List, Literal, Dict, Any +import re +from pydantic import BaseModel, Field, computed_field, validator, ConfigDict, RootModel + +class DiskInfo(BaseModel): + total: int + allocated: int + +class VMConfig(BaseModel): + """Configuration for creating a new VM. + + Note: Memory and disk sizes should be specified with units (e.g., "4GB", "64GB") + """ + name: str + os: Literal["macOS", "linux"] = "macOS" + cpu: int = Field(default=2, ge=1) + memory: str = "4GB" + disk_size: str = Field(default="64GB", alias="diskSize") + display: str = "1024x768" + ipsw: Optional[str] = Field(default=None, description="IPSW path or 'latest', for macOS VMs") + + class Config: + populate_by_alias = True + +class SharedDirectory(BaseModel): + """Configuration for a shared directory.""" + host_path: str = Field(..., alias="hostPath") # Allow host_path but serialize as hostPath + read_only: bool = False + + class Config: + populate_by_name = True # Allow both alias and original name + alias_generator = lambda s: ''.join(word.capitalize() if i else word for i, word in enumerate(s.split('_'))) + +class VMRunOpts(BaseModel): + """Configuration for running a VM. + + Args: + no_display: Whether to not display the VNC client + shared_directories: List of directories to share with the VM + """ + no_display: bool = Field(default=False, alias="noDisplay") + shared_directories: Optional[list[SharedDirectory]] = Field( + default=None, + alias="sharedDirectories" + ) + + model_config = ConfigDict( + populate_by_name=True, + alias_generator=lambda s: ''.join(word.capitalize() if i else word for i, word in enumerate(s.split('_'))) + ) + + def model_dump(self, **kwargs): + data = super().model_dump(**kwargs) + # Convert shared directory fields to match API expectations + if self.shared_directories and "by_alias" in kwargs and kwargs["by_alias"]: + data["sharedDirectories"] = [ + { + "hostPath": d.host_path, + "readOnly": d.read_only + } + for d in self.shared_directories + ] + # Remove the snake_case version if it exists + data.pop("shared_directories", None) + return data + +class VMStatus(BaseModel): + name: str + status: str + os: Literal["macOS", "linux"] + cpu_count: int = Field(alias="cpuCount") + memory_size: int = Field(alias="memorySize") # API returns memory size in bytes + disk_size: DiskInfo = Field(alias="diskSize") + vnc_url: Optional[str] = Field(default=None, alias="vncUrl") + ip_address: Optional[str] = Field(default=None, alias="ipAddress") + + class Config: + populate_by_alias = True + + @computed_field + @property + def state(self) -> str: + return self.status + + @computed_field + @property + def cpu(self) -> int: + return self.cpu_count + + @computed_field + @property + def memory(self) -> str: + # Convert bytes to GB + gb = self.memory_size / (1024 * 1024 * 1024) + return f"{int(gb)}GB" + +class VMUpdateOpts(BaseModel): + cpu: Optional[int] = None + memory: Optional[str] = None + disk_size: Optional[str] = None + +class ImageRef(BaseModel): + """Reference to a VM image.""" + image: str + tag: str = "latest" + registry: Optional[str] = "ghcr.io" + organization: Optional[str] = "trycua" + + def model_dump(self, **kwargs): + """Override model_dump to return just the image:tag format.""" + return f"{self.image}:{self.tag}" + +class CloneSpec(BaseModel): + """Specification for cloning a VM.""" + name: str + new_name: str = Field(alias="newName") + + class Config: + populate_by_alias = True + +class ImageInfo(BaseModel): + """Model for individual image information.""" + imageId: str + +class ImageList(RootModel): + """Response model for the images endpoint.""" + root: List[ImageInfo] + + def __iter__(self): + return iter(self.root) + + def __getitem__(self, item): + return self.root[item] + + def __len__(self): + return len(self.root) \ No newline at end of file diff --git a/libs/pylume/pylume/pylume.py b/libs/pylume/pylume/pylume.py new file mode 100644 index 00000000..66fb6bc4 --- /dev/null +++ b/libs/pylume/pylume/pylume.py @@ -0,0 +1,308 @@ +import os +import sys +import json +import time +import asyncio +import subprocess +from typing import Optional, List, Union, Callable, TypeVar, Any +from functools import wraps +import re +import signal + +from .server import LumeServer +from .client import LumeClient +from .models import ( + VMConfig, + VMStatus, + VMRunOpts, + VMUpdateOpts, + ImageRef, + CloneSpec, + SharedDirectory, + ImageList, +) +from .exceptions import ( + LumeError, + LumeServerError, + LumeConnectionError, + LumeTimeoutError, + LumeNotFoundError, + LumeConfigError, + LumeVMError, + LumeImageError, +) + +# Type variable for the decorator +T = TypeVar('T') + +def ensure_server(func: Callable[..., T]) -> Callable[..., T]: + """Decorator to ensure server is running before executing the method.""" + @wraps(func) + async def wrapper(self: 'PyLume', *args: Any, **kwargs: Any) -> T: + # ensure_running is an async method, so we need to await it + await self.server.ensure_running() + # Initialize client if needed + await self._init_client() + return await func(self, *args, **kwargs) # type: ignore + return wrapper # type: ignore + +class PyLume: + def __init__( + self, + debug: bool = False, + server_start_timeout: int = 60, + port: Optional[int] = None, + use_existing_server: bool = False + ): + """Initialize the async PyLume client. + + Args: + debug: Enable debug logging + auto_start_server: Whether to automatically start the lume server if not running + server_start_timeout: Timeout in seconds to wait for server to start + port: Port number for the lume server. Required when use_existing_server is True. + use_existing_server: If True, will try to connect to an existing server on the specified port + instead of starting a new one. + """ + if use_existing_server and port is None: + raise LumeConfigError("Port must be specified when using an existing server") + + self.server = LumeServer( + debug=debug, + server_start_timeout=server_start_timeout, + port=port, + use_existing_server=use_existing_server + ) + self.client = None + + async def __aenter__(self) -> 'PyLume': + """Async context manager entry.""" + if self.server.use_existing_server: + # Just set up the base URL and initialize client for existing server + self.server.port = self.server.requested_port + self.server.base_url = f"http://localhost:{self.server.port}/lume" + else: + await self.server.ensure_running() + + await self._init_client() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit.""" + if self.client is not None: + await self.client.close() + await self.server.stop() + + async def _init_client(self) -> None: + """Initialize the client if not already initialized.""" + if self.client is None: + if self.server.base_url is None: + raise RuntimeError("Server base URL not set") + self.client = LumeClient( + base_url=self.server.base_url, + timeout=300.0, + debug=self.server.debug + ) + + def _log_debug(self, message: str, **kwargs) -> None: + """Log debug information if debug mode is enabled.""" + if self.server.debug: + print(f"DEBUG: {message}") + if kwargs: + print(json.dumps(kwargs, indent=2)) + + async def _handle_api_error(self, e: Exception, operation: str) -> None: + """Handle API errors and raise appropriate custom exceptions.""" + if isinstance(e, subprocess.SubprocessError): + raise LumeConnectionError(f"Failed to connect to PyLume server: {str(e)}") + elif isinstance(e, asyncio.TimeoutError): + raise LumeTimeoutError(f"Request timed out: {str(e)}") + + if not hasattr(e, 'status') and not isinstance(e, subprocess.CalledProcessError): + raise LumeServerError(f"Unknown error during {operation}: {str(e)}") + + status_code = getattr(e, 'status', 500) + response_text = str(e) + + self._log_debug( + f"{operation} request failed", + status_code=status_code, + response_text=response_text + ) + + if status_code == 404: + raise LumeNotFoundError(f"Resource not found during {operation}") + elif status_code == 400: + raise LumeConfigError(f"Invalid configuration for {operation}: {response_text}") + elif status_code >= 500: + raise LumeServerError( + f"Server error during {operation}", + status_code=status_code, + response_text=response_text + ) + else: + raise LumeServerError( + f"Error during {operation}", + status_code=status_code, + response_text=response_text + ) + + async def _read_output(self) -> None: + """Read and log server output.""" + try: + while True: + if not self.server.server_process or self.server.server_process.poll() is not None: + self._log_debug("Server process ended") + break + + # Read stdout without blocking + if self.server.server_process.stdout: + while True: + line = self.server.server_process.stdout.readline() + if not line: + break + line = line.strip() + self._log_debug(f"Server stdout: {line}") + if "Server started" in line.decode('utf-8'): + self._log_debug("Detected server started message") + return + + # Read stderr without blocking + if self.server.server_process.stderr: + while True: + line = self.server.server_process.stderr.readline() + if not line: + break + line = line.strip() + self._log_debug(f"Server stderr: {line}") + if "error" in line.decode('utf-8').lower(): + raise RuntimeError(f"Server error: {line}") + + await asyncio.sleep(0.1) # Small delay to prevent CPU spinning + except Exception as e: + self._log_debug(f"Error in output reader: {str(e)}") + raise + + @ensure_server + async def create_vm(self, spec: Union[VMConfig, dict]) -> None: + """Create a VM with the given configuration.""" + # Ensure client is initialized + await self._init_client() + + if isinstance(spec, VMConfig): + spec = spec.model_dump(by_alias=True, exclude_none=True) + + # Suppress optional attribute access errors + self.client.print_curl("POST", "/vms", spec) # type: ignore[attr-defined] + await self.client.post("/vms", spec) # type: ignore[attr-defined] + + @ensure_server + async def run_vm(self, name: str, opts: Optional[Union[VMRunOpts, dict]] = None) -> None: + """Run a VM.""" + if opts is None: + opts = VMRunOpts(no_display=False) # type: ignore[attr-defined] + elif isinstance(opts, dict): + opts = VMRunOpts(**opts) + + payload = opts.model_dump(by_alias=True, exclude_none=True) + self.client.print_curl("POST", f"/vms/{name}/run", payload) # type: ignore[attr-defined] + await self.client.post(f"/vms/{name}/run", payload) # type: ignore[attr-defined] + + @ensure_server + async def list_vms(self) -> List[VMStatus]: + """List all VMs.""" + data = await self.client.get("/vms") # type: ignore[attr-defined] + return [VMStatus.model_validate(vm) for vm in data] + + @ensure_server + async def get_vm(self, name: str) -> VMStatus: + """Get VM details.""" + data = await self.client.get(f"/vms/{name}") # type: ignore[attr-defined] + return VMStatus.model_validate(data) + + @ensure_server + async def update_vm(self, name: str, params: Union[VMUpdateOpts, dict]) -> None: + """Update VM settings.""" + if isinstance(params, dict): + params = VMUpdateOpts(**params) + + payload = params.model_dump(by_alias=True, exclude_none=True) + self.client.print_curl("PATCH", f"/vms/{name}", payload) # type: ignore[attr-defined] + await self.client.patch(f"/vms/{name}", payload) # type: ignore[attr-defined] + + @ensure_server + async def stop_vm(self, name: str) -> None: + """Stop a VM.""" + await self.client.post(f"/vms/{name}/stop") # type: ignore[attr-defined] + + @ensure_server + async def delete_vm(self, name: str) -> None: + """Delete a VM.""" + await self.client.delete(f"/vms/{name}") # type: ignore[attr-defined] + + @ensure_server + async def pull_image(self, spec: Union[ImageRef, dict, str], name: Optional[str] = None) -> None: + """Pull a VM image.""" + await self._init_client() + if isinstance(spec, str): + if ":" in spec: + image_str = spec + else: + image_str = f"{spec}:latest" + registry = "ghcr.io" + organization = "trycua" + elif isinstance(spec, dict): + image = spec.get("image", "") + tag = spec.get("tag", "latest") + image_str = f"{image}:{tag}" + registry = spec.get("registry", "ghcr.io") + organization = spec.get("organization", "trycua") + else: + image_str = f"{spec.image}:{spec.tag}" + registry = spec.registry + organization = spec.organization + + payload = { + "image": image_str, + "name": name, + "registry": registry, + "organization": organization + } + + self.client.print_curl("POST", "/pull", payload) # type: ignore[attr-defined] + await self.client.post("/pull", payload, timeout=300.0) # type: ignore[attr-defined] + + @ensure_server + async def clone_vm(self, name: str, new_name: str) -> None: + """Clone a VM with the given name to a new VM with new_name.""" + config = CloneSpec(name=name, newName=new_name) + self.client.print_curl("POST", "/vms/clone", config.model_dump()) # type: ignore[attr-defined] + await self.client.post("/vms/clone", config.model_dump()) # type: ignore[attr-defined] + + @ensure_server + async def get_latest_ipsw_url(self) -> str: + """Get the latest IPSW URL.""" + await self._init_client() + data = await self.client.get("/ipsw") # type: ignore[attr-defined] + return data["url"] + + @ensure_server + async def get_images(self, organization: Optional[str] = None) -> ImageList: + """Get list of available images.""" + await self._init_client() + params = {"organization": organization} if organization else None + data = await self.client.get("/images", params) # type: ignore[attr-defined] + return ImageList(root=data) + + async def close(self) -> None: + """Close the client and stop the server.""" + if self.client is not None: + await self.client.close() + self.client = None + await asyncio.sleep(1) + await self.server.stop() + + async def _ensure_client(self) -> None: + """Ensure client is initialized.""" + if self.client is None: + await self._init_client() \ No newline at end of file diff --git a/libs/pylume/pylume/server.py b/libs/pylume/pylume/server.py new file mode 100644 index 00000000..ab0b9a71 --- /dev/null +++ b/libs/pylume/pylume/server.py @@ -0,0 +1,404 @@ +import os +import time +import asyncio +import subprocess +import tempfile +import logging +import socket +from typing import Optional +import sys +from .exceptions import LumeConnectionError +import signal + +class LumeServer: + def __init__( + self, + debug: bool = False, + server_start_timeout: int = 60, + port: Optional[int] = None, + use_existing_server: bool = False + ): + """Initialize the LumeServer.""" + self.debug = debug + self.server_start_timeout = server_start_timeout + self.server_process = None + self.output_file = None + self.requested_port = port + self.port = None + self.base_url = None + self.use_existing_server = use_existing_server + + # Configure logging + self.logger = logging.getLogger('lume_server') + if not self.logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.DEBUG if debug else logging.INFO) + + def _check_port_available(self, port: int) -> bool: + """Check if a specific port is available.""" + try: + # Create a socket + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.logger.debug(f"Created socket for port {port} check") + + # Set socket options + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.logger.debug("Set SO_REUSEADDR") + + # Bind to the port + try: + s.bind(('127.0.0.1', port)) + self.logger.debug(f"Successfully bound to port {port}") + s.listen(1) + self.logger.debug(f"Successfully listening on port {port}") + s.close() + self.logger.debug(f"Port {port} is available") + return True + except OSError as e: + self.logger.debug(f"Failed to bind to port {port}: {str(e)}") + return False + finally: + try: + s.close() + self.logger.debug("Socket closed") + except: + pass + + except Exception as e: + self.logger.debug(f"Unexpected error checking port {port}: {str(e)}") + return False + + def _get_server_port(self) -> int: + """Get and validate the server port.""" + from .exceptions import LumeConfigError + + if self.requested_port is None: + raise LumeConfigError("Port must be specified when starting a new server") + + self.logger.debug(f"Checking availability of port {self.requested_port}") + + # Try multiple times with a small delay + for attempt in range(3): + if attempt > 0: + self.logger.debug(f"Retrying port check (attempt {attempt + 1})") + time.sleep(1) + + if self._check_port_available(self.requested_port): + self.logger.debug(f"Port {self.requested_port} is available") + return self.requested_port + else: + self.logger.debug(f"Port {self.requested_port} check failed on attempt {attempt + 1}") + + raise LumeConfigError(f"Requested port {self.requested_port} is not available after 3 attempts") + + async def _ensure_server_running(self) -> None: + """Ensure the lume server is running, start it if it's not.""" + try: + self.logger.debug("Checking if lume server is running...") + # Try to connect to the server with a short timeout + cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "5", f"{self.base_url}/vms"] + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + if process.returncode == 0: + response = stdout.decode() + status_code = int(response[-3:]) + if status_code == 200: + self.logger.debug("PyLume server is running") + return + + self.logger.debug("PyLume server not running, attempting to start it") + # Server not running, try to start it + lume_path = os.path.join(os.path.dirname(__file__), "lume") + if not os.path.exists(lume_path): + raise RuntimeError(f"Could not find lume binary at {lume_path}") + + # Make sure the file is executable + os.chmod(lume_path, 0o755) + + # Create a temporary file for server output + self.output_file = tempfile.NamedTemporaryFile(mode='w+', delete=False) + self.logger.debug(f"Using temporary file for server output: {self.output_file.name}") + + # Start the server + self.logger.debug(f"Starting lume server with: {lume_path} serve --port {self.port}") + + # Start server in background using subprocess.Popen + try: + self.server_process = subprocess.Popen( + [lume_path, "serve", "--port", str(self.port)], + stdout=self.output_file, + stderr=self.output_file, + cwd=os.path.dirname(lume_path), + start_new_session=True # Run in new session to avoid blocking + ) + except Exception as e: + self.output_file.close() + os.unlink(self.output_file.name) + raise RuntimeError(f"Failed to start lume server process: {str(e)}") + + # Wait for server to start + self.logger.debug(f"Waiting up to {self.server_start_timeout} seconds for server to start...") + start_time = time.time() + server_ready = False + last_size = 0 + + while time.time() - start_time < self.server_start_timeout: + if self.server_process.poll() is not None: + # Process has terminated + self.output_file.seek(0) + output = self.output_file.read() + self.output_file.close() + os.unlink(self.output_file.name) + error_msg = ( + f"Server process terminated unexpectedly.\n" + f"Exit code: {self.server_process.returncode}\n" + f"Output: {output}" + ) + raise RuntimeError(error_msg) + + # Check output file for server ready message + self.output_file.seek(0, os.SEEK_END) + size = self.output_file.tell() + if size > last_size: # Only read if there's new content + self.output_file.seek(last_size) + new_output = self.output_file.read() + if new_output.strip(): # Only log non-empty output + self.logger.debug(f"Server output: {new_output.strip()}") + last_size = size + + if "Server started" in new_output: + server_ready = True + self.logger.debug("Server startup detected") + break + + # Try to connect to the server periodically + try: + cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "5", f"{self.base_url}/vms"] + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + if process.returncode == 0: + response = stdout.decode() + status_code = int(response[-3:]) + if status_code == 200: + server_ready = True + self.logger.debug("Server is responding to requests") + break + except: + pass # Server not ready yet + + await asyncio.sleep(1.0) + + if not server_ready: + # Cleanup if server didn't start + if self.server_process: + self.server_process.terminate() + try: + self.server_process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.server_process.kill() + self.output_file.close() + os.unlink(self.output_file.name) + raise RuntimeError( + f"Failed to start lume server after {self.server_start_timeout} seconds. " + "Check the debug output for more details." + ) + + # Give the server a moment to fully initialize + await asyncio.sleep(2.0) + + # Verify server is responding + try: + cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "10", f"{self.base_url}/vms"] + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + if process.returncode != 0: + raise RuntimeError(f"Curl command failed: {stderr.decode()}") + + response = stdout.decode() + status_code = int(response[-3:]) + + if status_code != 200: + raise RuntimeError(f"Server returned status code {status_code}") + + self.logger.debug("PyLume server started successfully") + except Exception as e: + self.logger.debug(f"Server verification failed: {str(e)}") + if self.server_process: + self.server_process.terminate() + try: + self.server_process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.server_process.kill() + self.output_file.close() + os.unlink(self.output_file.name) + raise RuntimeError(f"Server started but is not responding: {str(e)}") + + self.logger.debug("Server startup completed successfully") + + except Exception as e: + raise RuntimeError(f"Failed to start lume server: {str(e)}") + + async def _start_server(self) -> None: + """Start the lume server using the lume executable.""" + self.logger.debug("Starting PyLume server") + + # Get absolute path to lume executable in the same directory as this file + lume_path = os.path.join(os.path.dirname(__file__), "lume") + if not os.path.exists(lume_path): + raise RuntimeError(f"Could not find lume binary at {lume_path}") + + try: + # Make executable + os.chmod(lume_path, 0o755) + + # Get and validate port + self.port = self._get_server_port() + self.base_url = f"http://localhost:{self.port}/lume" + + # Set up output handling + self.output_file = tempfile.NamedTemporaryFile(mode='w+', delete=False) + + # Start the server process with the lume executable + env = os.environ.copy() + env["RUST_BACKTRACE"] = "1" # Enable backtrace for better error reporting + + self.server_process = subprocess.Popen( + [lume_path, "serve", "--port", str(self.port)], + stdout=self.output_file, + stderr=subprocess.STDOUT, + cwd=os.path.dirname(lume_path), # Run from same directory as executable + env=env + ) + + # Wait for server to initialize + await asyncio.sleep(2) + await self._wait_for_server() + + except Exception as e: + await self._cleanup() + raise RuntimeError(f"Failed to start lume server process: {str(e)}") + + async def _tail_log(self) -> None: + """Read and display server log output in debug mode.""" + while True: + try: + self.output_file.seek(0, os.SEEK_END) # type: ignore[attr-defined] + line = self.output_file.readline() # type: ignore[attr-defined] + if line: + line = line.strip() + if line: + print(f"SERVER: {line}") + if self.server_process.poll() is not None: # type: ignore[attr-defined] + print("Server process ended") + break + await asyncio.sleep(0.1) + except Exception as e: + print(f"Error reading log: {e}") + await asyncio.sleep(0.1) + + async def _wait_for_server(self) -> None: + """Wait for server to start and become responsive with increased timeout.""" + start_time = time.time() + while time.time() - start_time < self.server_start_timeout: + if self.server_process.poll() is not None: # type: ignore[attr-defined] + error_msg = await self._get_error_output() + await self._cleanup() + raise RuntimeError(error_msg) + + try: + await self._verify_server() + self.logger.debug("Server is now responsive") + return + except Exception as e: + self.logger.debug(f"Server not ready yet: {str(e)}") + await asyncio.sleep(1.0) + + await self._cleanup() + raise RuntimeError(f"Server failed to start after {self.server_start_timeout} seconds") + + async def _verify_server(self) -> None: + """Verify server is responding to requests.""" + try: + cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "10", f"{self.base_url}/vms"] + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + if process.returncode != 0: + raise RuntimeError(f"Curl command failed: {stderr.decode()}") + + response = stdout.decode() + status_code = int(response[-3:]) + + if status_code != 200: + raise RuntimeError(f"Server returned status code {status_code}") + + self.logger.debug("PyLume server started successfully") + except Exception as e: + raise RuntimeError(f"Server not responding: {str(e)}") + + async def _get_error_output(self) -> str: + """Get error output from the server process.""" + if not self.output_file: + return "No output available" + self.output_file.seek(0) + output = self.output_file.read() + return ( + f"Server process terminated unexpectedly.\n" + f"Exit code: {self.server_process.returncode}\n" # type: ignore[attr-defined] + f"Output: {output}" + ) + + async def _cleanup(self) -> None: + """Clean up all server resources.""" + if self.server_process: + try: + self.server_process.terminate() + try: + self.server_process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.server_process.kill() + except: + pass + self.server_process = None + + # Clean up output file + if self.output_file: + try: + self.output_file.close() + os.unlink(self.output_file.name) + except Exception as e: + self.logger.debug(f"Error cleaning up output file: {e}") + self.output_file = None + + async def ensure_running(self) -> None: + """Start the server if we're managing it.""" + if not self.use_existing_server: + await self._start_server() + + async def stop(self) -> None: + """Stop the server if we're managing it.""" + if not self.use_existing_server: + self.logger.debug("Stopping lume server...") + await self._cleanup() \ No newline at end of file diff --git a/libs/pylume/pyproject.toml b/libs/pylume/pyproject.toml new file mode 100644 index 00000000..5cb44806 --- /dev/null +++ b/libs/pylume/pyproject.toml @@ -0,0 +1,78 @@ +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" + +[project] +name = "pylume" +version = "0.1.0" +description = "Python SDK for lume - run macOS and Linux VMs on Apple Silicon" +authors = [ + { name = "TryCua", email = "gh@trycua.com" } +] +dependencies = [ + "pydantic>=2.0.0" +] +requires-python = ">=3.9" +readme = "README.md" +license = { text = "MIT" } +keywords = ["macos", "virtualization", "vm", "apple-silicon"] +classifiers = [ + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: MacOS :: MacOS X", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +[project.urls] +homepage = "https://github.com/trycua/pylume" +repository = "https://github.com/trycua/pylume" + +[tool.pdm] +distribution = true +package-dir = "." +includes = [ + "pylume/lume" +] + +[tool.pdm.dev-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.23.0", + "black>=23.0.0", + "isort>=5.12.0" +] + +[tool.black] +line-length = 100 +target-version = ["py310"] + +[tool.ruff] +line-length = 100 +target-version = "py310" +select = ["E", "F", "B", "I"] +fix = true + +[tool.ruff.format] +docstring-code-format = true + +[tool.mypy] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = "test_*.py" + +[tool.pdm.build] +includes = ["pylume/"] +source-includes = ["tests/", "README.md", "LICENSE"] \ No newline at end of file diff --git a/libs/pylume/tests/__init__.py b/libs/pylume/tests/__init__.py new file mode 100644 index 00000000..e1b66e8e --- /dev/null +++ b/libs/pylume/tests/__init__.py @@ -0,0 +1,3 @@ +""" +PyLume tests package +""" \ No newline at end of file diff --git a/libs/pylume/tests/test_basic.py b/libs/pylume/tests/test_basic.py new file mode 100644 index 00000000..18c2f30e --- /dev/null +++ b/libs/pylume/tests/test_basic.py @@ -0,0 +1,20 @@ +""" +Basic tests for the pylume package +""" +import pytest + + +def test_import(): + """Test that the package can be imported""" + import pylume + try: + assert pylume.__version__ == "0.1.0" + except AttributeError: + # If __version__ is not defined, that's okay for this test + pass + + +def test_pylume_import(): + """Test that the PyLume class can be imported""" + from pylume import PyLume + assert PyLume is not None \ No newline at end of file diff --git a/libs/som/README.md b/libs/som/README.md new file mode 100644 index 00000000..23423daa --- /dev/null +++ b/libs/som/README.md @@ -0,0 +1,184 @@ +
+

+
+ + + + Shows my svg + +
+ + [![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#) + [![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#) + [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85) + [![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/) +

+
+ +**Som** (Set-of-Mark) is a visual grounding component for the Computer-Use Agent (CUA) framework powering Cua, for detecting and analyzing UI elements in screenshots. Optimized for macOS Silicon with Metal Performance Shaders (MPS), it combines YOLO-based icon detection with EasyOCR text recognition to provide comprehensive UI element analysis. + +## Features + +- Optimized for Apple Silicon with MPS acceleration +- Icon detection using YOLO with multi-scale processing +- Text recognition using EasyOCR (GPU-accelerated) +- Automatic hardware detection (MPS → CUDA → CPU) +- Smart detection parameters tuned for UI elements +- Detailed visualization with numbered annotations +- Performance benchmarking tools + +## System Requirements + +- **Recommended**: macOS with Apple Silicon + - Uses Metal Performance Shaders (MPS) + - Multi-scale detection enabled + - ~0.4s average detection time + +- **Supported**: Any Python 3.11+ environment + - Falls back to CPU if no GPU available + - Single-scale detection on CPU + - ~1.3s average detection time + +## Installation + +```bash +# Using PDM (recommended) +pdm install + +# Using pip +pip install -e . +``` + +## Quick Start + +```python +from som import OmniParser +from PIL import Image + +# Initialize parser +parser = OmniParser() + +# Process an image +image = Image.open("screenshot.png") +result = parser.parse( + image, + box_threshold=0.3, # Confidence threshold + iou_threshold=0.1, # Overlap threshold + use_ocr=True # Enable text detection +) + +# Access results +for elem in result.elements: + if elem.type == "icon": + print(f"Icon: confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}") + else: # text + print(f"Text: '{elem.content}', confidence={elem.confidence:.3f}") +``` + +## Configuration + +### Detection Parameters + +#### Box Threshold (0.3) +Controls the confidence threshold for accepting detections: +``` +High Threshold (0.3): Low Threshold (0.01): ++----------------+ +----------------+ +| | | +--------+ | +| Confident | | |Unsure?| | +| Detection | | +--------+ | +| (✓ Accept) | | (? Reject) | +| | | | ++----------------+ +----------------+ +conf = 0.85 conf = 0.02 +``` +- Higher values (0.3) yield more precise but fewer detections +- Lower values (0.01) catch more potential icons but increase false positives +- Default is 0.3 for optimal precision/recall balance + +#### IOU Threshold (0.1) +Controls how overlapping detections are merged: +``` +IOU = Intersection Area / Union Area + +Low Overlap (Keep Both): High Overlap (Merge): ++----------+ +----------+ +| Box1 | | Box1 | +| | vs. |+-----+ | ++----------+ ||Box2 | | + +----------+ |+-----+ | + | Box2 | +----------+ + | | + +----------+ +IOU ≈ 0.05 (Keep Both) IOU ≈ 0.7 (Merge) +``` +- Lower values (0.1) more aggressively remove overlapping boxes +- Higher values (0.5) allow more overlapping detections +- Default is 0.1 to handle densely packed UI elements + +### OCR Configuration + +- **Engine**: EasyOCR + - Primary choice for all platforms + - Fast initialization and processing + - Built-in English language support + - GPU acceleration when available + +- **Settings**: + - Timeout: 5 seconds + - Confidence threshold: 0.5 + - Paragraph mode: Disabled + - Language: English only + +## Performance + +### Hardware Acceleration + +#### MPS (Metal Performance Shaders) +- Multi-scale detection (640px, 1280px, 1920px) +- Test-time augmentation enabled +- Half-precision (FP16) +- Average detection time: ~0.4s +- Best for production use when available + +#### CPU +- Single-scale detection (1280px) +- Full-precision (FP32) +- Average detection time: ~1.3s +- Reliable fallback option + +### Example Output Structure + +``` +examples/output/ +├── {timestamp}_no_ocr/ +│ ├── annotated_images/ +│ │ └── screenshot_analyzed.png +│ ├── screen_details.txt +│ └── summary.json +└── {timestamp}_ocr/ + ├── annotated_images/ + │ └── screenshot_analyzed.png + ├── screen_details.txt + └── summary.json +``` + +## Development + +### Test Data +- Place test screenshots in `examples/test_data/` +- Not tracked in git to keep repository size manageable +- Default test image: `test_screen.png` (1920x1080) + +### Running Tests +```bash +# Run benchmark with no OCR +python examples/omniparser_examples.py examples/test_data/test_screen.png --runs 5 --ocr none + +# Run benchmark with OCR +python examples/omniparser_examples.py examples/test_data/test_screen.png --runs 5 --ocr easyocr +``` + +## License + +MIT License - See LICENSE file for details. diff --git a/libs/som/poetry.toml b/libs/som/poetry.toml new file mode 100644 index 00000000..ab1033bd --- /dev/null +++ b/libs/som/poetry.toml @@ -0,0 +1,2 @@ +[virtualenvs] +in-project = true diff --git a/libs/som/pyproject.toml b/libs/som/pyproject.toml new file mode 100644 index 00000000..0bae7ea2 --- /dev/null +++ b/libs/som/pyproject.toml @@ -0,0 +1,81 @@ +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" + +[project] +name = "cua-som" +version = "0.1.0" +description = "Computer Vision and OCR library for detecting and analyzing UI elements" +authors = [ + { name = "TryCua", email = "gh@trycua.com" } +] +dependencies = [ + "torch>=2.2.1", + "torchvision>=0.17.1", + "ultralytics>=8.1.28", + "easyocr>=1.7.1", + "numpy>=1.26.4", + "pillow>=10.2.0", + "setuptools>=75.8.1", + "opencv-python-headless>=4.11.0.86", + "matplotlib>=3.8.3", + "huggingface-hub>=0.21.4", + "supervision>=0.25.1", + "typing-extensions>=4.9.0", + "pydantic>=2.6.3" +] +requires-python = ">=3.11" +readme = "README.md" +license = {text = "MIT"} +keywords = ["computer-vision", "ocr", "ui-analysis", "icon-detection"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Recognition" +] + +[project.urls] +Homepage = "https://github.com/trycua/cua" +Repository = "https://github.com/trycua/cua" +Documentation = "https://github.com/trycua/cua/tree/main/docs" + +[tool.pdm] +distribution = true +package-type = "library" +src-layout = false + +[tool.pdm.build] +includes = ["som/"] +source-includes = ["tests/", "README.md", "LICENSE"] + +[tool.black] +line-length = 100 +target-version = ["py311"] + +[tool.ruff] +line-length = 100 +target-version = "py311" +select = ["E", "F", "B", "I"] +fix = true + +[tool.ruff.format] +docstring-code-format = true + +[tool.mypy] +strict = true +python_version = "3.11" +ignore_missing_imports = true +disallow_untyped_defs = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = "test_*.py" diff --git a/libs/som/som/__init__.py b/libs/som/som/__init__.py new file mode 100644 index 00000000..906da597 --- /dev/null +++ b/libs/som/som/__init__.py @@ -0,0 +1,23 @@ +"""SOM - Computer Vision and OCR library for detecting and analyzing UI elements.""" + +__version__ = "0.1.0" + +from .detect import OmniParser +from .models import ( + BoundingBox, + UIElement, + IconElement, + TextElement, + ParserMetadata, + ParseResult +) + +__all__ = [ + "OmniParser", + "BoundingBox", + "UIElement", + "IconElement", + "TextElement", + "ParserMetadata", + "ParseResult" +] \ No newline at end of file diff --git a/libs/som/som/detect.py b/libs/som/som/detect.py new file mode 100644 index 00000000..25b95c8d --- /dev/null +++ b/libs/som/som/detect.py @@ -0,0 +1,430 @@ +from pathlib import Path +from typing import Union, List, Dict, Any, Tuple, Optional +import logging +import torch +import torchvision.ops +import cv2 +import numpy as np +import time +import torchvision.transforms as T +from PIL import Image +import io +import base64 +import argparse +import signal +from contextlib import contextmanager + +from ultralytics import YOLO +from huggingface_hub import hf_hub_download +import supervision as sv +from supervision.detection.core import Detections + +from .detection import DetectionProcessor +from .ocr import OCRProcessor +from .visualization import BoxAnnotator +from .models import BoundingBox, UIElement, IconElement, TextElement, ParserMetadata, ParseResult + +logger = logging.getLogger(__name__) + + +class TimeoutException(Exception): + pass + + +@contextmanager +def timeout(seconds: int): + def timeout_handler(signum, frame): + raise TimeoutException("OCR process timed out") + + # Register the signal handler + original_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(seconds) + + try: + yield + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, original_handler) + + +def process_text_box(box, image): + """Process a single text box with OCR.""" + try: + import easyocr + + x1 = int(min(point[0] for point in box)) + y1 = int(min(point[1] for point in box)) + x2 = int(max(point[0] for point in box)) + y2 = int(max(point[1] for point in box)) + + # Add padding + pad = 2 + x1 = max(0, x1 - pad) + y1 = max(0, y1 - pad) + x2 = min(image.shape[1], x2 + pad) + y2 = min(image.shape[0], y2 + pad) + + region = image[y1:y2, x1:x2] + if region.size > 0: + reader = easyocr.Reader(["en"]) + result = reader.readtext(region) + if result: + text = result[0][1] # Get text + conf = result[0][2] # Get confidence + if conf > 0.5: + return text, [x1, y1, x2, y2], conf + except Exception: + pass + return None + + +def check_ocr_box(image_path: Union[str, Path]) -> Tuple[List[str], List[List[float]]]: + """Check OCR box using EasyOCR.""" + # Read image once + if isinstance(image_path, str): + image_path = Path(image_path) + + # Read image into memory + image_cv = cv2.imread(str(image_path)) + if image_cv is None: + logger.error(f"Failed to read image: {image_path}") + return [], [] + + # Use EasyOCR + import ssl + import easyocr + + # Create unverified SSL context for development + ssl._create_default_https_context = ssl._create_unverified_context + try: + reader = easyocr.Reader(["en"]) + with timeout(5): # 5 second timeout for EasyOCR + results = reader.readtext(image_cv, paragraph=False, text_threshold=0.5) + except TimeoutException: + logger.warning("EasyOCR timed out, returning no results") + return [], [] + except Exception as e: + logger.warning(f"EasyOCR failed: {str(e)}") + return [], [] + finally: + # Restore default SSL context + ssl._create_default_https_context = ssl.create_default_context + + texts = [] + boxes = [] + + for box, text, conf in results: + # Convert box format to [x1, y1, x2, y2] + x1 = min(point[0] for point in box) + y1 = min(point[1] for point in box) + x2 = max(point[0] for point in box) + y2 = max(point[1] for point in box) + + if conf > 0.5: # Only keep higher confidence detections + texts.append(text) + boxes.append([x1, y1, x2, y2]) + + return texts, boxes + + +class OmniParser: + """Enhanced UI parser using computer vision and OCR for detecting interactive elements.""" + + def __init__( + self, + model_path: Optional[Union[str, Path]] = None, + cache_dir: Optional[Union[str, Path]] = None, + force_device: Optional[str] = None, + ): + """Initialize the OmniParser. + + Args: + model_path: Optional path to the YOLO model + cache_dir: Optional directory to cache model files + force_device: Force specific device (cpu/cuda/mps) + """ + self.detector = DetectionProcessor( + model_path=Path(model_path) if model_path else None, + cache_dir=Path(cache_dir) if cache_dir else None, + force_device=force_device, + ) + self.ocr = OCRProcessor() + self.visualizer = BoxAnnotator() + + def process_image( + self, + image: Image.Image, + box_threshold: float = 0.3, + iou_threshold: float = 0.1, + use_ocr: bool = True, + ) -> Tuple[Image.Image, List[UIElement]]: + """Process an image to detect UI elements and optionally text. + + Args: + image: Input PIL Image + box_threshold: Confidence threshold for detection + iou_threshold: IOU threshold for NMS + use_ocr: Whether to enable OCR processing + + Returns: + Tuple of (annotated image, list of detections) + """ + try: + logger.info("Starting UI element detection...") + + # Detect icons + icon_detections = self.detector.detect_icons( + image=image, box_threshold=box_threshold, iou_threshold=iou_threshold + ) + logger.info(f"Found {len(icon_detections)} interactive elements") + + # Convert icon detections to typed objects + elements: List[UIElement] = [ + IconElement( + bbox=BoundingBox( + x1=det["bbox"][0], y1=det["bbox"][1], x2=det["bbox"][2], y2=det["bbox"][3] + ), + confidence=det["confidence"], + scale=det.get("scale"), + ) + for det in icon_detections + ] + + # Run OCR if enabled + if use_ocr: + logger.info("Running OCR detection...") + text_detections = self.ocr.detect_text(image=image, confidence_threshold=0.5) + if text_detections is None: + text_detections = [] + logger.info(f"Found {len(text_detections)} text regions") + + # Convert text detections to typed objects + elements.extend( + [ + TextElement( + bbox=BoundingBox( + x1=det["bbox"][0], + y1=det["bbox"][1], + x2=det["bbox"][2], + y2=det["bbox"][3], + ), + content=det["content"], + confidence=det["confidence"], + ) + for det in text_detections + ] + ) + + # Calculate drawing parameters based on image size + box_overlay_ratio = max(image.size) / 3200 + draw_config = { + "font_size": int(12 * box_overlay_ratio), + "box_thickness": max(int(2 * box_overlay_ratio), 1), + "text_padding": max(int(3 * box_overlay_ratio), 1), + } + + # Convert elements back to dict format for visualization + detection_dicts = [ + { + "type": elem.type, + "bbox": elem.bbox.coordinates, + "confidence": elem.confidence, + "content": elem.content if isinstance(elem, TextElement) else None, + } + for elem in elements + ] + + # Create visualization + logger.info("Creating visualization...") + annotated_image = self.visualizer.draw_boxes( + image=image.copy(), detections=detection_dicts, draw_config=draw_config + ) + logger.info("Visualization complete") + + return annotated_image, elements + + except Exception as e: + logger.error(f"Error in process_image: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + raise + + def parse( + self, + screenshot_data: Union[bytes, str], + box_threshold: float = 0.3, + iou_threshold: float = 0.1, + use_ocr: bool = True, + ) -> ParseResult: + """Parse a UI screenshot to detect interactive elements and text. + + Args: + screenshot_data: Raw bytes or base64 string of the screenshot + box_threshold: Confidence threshold for detection + iou_threshold: IOU threshold for NMS + use_ocr: Whether to enable OCR processing + + Returns: + ParseResult object containing elements, annotated image, and metadata + """ + try: + start_time = time.time() + + # Convert input to PIL Image + if isinstance(screenshot_data, str): + screenshot_data = base64.b64decode(screenshot_data) + image = Image.open(io.BytesIO(screenshot_data)).convert("RGB") + + # Process image + annotated_image, elements = self.process_image( + image=image, + box_threshold=box_threshold, + iou_threshold=iou_threshold, + use_ocr=use_ocr, + ) + + # Convert annotated image to base64 + buffered = io.BytesIO() + annotated_image.save(buffered, format="PNG") + annotated_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") + + # Generate screen info text + screen_info = [] + parsed_content_list = [] + + # Set element IDs and generate human-readable descriptions + for i, elem in enumerate(elements): + # Set the ID (1-indexed) + elem.id = i + 1 + + if isinstance(elem, IconElement): + screen_info.append( + f"Box #{i+1}: Icon (confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates})" + ) + parsed_content_list.append( + { + "id": i + 1, + "type": "icon", + "bbox": elem.bbox.coordinates, + "confidence": elem.confidence, + "content": None, + } + ) + elif isinstance(elem, TextElement): + screen_info.append( + f"Box #{i+1}: Text '{elem.content}' (confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates})" + ) + parsed_content_list.append( + { + "id": i + 1, + "type": "text", + "bbox": elem.bbox.coordinates, + "confidence": elem.confidence, + "content": elem.content, + } + ) + + # Calculate metadata + latency = time.time() - start_time + width, height = image.size + + # Create ParseResult object with enhanced properties + result = ParseResult( + elements=elements, + annotated_image_base64=annotated_image_base64, + screen_info=screen_info, + parsed_content_list=parsed_content_list, + metadata=ParserMetadata( + image_size=(width, height), + num_icons=len([e for e in elements if isinstance(e, IconElement)]), + num_text=len([e for e in elements if isinstance(e, TextElement)]), + device=self.detector.device, + ocr_enabled=use_ocr, + latency=latency, + ), + ) + + # Return the ParseResult object directly + return result + + except Exception as e: + logger.error(f"Error in parse: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + raise + + +def main(): + """Command line interface for UI element detection.""" + parser = argparse.ArgumentParser(description="Detect UI elements and text in images") + parser.add_argument("image_path", help="Path to the input image") + parser.add_argument("--model-path", help="Path to YOLO model") + parser.add_argument( + "--box-threshold", type=float, default=0.3, help="Box confidence threshold (default: 0.3)" + ) + parser.add_argument( + "--iou-threshold", type=float, default=0.1, help="IOU threshold (default: 0.1)" + ) + parser.add_argument( + "--ocr", action="store_true", default=True, help="Enable OCR processing (default: True)" + ) + parser.add_argument("--output", help="Output path for annotated image") + args = parser.parse_args() + + # Setup logging + logging.basicConfig(level=logging.INFO) + + try: + # Initialize parser + parser = OmniParser(model_path=args.model_path) + + # Load and process image + logger.info(f"Loading image from: {args.image_path}") + image = Image.open(args.image_path).convert("RGB") + logger.info(f"Image loaded successfully, size: {image.size}") + + # Process image + annotated_image, elements = parser.process_image( + image=image, + box_threshold=args.box_threshold, + iou_threshold=args.iou_threshold, + use_ocr=args.ocr, + ) + + # Save output image + output_path = args.output or str( + Path(args.image_path).parent + / f"{Path(args.image_path).stem}_analyzed{Path(args.image_path).suffix}" + ) + logger.info(f"Saving annotated image to: {output_path}") + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + annotated_image.save(output_path) + logger.info(f"Image saved successfully to {output_path}") + + # Print detections + logger.info("\nDetections:") + for i, elem in enumerate(elements): + if isinstance(elem, IconElement): + logger.info( + f"Interactive element {i}: confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}" + ) + elif isinstance(elem, TextElement): + logger.info(f"Text {i}: '{elem.content}', bbox={elem.bbox.coordinates}") + + except Exception as e: + logger.error(f"Error processing image: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return 1 + + return 0 + + +if __name__ == "__main__": + import sys + + sys.exit(main()) diff --git a/libs/som/som/detection.py b/libs/som/som/detection.py new file mode 100644 index 00000000..3b585d9e --- /dev/null +++ b/libs/som/som/detection.py @@ -0,0 +1,240 @@ +from typing import List, Dict, Any, Tuple, Optional +import logging +import torch +import torchvision +from PIL import Image +import numpy as np +from ultralytics import YOLO +from huggingface_hub import hf_hub_download +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class DetectionProcessor: + """Class for handling YOLO-based icon detection.""" + + def __init__( + self, + model_path: Optional[Path] = None, + cache_dir: Optional[Path] = None, + force_device: Optional[str] = None, + ): + """Initialize the detection processor. + + Args: + model_path: Path to YOLOv8 model + cache_dir: Directory to cache downloaded models + force_device: Force specific device (cuda, cpu, mps) + """ + self.model_path = model_path + self.cache_dir = cache_dir + self.model = None # type: Any # Will be set to YOLO model in load_model + + # Set device + self.device = "cpu" + if torch.cuda.is_available() and force_device != "cpu": + self.device = "cuda" + elif ( + hasattr(torch, "backends") + and hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() + and force_device != "cpu" + ): + self.device = "mps" + + if force_device: + self.device = force_device + + logger.info(f"Using device: {self.device}") + + def load_model(self) -> None: + """Load or download the YOLO model.""" + try: + # Set default model path if none provided + if self.model_path is None: + self.model_path = Path(__file__).parent / "weights" / "icon_detect" / "model.pt" + + # Check if the model file already exists + if not self.model_path.exists(): + logger.info( + "Model not found locally, downloading from Microsoft OmniParser-v2.0..." + ) + + # Create directory + self.model_path.parent.mkdir(parents=True, exist_ok=True) + + try: + # Check if the model exists in cache + cache_path = None + if self.cache_dir: + # Try to find the model in the cache + potential_paths = list(Path(self.cache_dir).glob("**/model.pt")) + if potential_paths: + cache_path = str(potential_paths[0]) + logger.info(f"Found model in cache: {cache_path}") + + if not cache_path: + # Download from HuggingFace + downloaded_path = hf_hub_download( + repo_id="microsoft/OmniParser-v2.0", + filename="icon_detect/model.pt", + cache_dir=self.cache_dir, + ) + cache_path = downloaded_path + logger.info(f"Model downloaded to cache: {cache_path}") + + # Copy to package directory + import shutil + + shutil.copy2(cache_path, self.model_path) + logger.info(f"Model copied to: {self.model_path}") + except Exception as e: + raise FileNotFoundError( + f"Failed to download model: {str(e)}\n" + "Please ensure you have internet connection and huggingface-hub installed." + ) from e + + # Make sure the model path exists before loading + if not self.model_path.exists(): + raise FileNotFoundError(f"Model file not found at: {self.model_path}") + + # If model is already loaded, skip reloading + if self.model is not None: + logger.info("Model already loaded, skipping reload") + return + + logger.info(f"Loading YOLOv8 model from {self.model_path}") + from ultralytics import YOLO + + self.model = YOLO(str(self.model_path)) # Convert Path to string for compatibility + + # Verify model loaded successfully + if self.model is None: + raise ValueError("Model failed to initialize but didn't raise an exception") + + if self.device in ["cuda", "mps"]: + self.model.to(self.device) + + logger.info(f"Model loaded successfully with device: {self.device}") + except Exception as e: + logger.error(f"Failed to load model: {str(e)}") + # Re-raise with more informative message but preserve the model as None + self.model = None + raise RuntimeError(f"Failed to initialize detection model: {str(e)}") from e + + def detect_icons( + self, + image: Image.Image, + box_threshold: float = 0.05, + iou_threshold: float = 0.1, + multi_scale: bool = True, + ) -> List[Dict[str, Any]]: + """Detect icons in an image using YOLO. + + Args: + image: PIL Image to process + box_threshold: Confidence threshold for detection + iou_threshold: IOU threshold for NMS + multi_scale: Whether to use multi-scale detection + + Returns: + List of icon detection dictionaries + """ + # Load model if not already loaded + if self.model is None: + self.load_model() + + # Double-check the model was successfully loaded + if self.model is None: + logger.error("Model failed to load and is still None") + return [] # Return empty list instead of crashing + + img_width, img_height = image.size + all_detections = [] + + # Define detection scales + scales = ( + [{"size": 1280, "conf": box_threshold}] # Single scale for CPU + if self.device == "cpu" + else [ + {"size": 640, "conf": box_threshold}, # Base scale + {"size": 1280, "conf": box_threshold}, # Medium scale + {"size": 1920, "conf": box_threshold}, # Large scale + ] + ) + + if not multi_scale: + scales = [scales[0]] + + # Run detection at each scale + for scale in scales: + try: + if self.model is None: + logger.error("Model is None, skipping detection") + continue + + results = self.model.predict( + source=image, + conf=scale["conf"], + iou=iou_threshold, + max_det=1000, + verbose=False, + augment=self.device != "cpu", + agnostic_nms=True, + imgsz=scale["size"], + device=self.device, + ) + + # Process results + for r in results: + boxes = r.boxes + if not hasattr(boxes, "conf") or not hasattr(boxes, "xyxy"): + logger.warning("Boxes object missing expected attributes") + continue + + confidences = boxes.conf + coords = boxes.xyxy + + # Handle different types of tensors (PyTorch, NumPy, etc.) + if hasattr(confidences, "cpu"): + confidences = confidences.cpu() + if hasattr(coords, "cpu"): + coords = coords.cpu() + + for conf, bbox in zip(confidences, coords): + # Normalize coordinates + x1, y1, x2, y2 = bbox.tolist() + norm_bbox = [ + x1 / img_width, + y1 / img_height, + x2 / img_width, + y2 / img_height, + ] + + all_detections.append( + { + "type": "icon", + "confidence": conf.item(), + "bbox": norm_bbox, + "scale": scale["size"], + "interactivity": True, + } + ) + + except Exception as e: + logger.warning(f"Detection failed at scale {scale['size']}: {str(e)}") + continue + + # Merge detections using NMS + if len(all_detections) > 0: + boxes = torch.tensor([d["bbox"] for d in all_detections]) + scores = torch.tensor([d["confidence"] for d in all_detections]) + + keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold) + + merged_detections = [all_detections[i] for i in keep_indices] + else: + merged_detections = [] + + return merged_detections diff --git a/libs/som/som/models.py b/libs/som/som/models.py new file mode 100644 index 00000000..aa116106 --- /dev/null +++ b/libs/som/som/models.py @@ -0,0 +1,119 @@ +from typing import List, Tuple, Optional, Literal, Dict, Any, Union +from pydantic import BaseModel, Field, validator + + +class BoundingBox(BaseModel): + """Normalized bounding box coordinates.""" + + x1: float = Field(..., description="Normalized left coordinate") + y1: float = Field(..., description="Normalized top coordinate") + x2: float = Field(..., description="Normalized right coordinate") + y2: float = Field(..., description="Normalized bottom coordinate") + + @property + def coordinates(self) -> List[float]: + """Get coordinates as a list [x1, y1, x2, y2].""" + return [self.x1, self.y1, self.x2, self.y2] + + +class UIElement(BaseModel): + """Base class for UI elements.""" + + id: Optional[int] = Field(None, description="Unique identifier for the element (1-indexed)") + type: Literal["icon", "text"] + bbox: BoundingBox + interactivity: bool = Field(default=False, description="Whether the element is interactive") + confidence: float = Field(default=1.0, description="Detection confidence score") + + +class IconElement(UIElement): + """An interactive icon element.""" + + type: Literal["icon"] = "icon" + interactivity: bool = True + scale: Optional[int] = Field(None, description="Detection scale used") + + +class TextElement(UIElement): + """A text element.""" + + type: Literal["text"] = "text" + content: str = Field(..., description="The text content") + interactivity: bool = False + + +class ImageData(BaseModel): + """Image data with dimensions.""" + + base64: str = Field(..., description="Base64 encoded image data") + width: int = Field(..., description="Image width in pixels") + height: int = Field(..., description="Image height in pixels") + + @validator("width", "height") + def dimensions_must_be_positive(cls, v): + if v <= 0: + raise ValueError("Dimensions must be positive") + return v + + +class ParserMetadata(BaseModel): + """Metadata about the parsing process.""" + + image_size: Tuple[int, int] = Field( + ..., description="Original image dimensions (width, height)" + ) + num_icons: int = Field(..., description="Number of icons detected") + num_text: int = Field(..., description="Number of text elements detected") + device: str = Field(..., description="Device used for detection (cpu/cuda/mps)") + ocr_enabled: bool = Field(..., description="Whether OCR was enabled") + latency: float = Field(..., description="Total processing time in seconds") + + @property + def width(self) -> int: + """Get image width from image_size.""" + return self.image_size[0] + + @property + def height(self) -> int: + """Get image height from image_size.""" + return self.image_size[1] + + +class ParseResult(BaseModel): + """Result of parsing a UI screenshot.""" + + elements: List[UIElement] = Field(..., description="Detected UI elements") + annotated_image_base64: str = Field(..., description="Base64 encoded annotated image") + metadata: ParserMetadata = Field(..., description="Processing metadata") + screen_info: Optional[List[str]] = Field( + None, description="Human-readable descriptions of elements" + ) + parsed_content_list: Optional[List[Dict[str, Any]]] = Field( + None, description="Parsed elements as dictionaries" + ) + + @property + def image(self) -> ImageData: + """Get image data as a convenience property.""" + return ImageData( + base64=self.annotated_image_base64, + width=self.metadata.width, + height=self.metadata.height, + ) + + @property + def width(self) -> int: + """Get image width from metadata.""" + return self.metadata.width + + @property + def height(self) -> int: + """Get image height from metadata.""" + return self.metadata.height + + def model_dump(self) -> Dict[str, Any]: + """Convert model to dict for compatibility with older code.""" + result = super().model_dump() + # Add image data dict for backward compatibility + result["image"] = self.image.model_dump() + return result diff --git a/libs/som/som/ocr.py b/libs/som/som/ocr.py new file mode 100644 index 00000000..245fc9dc --- /dev/null +++ b/libs/som/som/ocr.py @@ -0,0 +1,162 @@ +from typing import List, Dict, Any, Tuple +import logging +import signal +from contextlib import contextmanager +from pathlib import Path +import easyocr +from PIL import Image +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +class TimeoutException(Exception): + pass + + +@contextmanager +def timeout(seconds: int): + def timeout_handler(signum, frame): + raise TimeoutException("OCR process timed out") + + original_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(seconds) + + try: + yield + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, original_handler) + + +class OCRProcessor: + """Class for handling OCR text detection.""" + + _shared_reader = None # Class-level shared reader instance + + def __init__(self): + """Initialize the OCR processor.""" + self.reader = None + # Determine best available device + self.device = "cpu" + if torch.cuda.is_available(): + self.device = "cuda" + elif ( + hasattr(torch, "backends") + and hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() + ): + self.device = "mps" + logger.info(f"OCR processor initialized with device: {self.device}") + + def _ensure_reader(self): + """Ensure EasyOCR reader is initialized. + + Uses a class-level cached reader to avoid reinitializing on every instance. + """ + # First check if we already have a class-level reader + if OCRProcessor._shared_reader is not None: + self.reader = OCRProcessor._shared_reader + return + + # Otherwise initialize a new one + if self.reader is None: + try: + logger.info("Initializing EasyOCR reader...") + import easyocr + + # Use GPU if available + use_gpu = self.device in ["cuda"] # MPS not directly supported by EasyOCR + + # If using MPS, add warnings to explain why CPU is used + if self.device == "mps": + logger.warning("EasyOCR doesn't support MPS directly. Using CPU instead.") + logger.warning( + "To silence this warning, set environment variable: PYTORCH_ENABLE_MPS_FALLBACK=1" + ) + + self.reader = easyocr.Reader(["en"], gpu=use_gpu) + + # Verify reader initialization + if self.reader is None: + raise ValueError("Failed to initialize EasyOCR reader") + + # Cache the reader at class level + OCRProcessor._shared_reader = self.reader + + logger.info(f"EasyOCR reader initialized successfully with GPU={use_gpu}") + except Exception as e: + logger.error(f"Failed to initialize EasyOCR reader: {str(e)}") + # Set to a placeholder that will be checked + self.reader = None + raise RuntimeError(f"EasyOCR initialization failed: {str(e)}") from e + + def detect_text( + self, image: Image.Image, confidence_threshold: float = 0.5, timeout_seconds: int = 5 + ) -> List[Dict[str, Any]]: + """Detect text in an image using EasyOCR. + + Args: + image: PIL Image to process + confidence_threshold: Minimum confidence for text detection + timeout_seconds: Maximum time to wait for OCR + + Returns: + List of text detection dictionaries + """ + try: + # Try to initialize reader, catch any exceptions + try: + self._ensure_reader() + except Exception as e: + logger.error(f"Failed to initialize OCR reader: {str(e)}") + return [] + + # Ensure reader was properly initialized + if self.reader is None: + logger.error("OCR reader is None after initialization") + return [] + + # Convert PIL Image to numpy array + image_np = np.array(image) + + try: + with timeout(timeout_seconds): + results = self.reader.readtext( + image_np, paragraph=False, text_threshold=confidence_threshold + ) + except TimeoutException: + logger.warning("OCR timed out") + return [] + except Exception as e: + logger.warning(f"OCR failed: {str(e)}") + return [] + + detections = [] + img_width, img_height = image.size + + for box, text, conf in results: + if conf < confidence_threshold: + continue + + # Convert box format to [x1, y1, x2, y2] + x1 = min(point[0] for point in box) / img_width + y1 = min(point[1] for point in box) / img_height + x2 = max(point[0] for point in box) / img_width + y2 = max(point[1] for point in box) / img_height + + detections.append( + { + "type": "text", + "bbox": [x1, y1, x2, y2], + "content": text, + "confidence": conf, + "interactivity": False, # Text is typically non-interactive + } + ) + + return detections + except Exception as e: + logger.error(f"Unexpected error in OCR processing: {str(e)}") + return [] diff --git a/libs/som/som/util/utils.py b/libs/som/som/util/utils.py new file mode 100644 index 00000000..6658b69a --- /dev/null +++ b/libs/som/som/util/utils.py @@ -0,0 +1,190 @@ +import easyocr +import cv2 +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image +from typing import Union +import time +import signal +from contextlib import contextmanager +import logging + +logger = logging.getLogger(__name__) + + +class TimeoutException(Exception): + pass + + +@contextmanager +def timeout(seconds): + def timeout_handler(signum, frame): + logger.warning(f"OCR process timed out after {seconds} seconds") + raise TimeoutException("OCR processing timed out") + + # Register the signal handler + original_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(seconds) + + try: + yield + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, original_handler) + + +# Initialize EasyOCR with optimized settings +logger.info("Initializing EasyOCR with optimized settings...") +reader = easyocr.Reader( + ["en"], + gpu=True, # Use GPU if available + model_storage_directory=None, # Use default directory + download_enabled=True, + detector=True, # Enable text detection + recognizer=True, # Enable text recognition + verbose=False, # Disable verbose output + quantize=True, # Enable quantization for faster inference + cudnn_benchmark=True, # Enable cuDNN benchmarking +) +logger.info("EasyOCR initialization complete") + + +def check_ocr_box( + image_source: Union[str, Image.Image], + display_img=True, + output_bb_format="xywh", + goal_filtering=None, + easyocr_args=None, + use_paddleocr=False, +): + """Check OCR box using EasyOCR with optimized settings. + + Args: + image_source: Either a file path or PIL Image + display_img: Whether to display the annotated image + output_bb_format: Format for bounding boxes ('xywh' or 'xyxy') + goal_filtering: Optional filtering of results + easyocr_args: Arguments for EasyOCR + use_paddleocr: Ignored (kept for backward compatibility) + """ + logger.info("Starting OCR processing...") + start_time = time.time() + + if isinstance(image_source, str): + logger.info(f"Loading image from path: {image_source}") + image_source = Image.open(image_source) + if image_source.mode == "RGBA": + logger.info("Converting RGBA image to RGB") + image_source = image_source.convert("RGB") + image_np = np.array(image_source) + w, h = image_source.size + logger.info(f"Image size: {w}x{h}") + + # Default EasyOCR arguments optimized for speed + default_args = { + "paragraph": False, # Disable paragraph detection + "text_threshold": 0.5, # Confidence threshold + "link_threshold": 0.4, # Text link threshold + "canvas_size": 2560, # Max image size + "mag_ratio": 1.0, # Magnification ratio + "slope_ths": 0.1, # Slope threshold + "ycenter_ths": 0.5, # Y-center threshold + "height_ths": 0.5, # Height threshold + "width_ths": 0.5, # Width threshold + "add_margin": 0.1, # Margin around text + "min_size": 20, # Minimum text size + } + + # Update with user-provided arguments + if easyocr_args: + logger.info(f"Using custom EasyOCR arguments: {easyocr_args}") + default_args.update(easyocr_args) + + try: + # Use EasyOCR with timeout + logger.info("Starting EasyOCR detection with 5 second timeout...") + with timeout(5): # 5 second timeout + result = reader.readtext(image_np, **default_args) + coord = [item[0] for item in result] + text = [item[1] for item in result] + logger.info(f"OCR completed successfully. Found {len(text)} text regions") + logger.info(f"Detected text: {text}") + + except TimeoutException: + logger.error("OCR processing timed out after 5 seconds") + coord = [] + text = [] + except Exception as e: + logger.error(f"OCR processing failed with error: {str(e)}") + coord = [] + text = [] + + processing_time = time.time() - start_time + logger.info(f"Total OCR processing time: {processing_time:.2f} seconds") + + if display_img: + logger.info("Creating visualization of OCR results...") + opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) + bb = [] + for item in coord: + x, y, a, b = get_xywh(item) + bb.append((x, y, a, b)) + cv2.rectangle(opencv_img, (x, y), (x + a, y + b), (0, 255, 0), 2) + plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB)) + else: + if output_bb_format == "xywh": + bb = [get_xywh(item) for item in coord] + elif output_bb_format == "xyxy": + bb = [get_xyxy(item) for item in coord] + + logger.info("OCR processing complete") + return (text, bb), goal_filtering + + +def get_xywh(box): + """ + Convert a bounding box to xywh format (x, y, width, height). + + Args: + box: Bounding box coordinates (various formats supported) + + Returns: + Tuple of (x, y, width, height) + """ + # Handle different input formats + if len(box) == 4: + # If already in xywh format or xyxy format + if isinstance(box[0], (int, float)) and isinstance(box[2], (int, float)): + if box[2] < box[0] or box[3] < box[1]: + # Already xyxy format, convert to xywh + x1, y1, x2, y2 = box + return x1, y1, x2 - x1, y2 - y1 + else: + # Already in xywh format + return box + elif len(box) == 2: + # Format like [[x1,y1],[x2,y2]] from some OCR engines + (x1, y1), (x2, y2) = box + return x1, y1, x2 - x1, y2 - y1 + + # Default case - try to convert assuming it's a list of points + x_coords = [p[0] for p in box] + y_coords = [p[1] for p in box] + x1, y1 = min(x_coords), min(y_coords) + width, height = max(x_coords) - x1, max(y_coords) - y1 + return x1, y1, width, height + + +def get_xyxy(box): + """ + Convert a bounding box to xyxy format (x1, y1, x2, y2). + + Args: + box: Bounding box coordinates (various formats supported) + + Returns: + Tuple of (x1, y1, x2, y2) + """ + # Get xywh first, then convert to xyxy + x, y, w, h = get_xywh(box) + return x, y, x + w, y + h diff --git a/libs/som/som/visualization.py b/libs/som/som/visualization.py new file mode 100644 index 00000000..4212379c --- /dev/null +++ b/libs/som/som/visualization.py @@ -0,0 +1,274 @@ +from typing import List, Dict, Any, Tuple +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import supervision as sv +import platform +import os +import logging + +logger = logging.getLogger(__name__) + + +class BoxAnnotator: + """Class for drawing bounding boxes and labels on images.""" + + def __init__(self): + """Initialize the box annotator with a color palette.""" + # WCAG 2.1 compliant color palette optimized for accessibility + self.colors = [ + "#2E7D32", # Green + "#C62828", # Red + "#1565C0", # Blue + "#6A1B9A", # Purple + "#EF6C00", # Orange + "#283593", # Indigo + "#4527A0", # Deep Purple + "#00695C", # Teal + "#D84315", # Deep Orange + "#1B5E20", # Dark Green + "#B71C1C", # Dark Red + "#0D47A1", # Dark Blue + "#4A148C", # Dark Purple + "#E65100", # Dark Orange + "#1A237E", # Dark Indigo + "#311B92", # Darker Purple + "#004D40", # Dark Teal + "#BF360C", # Darker Orange + "#33691E", # Darker Green + "#880E4F", # Pink + ] + self.color_index = 0 + self.default_font = None + self._initialize_font() + + def _initialize_font(self) -> None: + """Initialize the default font.""" + # Try to load a system font first + system = platform.system() + font_paths = [] + + if system == "Darwin": # macOS + font_paths = [ + "/System/Library/Fonts/Helvetica.ttc", + "/System/Library/Fonts/Arial.ttf", + "/Library/Fonts/Arial.ttf", + ] + elif system == "Linux": + font_paths = [ + "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", + "/usr/share/fonts/TTF/DejaVuSans.ttf", + "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", + ] + else: # Windows + font_paths = ["C:\\Windows\\Fonts\\arial.ttf"] + + # Try each font path + for font_path in font_paths: + if os.path.exists(font_path): + try: + # Test the font with a small size + test_font = ImageFont.truetype(font_path, 12) + # Test if the font can render text + test_font.getbbox("1") + self.default_font = font_path + return + except Exception: + continue + + def _get_next_color(self) -> str: + """Get the next color from the palette.""" + color = self.colors[self.color_index] + self.color_index = (self.color_index + 1) % len(self.colors) + return color + + def _hex_to_rgb(self, hex_color: str) -> Tuple[int, int, int]: + """Convert hex color to RGB tuple.""" + hex_color = hex_color.lstrip("#") + # Create explicit tuple of 3 integers to match the return type + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) + return (r, g, b) + + def draw_boxes( + self, image: Image.Image, detections: List[Dict[str, Any]], draw_config: Dict[str, Any] + ) -> Image.Image: + """Draw bounding boxes and labels on the image.""" + draw = ImageDraw.Draw(image) + + # Create smaller font while keeping contrast + try: + if self.default_font: + font = ImageFont.truetype(self.default_font, size=12) # Reduced from 16 to 12 + else: + # If no TrueType font available, use default + font = ImageFont.load_default() + except Exception: + font = ImageFont.load_default() + + padding = 2 # Reduced padding for smaller overall box + spacing = 1 # Reduced spacing between elements + + # Keep track of used label areas to check for collisions + used_areas = [] + + # Store label information for second pass + labels_to_draw = [] + + # First pass: Draw all bounding boxes + for idx, detection in enumerate(detections, 1): + # Get box coordinates + box = detection["bbox"] + x1, y1, x2, y2 = [ + int(coord * dim) for coord, dim in zip(box, [image.width, image.height] * 2) + ] + + # Get color for this detection + color = self._get_next_color() + rgb_color = self._hex_to_rgb(color) + + # Draw bounding box with original width + draw.rectangle(((x1, y1), (x2, y2)), outline=rgb_color, width=2) + + # Use detection number as label + label = str(idx) + + # Get text dimensions using getbbox + bbox = font.getbbox(label) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + + # Create box dimensions with padding + box_width = text_width + (padding * 2) # Removed multiplier for tighter box + box_height = text_height + (padding * 2) # Removed multiplier for tighter box + + def is_inside_bbox(x, y): + """Check if a label box would be inside the bounding box.""" + return x >= x1 and x + box_width <= x2 and y >= y1 and y + box_height <= y2 + + # Try different positions until we find one without collision + positions = [ + # Top center (above bbox) + lambda: (x1 + ((x2 - x1) - box_width) // 2, y1 - box_height - spacing), + # Bottom center (below bbox) + lambda: (x1 + ((x2 - x1) - box_width) // 2, y2 + spacing), + # Right center (right of bbox) + lambda: (x2 + spacing, y1 + ((y2 - y1) - box_height) // 2), + # Left center (left of bbox) + lambda: (x1 - box_width - spacing, y1 + ((y2 - y1) - box_height) // 2), + # Top right (outside corner) + lambda: (x2 + spacing, y1 - box_height - spacing), + # Top left (outside corner) + lambda: (x1 - box_width - spacing, y1 - box_height - spacing), + # Bottom right (outside corner) + lambda: (x2 + spacing, y2 + spacing), + # Bottom left (outside corner) + lambda: (x1 - box_width - spacing, y2 + spacing), + ] + + def check_collision(x, y): + """Check if a label box collides with any existing ones or is inside bbox.""" + # First check if it's inside the bounding box + if is_inside_bbox(x, y): + return True + + # Then check collision with other labels + new_box = (x, y, x + box_width, y + box_height) + for used_box in used_areas: + if not ( + new_box[2] < used_box[0] # new box is left of used box + or new_box[0] > used_box[2] # new box is right of used box + or new_box[3] < used_box[1] # new box is above used box + or new_box[1] > used_box[3] + ): # new box is below used box + return True + return False + + # Try each position until we find one without collision + label_x = None + label_y = None + + for get_pos in positions: + x, y = get_pos() + # Ensure position is within image bounds + if x < 0 or y < 0 or x + box_width > image.width or y + box_height > image.height: + continue + if not check_collision(x, y): + label_x = x + label_y = y + break + + # If all positions collide or are out of bounds, find the best possible position + if label_x is None: + # Try to place it in the nearest valid position outside the bbox + best_pos = positions[0]() # Default to top center + label_x = max(0, min(image.width - box_width, best_pos[0])) + label_y = max(0, min(image.height - box_height, best_pos[1])) + + # Ensure it's not inside the bounding box + if is_inside_bbox(label_x, label_y): + # Force it above the bounding box + label_y = max(0, y1 - box_height - spacing) + + # Add this label area to used areas + if ( + label_x is not None + and label_y is not None + and box_width is not None + and box_height is not None + ): + used_areas.append((label_x, label_y, label_x + box_width, label_y + box_height)) + + # Store label information for second pass + labels_to_draw.append( + { + "label": label, + "x": label_x, + "y": label_y, + "width": box_width, + "height": box_height, + "text_width": text_width, + "text_height": text_height, + "color": rgb_color, + } + ) + + # Second pass: Draw all labels on top + for label_info in labels_to_draw: + # Draw background box with white outline + draw.rectangle( + ( + (label_info["x"] - 1, label_info["y"] - 1), + ( + label_info["x"] + label_info["width"] + 1, + label_info["y"] + label_info["height"] + 1, + ), + ), + outline="white", + width=2, + ) + draw.rectangle( + ( + (label_info["x"], label_info["y"]), + (label_info["x"] + label_info["width"], label_info["y"] + label_info["height"]), + ), + fill=label_info["color"], + ) + + # Center text in box + text_x = label_info["x"] + (label_info["width"] - label_info["text_width"]) // 2 + text_y = label_info["y"] + (label_info["height"] - label_info["text_height"]) // 2 + + # Draw text with black outline for better visibility + outline_width = 1 + for dx in [-outline_width, outline_width]: + for dy in [-outline_width, outline_width]: + draw.text( + (text_x + dx, text_y + dy), label_info["label"], fill="black", font=font + ) + + # Draw the main white text + draw.text((text_x, text_y), label_info["label"], fill=(255, 255, 255), font=font) + + logger.info("Finished drawing all boxes") + return image diff --git a/libs/som/tests/test_omniparser.py b/libs/som/tests/test_omniparser.py new file mode 100644 index 00000000..2edbdcd0 --- /dev/null +++ b/libs/som/tests/test_omniparser.py @@ -0,0 +1,13 @@ +# """Basic tests for the omniparser package.""" + +# import pytest +# from omniparser import IconDetector + +# def test_icon_detector_import(): +# """Test that we can import the IconDetector class.""" +# assert IconDetector is not None + +# def test_icon_detector_init(): +# """Test that we can create an IconDetector instance.""" +# detector = IconDetector(force_cpu=True) +# assert detector is not None diff --git a/notebooks/agent_nb.ipynb b/notebooks/agent_nb.ipynb new file mode 100644 index 00000000..47f91189 --- /dev/null +++ b/notebooks/agent_nb.ipynb @@ -0,0 +1,229 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Agent\n", + "\n", + "This notebook demonstrates how to use Cua's Agent to run a workflow in a virtual sandbox on Apple Silicon Macs." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "zsh:1: no matches found: cua-agent[all]\n" + ] + } + ], + "source": [ + "!pip uninstall cua-agent[all]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install \"cua-agent[all]\"\n", + "\n", + "# Or install individual agent loops:\n", + "# !pip install cua-agent[anthropic]\n", + "# !pip install cua-agent[omni]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If locally installed, use this instead:\n", + "import os\n", + "\n", + "os.chdir('../libs/agent')\n", + "!poetry install\n", + "!poetry build\n", + "\n", + "!pip uninstall cua-agent -y\n", + "!pip install ./dist/cua_agent-0.1.0-py3-none-any.whl --force-reinstall" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize a Computer Agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Computer allows you to run an agentic workflow in a virtual sandbox instances on Apple Silicon. Here's a basic example:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from agent import ComputerAgent, AgenticLoop, LLMProvider\n", + "from computer import Computer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similar to Computer, you can either use the async context manager pattern or initialize the ComputerAgent instance directly." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# Set your API key\n", + "!export ANTHROPIC_API_KEY=\"your-anthropic-api-key\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Direct initialization:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "import logging\n", + "from pathlib import Path\n", + "\n", + "computer = Computer(verbosity=logging.INFO)\n", + "\n", + "# Create agent with Anthropic loop and provider\n", + "agent = ComputerAgent(\n", + " computer=computer,\n", + " api_key=\"\",\n", + " loop_type=AgenticLoop.ANTHROPIC,\n", + " ai_provider=LLMProvider.ANTHROPIC,\n", + " model='claude-3-7-sonnet-20250219',\n", + " save_trajectory=True,\n", + " trajectory_dir=str(Path(\"trajectories\") / datetime.now().strftime(\"%Y%m%d_%H%M%S\")),\n", + " only_n_most_recent_images=3, # Slightly different from the omni example\n", + " verbosity=logging.INFO,\n", + ")\n", + "\n", + "tasks = [\n", + "\"\"\"\n", + "Please help me with the following task:\n", + "1. Open Safari browser\n", + "2. Go to Wikipedia.org\n", + "3. Search for \"Claude AI\" \n", + "4. Summarize the main points you find about Claude AI\n", + "\"\"\"\n", + "]\n", + "\n", + "async with agent:\n", + " for i, task in enumerate(tasks, 1):\n", + " print(f\"\\nExecuting task {i}/{len(tasks)}: {task}\")\n", + " async for result in agent.run(task):\n", + " print(result)\n", + " print(f\"Task {i} completed\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or using the Omni Agentic Loop:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "import logging\n", + "from pathlib import Path\n", + "\n", + "computer = Computer(verbosity=logging.INFO)\n", + "\n", + "# Create agent with Anthropic loop and provider\n", + "agent = ComputerAgent(\n", + " computer=computer,\n", + " api_key=\"\",\n", + " loop_type=AgenticLoop.OMNI,\n", + " ai_provider=LLMProvider.ANTHROPIC,\n", + " model='claude-3-7-sonnet-20250219',\n", + " save_trajectory=True,\n", + " trajectory_dir=str(Path(\"trajectories\") / datetime.now().strftime(\"%Y%m%d_%H%M%S\")),\n", + " only_n_most_recent_images=3, # Slightly different from the omni example\n", + " verbosity=logging.INFO,\n", + ")\n", + "\n", + "tasks = [\n", + "\"\"\"\n", + "Please help me with the following task:\n", + "1. Open Safari browser\n", + "2. Go to Wikipedia.org\n", + "3. Search for \"Claude AI\" \n", + "4. Summarize the main points you find about Claude AI\n", + "\"\"\"\n", + "]\n", + "\n", + "async with agent:\n", + " for i, task in enumerate(tasks, 1):\n", + " print(f\"\\nExecuting task {i}/{len(tasks)}: {task}\")\n", + " async for result in agent.run(task):\n", + " print(result)\n", + " print(f\"Task {i} completed\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/computer_nb.ipynb b/notebooks/computer_nb.ipynb new file mode 100644 index 00000000..caa821e1 --- /dev/null +++ b/notebooks/computer_nb.ipynb @@ -0,0 +1,506 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Computer\n", + "\n", + "This notebook demonstrates how to use Computer to operate a Lume sandbox programmatically on Apple Silicon macOS systems." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip uninstall -y cua-computer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install cua-computer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If locally installed, use this instead:\n", + "import os\n", + "\n", + "os.chdir('../libs/computer')\n", + "!poetry install\n", + "!poetry build\n", + "\n", + "!pip uninstall cua-computer -y\n", + "!pip install ./dist/cua_computer-0.1.0-py3-none-any.whl --force-reinstall" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Lume daemon\n", + "\n", + "While a `lume` binary is included with Computer, we recommend installing the standalone version with brew, and starting the lume daemon service before running Computer. Refer to [../libs/lume/README.md](../libs/lume/README.md) for more details on lume cli." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!brew tap trycua/lume\n", + "!brew install lume" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Start the lume daemon service (from another terminal):\n", + "\n", + "```bash\n", + "lume serve\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Pull the latest pre-built macos-sequoia-cua image. This image, based on macOS Sequoia, contains all dependencies needed to be controlled from the Computer interface." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!lume pull macos-sequoia-cua:latest" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The initial image download of thew macos-sequoia-cua image requires 80GB of storage space. However, after the first run, the image size reduces to around 20GB. Thanks to macOS's sparse file system, VM disk space is allocated dynamically - while VMs may show a total size of 50GB, they typically only consume about 20GB of physical disk space.\n", + "\n", + "Sandbox are stored in `~/.lume`, and locally cached images are stored in `~/.lume/cache`.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can always see the list of downloaded VM images with:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!lume ls" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing the sandbox\n", + "\n", + "Once downloaded, you can run the sandbox to test everything is working:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!lume run macos-sequoia-cua:latest" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can add additional software and tools to the sandbox - these changes will be saved in the VM disk." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize a Computer instance" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Computer allows you to create and control a virtual sandbox instances from your host on Apple Silicon. Here's a basic example:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from computer import Computer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can either use the async context manager or initialize the Computer instance directly." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the async context manager:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async with Computer(\n", + " # name=\"my_vm\", # optional, in case you want to use any other VM created using lume\n", + " display=\"1024x768\",\n", + " memory=\"8GB\",\n", + " cpu=\"4\",\n", + " os=\"macos\"\n", + ") as computer:\n", + " await computer.run()\n", + " # ... do something with the computer interface" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Direct initialization:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "computer = Computer(\n", + " display=\"1024x768\",\n", + " memory=\"8GB\",\n", + " cpu=\"4\",\n", + " os=\"macos\"\n", + ")\n", + "\n", + "await computer.run()\n", + "# ... do something with the computer interface" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Computer instance requires a Lume server for communication. Here's how it works:\n", + "\n", + "1. First, it attempts to connect to any existing Lume server running on port 3000\n", + "2. If no Lume server is available, it automatically starts a new one via [lume serve](https://github.com/trycua/lume/?tab=readme-ov-file#local-api-server)\n", + "\n", + "The sandbox's lifecycle is tied to the Lume server:\n", + "- If Computer started the Lume server itself, the server will be terminated when Computer stops\n", + "- If Computer connected to a pre-existing server, that server remains running after Computer stops\n", + "\n", + "If you have scenarios where you need to run multiple sandboxes in parallel, we recommend first starting the Lume server separately with `lume serve` (refer to [Lume](https://github.com/trycua/lume/?tab=readme-ov-file#install) on how to install) and then having each Computer instance connect to it." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to execute commands targeting the sandbox, the Computer interface communicates through websockets to a Fast API server running on the sandbox." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cursor" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Move and click\n", + "await computer.interface.move_cursor(100, 100)\n", + "await computer.interface.left_click()\n", + "await computer.interface.right_click(300, 300)\n", + "await computer.interface.double_click(400, 400)\n", + "\n", + "# Drag operations\n", + "await computer.interface.drag_to(500, 500, duration=1.0)\n", + "\n", + "# Get cursor position\n", + "cursor_pos = await computer.interface.get_cursor_position()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Keyboard" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Type text\n", + "await computer.interface.type_text(\"Hello, World!\")\n", + "\n", + "# Press individual keys\n", + "await computer.interface.press_key(\"enter\")\n", + "await computer.interface.press_key(\"escape\")\n", + "\n", + "# Use hotkeys\n", + "await computer.interface.hotkey(\"command\", \"c\") # Command+C on macOS" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Screen" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Get screen dimensions\n", + "screen_size = await computer.interface.get_screen_size()\n", + "\n", + "# Take basic screenshot\n", + "screenshot = await computer.interface.screenshot()\n", + "with open(\"screenshot.png\", \"wb\") as f:\n", + " f.write(screenshot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Clipboard" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set clipboard content\n", + "await computer.interface.set_clipboard(\"Text to copy\")\n", + "\n", + "# Get clipboard content\n", + "clipboard_content = await computer.interface.copy_to_clipboard()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### File System Operations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check file/directory existence\n", + "file_exists = await computer.interface.file_exists(\"/path/to/file.txt\")\n", + "dir_exists = await computer.interface.directory_exists(\"/path/to/directory\")\n", + "\n", + "# Run shell commands\n", + "stdout, stderr = await computer.interface.run_command(\"ls -la\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Accessibility" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Get accessibility tree\n", + "accessibility_tree = await computer.interface.get_accessibility_tree()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sharing a directory with the sandbox" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can share a directory with the sandbox by passing a list of absolute paths to the `shared_directories` argument when initializing the Computer instance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "computer = Computer(\n", + " display=\"1024x768\",\n", + " memory=\"4GB\",\n", + " cpu=\"2\",\n", + " os=\"macos\",\n", + " shared_directories=[\"/absolute/path/to/directory\"]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using Host Computer\n", + "\n", + "For local development and testing purposes, you can run the Computer API server on the same host machine and target it instead:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "computer = Computer(\n", + " display=\"1024x768\",\n", + " memory=\"4GB\",\n", + " cpu=\"2\",\n", + " os=\"macos\",\n", + " use_host_computer_server=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Examples" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from computer.computer import Computer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async with Computer(\n", + " display=\"1024x768\",\n", + " memory=\"4GB\",\n", + " cpu=\"2\",\n", + " os=\"macos\"\n", + ") as computer:\n", + " await computer.run()\n", + " res = await computer.interface.run_command(\"ls -a\")\n", + "\n", + " # Get screen dimensions\n", + " screen_size = await computer.interface.get_screen_size()\n", + "\n", + " # Move and click\n", + " await computer.interface.move_cursor(100, 100)\n", + " await computer.interface.left_click()\n", + " await computer.interface.right_click(300, 300)\n", + " await computer.interface.double_click(400, 400)\n", + "\n", + " # Drag operations\n", + " await computer.interface.drag_to(500, 500, duration=1.0)\n", + "\n", + " # Get cursor position\n", + " cursor_pos = await computer.interface.get_cursor_position()\n", + "\n", + " # Your automation code here\n", + " await computer.stop()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/computer_server_nb.ipynb b/notebooks/computer_server_nb.ipynb new file mode 100644 index 00000000..536e4e67 --- /dev/null +++ b/notebooks/computer_server_nb.ipynb @@ -0,0 +1,117 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Computer Server\n", + "\n", + "This notebook demonstrates how to host the server used by Computer." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install cua-computer-server" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If locally installed, use this instead:\n", + "import os\n", + "\n", + "os.chdir('../libs/computer-server')\n", + "!pdm install\n", + "!pdm build\n", + "\n", + "!pip uninstall cua-computer-server -y\n", + "!pip install ./dist/cua_computer_server-0.1.0-py3-none-any.whl --force-reinstall" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Start the Computer server" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==> Starting computer-server on 0.0.0.0:8000...\n", + "Starting computer-server on 0.0.0.0:8000...\n", + "\u001b[32mINFO\u001b[0m: Started server process [\u001b[36m65480\u001b[0m]\n", + "\u001b[32mINFO\u001b[0m: Waiting for application startup.\n", + "\u001b[32mINFO\u001b[0m: Application startup complete.\n", + "\u001b[32mINFO\u001b[0m: Uvicorn running on \u001b[1mhttp://0.0.0.0:8000\u001b[0m (Press CTRL+C to quit)\n", + "^C\n", + "\u001b[32mINFO\u001b[0m: Shutting down\n", + "\u001b[32mINFO\u001b[0m: Waiting for application shutdown.\n", + "\u001b[32mINFO\u001b[0m: Application shutdown complete.\n", + "\u001b[32mINFO\u001b[0m: Finished server process [\u001b[36m65480\u001b[0m]\n" + ] + } + ], + "source": [ + "import os\n", + "# os.chdir('../../scripts')\n", + "\n", + "! ./run_computer_server.sh\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Synchronous usage\n", + "from computer_server import Server\n", + "\n", + "server = Server(port=8000)\n", + "server.start() # Blocks until stopped" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/pylume_nb.ipynb b/notebooks/pylume_nb.ipynb new file mode 100644 index 00000000..76c37419 --- /dev/null +++ b/notebooks/pylume_nb.ipynb @@ -0,0 +1,348 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Intro\n", + "\n", + "This notebook provides a quickstart guide to the pylume python library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip uninstall pylume -y" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install pylume" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If locally installed, use this instead:\n", + "# !poetry install\n", + "# !poetry build\n", + "!pip uninstall pylume -y && pip install ./dist/pylume-0.1.0-py3-none-any.whl --force-reinstall" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from pylume import (\n", + " PyLume, \n", + " ImageRef, \n", + " VMRunOpts, \n", + " SharedDirectory, \n", + " VMConfig,\n", + " VMUpdateOpts\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get latest IPSW URL from Apple Server\n", + "\n", + "This is used to create a new macOS VM by providing the downloaded IPSW file path to the `ipsw` argument in the `create_vm` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def get_ipsw():\n", + " async with PyLume(port=3000) as pylume:\n", + " url = await pylume.get_latest_ipsw_url()\n", + " print(f\"Latest IPSW URL: {url}\")\n", + "\n", + "await get_ipsw()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a new VM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### macOS\n", + "\n", + "An IPSW file path is required to create a new macOS VM. To fetch automatically the latest IPSW during the VM creation, use `ipsw=\"latest\"`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def create_macos_vm():\n", + " async with PyLume() as pylume:\n", + " vm_config = VMConfig(\n", + " name=\"macos-vm\",\n", + " os=\"macOS\",\n", + " cpu=4,\n", + " memory=\"4GB\",\n", + " disk_size=\"40GB\",\n", + " display=\"1024x768\",\n", + " ipsw=\"latest\"\n", + " )\n", + " await pylume.create_vm(vm_config)\n", + "\n", + "await create_macos_vm()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Linux\n", + "\n", + "To create a new Linux VM, use the `os=\"linux\"` argument in the `VMConfig` class. Note that this doesn't set up any Linux distribution, it just creates a VM with a Linux kernel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def create_linux_vm():\n", + " async with PyLume() as pylume:\n", + " vm_config = VMConfig(\n", + " name=\"linux-vm\",\n", + " os=\"linux\",\n", + " cpu=2,\n", + " memory=\"4GB\",\n", + " disk_size=\"25GB\",\n", + " display=\"1024x768\"\n", + " )\n", + " await pylume.create_vm(vm_config)\n", + "\n", + "await create_linux_vm()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pull an image from ghcr.io/trycua" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Cua provides pre-built images for macOS and Linux." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def pull_macos_image():\n", + " async with PyLume() as pylume:\n", + " image_ref = ImageRef(\n", + " image=\"macos-sequoia-vanilla\",\n", + " tag=\"15.2\",\n", + " registry=\"ghcr.io\",\n", + " organization=\"trycua\"\n", + " )\n", + " await pylume.pull_image(image_ref, name=\"macos-sequoia-vanilla\")\n", + "\n", + "await pull_macos_image()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run\n", + "\n", + "Run a VM by providing the `VMRunConfig` object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def run_vm():\n", + " async with PyLume() as pylume:\n", + " vm_name = \"macos-sequoia-vanilla\"\n", + " run_opts = VMRunOpts(\n", + " no_display=False,\n", + " shared_directories=[\n", + " SharedDirectory(\n", + " host_path=\"/Users//Shared\",\n", + " read_only=False\n", + " )\n", + " ]\n", + " )\n", + " await pylume.run_vm(vm_name, run_opts)\n", + "\n", + "await run_vm()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### List existing VMs\n", + "\n", + "VMs are stored in the `~/.lume` directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async with PyLume() as pylume:\n", + " vms = await pylume.list_vms()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get VM status" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async with PyLume() as pylume:\n", + " status = await pylume.get_vm(\"macos-sequoia-vanilla\")\n", + " print(status)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Update VM Settings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "update_config = VMUpdateOpts(\n", + " cpu=8,\n", + " memory=\"8GB\"\n", + ")\n", + "async with PyLume() as pylume:\n", + " await pylume.update_vm(\"macos-sequoia-vanilla\", update_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Stop a VM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async with PyLume() as pylume:\n", + " await pylume.stop_vm(\"macos-sequoia-vanilla\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Delete a VM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async with PyLume() as pylume:\n", + " await pylume.delete_vm(\"linux-vm\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Images\n", + "\n", + "List the images available locally" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async with PyLume() as pylume:\n", + " await pylume.get_images()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..b00dbc31 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,86 @@ +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" + +[project] +name = "cua-workspace" +version = "0.1.0" +description = "CUA (Computer Use Agent) mono-repo" +authors = [ + { name = "TryCua", email = "gh@trycua.com" } +] +dependencies = [] +requires-python = ">=3.10,<3.13" +readme = "README.md" +license = { text = "MIT" } + +[project.urls] +repository = "https://github.com/trycua/cua" + +[dependency-groups] +dev = [] +examples = [] + +[tool.pdm] +distribution = false + +[tool.pdm.dev-dependencies] +dev = [ + "black>=23.0.0", + "ruff>=0.9.2", + "mypy>=1.10.0", + "types-requests>=2.31.0", + "ipykernel>=6.29.5", + "jupyter>=1.0.0", + "jedi>=0.19.2" +] +test = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.21.1", + "pytest-cov>=4.1.0", + "pytest-mock>=3.10.0", + "pytest-xdist>=3.6.1", + "aioresponses>=0.7.4" +] +docs = [ + "mkdocs>=1.5.0", + "mkdocs-material>=9.2.0" +] + +[tool.pdm.resolution] +respect-source-order = true + +[tool.pdm.resolution.overrides] +cua-computer = { path = "libs/computer" } +cua-omniparser = { path = "libs/omniparser" } +cua-agent = { path = "libs/agent" } +pylume = { path = "libs/pylume" } +cua-computer-server = { path = "libs/computer-server" } + +[tool.black] +line-length = 100 +target-version = ["py310"] + +[tool.ruff] +line-length = 100 +target-version = "py310" +select = ["E", "F", "B", "I"] +fix = true + +[tool.ruff.format] +docstring-code-format = true + +[tool.mypy] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["libs/*/tests"] +python_files = "test_*.py" \ No newline at end of file diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 00000000..3727e40d --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,24 @@ +{ + "include": [ + "**/*.py" + ], + "exclude": [ + "**/node_modules/**", + "**/__pycache__/**", + "**/.*/**", + "**/venv/**", + "**/.venv/**", + "**/dist/**", + "**/build/**", + ".pdm-build/**", + "**/.git/**", + "examples/**", + "notebooks/**", + "logs/**", + "screenshots/**" + ], + "typeCheckingMode": "basic", + "useLibraryCodeForTypes": true, + "reportMissingImports": false, + "reportMissingModuleSource": false +} diff --git a/scripts/build.sh b/scripts/build.sh new file mode 100755 index 00000000..5545aadd --- /dev/null +++ b/scripts/build.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +# Exit on error +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Function to print step information +print_step() { + echo -e "${BLUE}==> $1${NC}" +} + +# Function to print success message +print_success() { + echo -e "${GREEN}==> Success: $1${NC}" +} + +# Function to print error message +print_error() { + echo -e "${RED}==> Error: $1${NC}" >&2 +} + +# Get the script's directory +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +PROJECT_ROOT="$( cd "${SCRIPT_DIR}/.." && pwd )" + +# Change to project root +cd "$PROJECT_ROOT" + +# Load environment variables from .env.local +if [ -f .env.local ]; then + print_step "Loading environment variables from .env.local..." + set -a + source .env.local + set +a + print_success "Environment variables loaded" +else + print_error ".env.local file not found" + exit 1 +fi + +# Clean up existing environments and cache +print_step "Cleaning up existing environments..." +find . -type d -name "__pycache__" -exec rm -rf {} + +find . -type d -name ".pytest_cache" -exec rm -rf {} + +find . -type d -name "dist" -exec rm -rf {} + +find . -type d -name ".venv" -exec rm -rf {} + +find . -type d -name "*.egg-info" -exec rm -rf {} + +print_success "Environment cleanup complete" + +# Create and activate virtual environment +print_step "Creating virtual environment..." +python -m venv .venv +source .venv/bin/activate + +# Upgrade pip and install build tools +print_step "Upgrading pip and installing build tools..." +python -m pip install --upgrade pip setuptools wheel + +# Function to install a package and its dependencies +install_package() { + local package_dir=$1 + local package_name=$2 + local extras=$3 + print_step "Installing ${package_name}..." + cd "$package_dir" + + if [ -f "pyproject.toml" ]; then + if [ -n "$extras" ]; then + pip install -e ".[${extras}]" + else + pip install -e . + fi + else + print_error "No pyproject.toml found in ${package_dir}" + return 1 + fi + + cd "$PROJECT_ROOT" +} + +# Install packages in order of dependency +print_step "Installing packages in development mode..." + +# Install core first (base package with telemetry support) +install_package "libs/core" "core" + +# Install pylume (base dependency) +install_package "libs/pylume" "pylume" + +# Install computer (depends on pylume) +install_package "libs/computer" "computer" + +# Install omniparser +install_package "libs/som" "som" + +# Install agent with all its dependencies and extras +install_package "libs/agent" "agent" "all" + +# Install computer-server +install_package "libs/computer-server" "computer-server" + +# Install development tools from root project +print_step "Installing development dependencies..." +pip install -e ".[dev,test,docs]" + +# Create a .env file for VS Code to use the virtual environment +print_step "Creating .env file for VS Code..." +echo "PYTHONPATH=${PROJECT_ROOT}/libs/core:${PROJECT_ROOT}/libs/computer:${PROJECT_ROOT}/libs/agent:${PROJECT_ROOT}/libs/som:${PROJECT_ROOT}/libs/pylume:${PROJECT_ROOT}/libs/computer-server" > .env + +print_success "All packages installed successfully!" +print_step "Your virtual environment is ready. To activate it:" +echo " source .venv/bin/activate" diff --git a/scripts/cleanup.sh b/scripts/cleanup.sh new file mode 100755 index 00000000..789affd0 --- /dev/null +++ b/scripts/cleanup.sh @@ -0,0 +1,85 @@ +#!/bin/bash + +# Exit on error +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Function to print step information +print_step() { + echo -e "${BLUE}==> $1${NC}" +} + +# Function to print success message +print_success() { + echo -e "${GREEN}==> Success: $1${NC}" +} + +# Function to print error message +print_error() { + echo -e "${RED}==> Error: $1${NC}" >&2 +} + +# Get the script's directory +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +PROJECT_ROOT="$SCRIPT_DIR/.." + +# Change to project root +cd "$PROJECT_ROOT" + +print_step "Starting cleanup of all caches and virtual environments..." + +# Remove all virtual environments +print_step "Removing virtual environments..." +find . -type d -name ".venv" -exec rm -rf {} + +print_success "Virtual environments removed" + +# Remove all Python cache files and directories +print_step "Removing Python cache files and directories..." +find . -type d -name "__pycache__" -exec rm -rf {} + +find . -type d -name ".pytest_cache" -exec rm -rf {} + +find . -type d -name ".mypy_cache" -exec rm -rf {} + +find . -type d -name ".ruff_cache" -exec rm -rf {} + +find . -name "*.pyc" -delete +find . -name "*.pyo" -delete +find . -name "*.pyd" -delete +print_success "Python cache files removed" + +# Remove all build artifacts +print_step "Removing build artifacts..." +find . -type d -name "build" -exec rm -rf {} + +find . -type d -name "dist" -exec rm -rf {} + +find . -type d -name "*.egg-info" -exec rm -rf {} + +find . -type d -name "*.egg" -exec rm -rf {} + +print_success "Build artifacts removed" + +# Remove PDM-related files and directories +print_step "Removing PDM-related files and directories..." +find . -name "pdm.lock" -delete +find . -type d -name ".pdm-build" -exec rm -rf {} + +find . -name ".pdm-python" -delete # .pdm-python is a file, not a directory +print_success "PDM-related files removed" + +# Remove .env file +print_step "Removing .env file..." +rm -f .env +print_success ".env file removed" + +# Remove typings directory +print_step "Removing typings directory..." +rm -rf .vscode/typings +print_success "Typings directory removed" + +# Clean up any temporary files +print_step "Removing temporary files..." +find . -name "*.tmp" -delete +find . -name "*.bak" -delete +find . -name "*.swp" -delete +print_success "Temporary files removed" + +print_success "Cleanup complete! All caches and virtual environments have been removed." +print_step "To rebuild the project, run: bash scripts/build.sh"