mirror of
https://github.com/trycua/lume.git
synced 2026-01-06 12:29:56 -06:00
Format codebase with uv run pre-commit run --all-files
This commit is contained in:
@@ -9,60 +9,56 @@ Models are loaded/unloaded one at a time to avoid memory issues.
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from utils import (
|
||||
ModelWrapper,
|
||||
take_screenshot,
|
||||
get_available_models,
|
||||
save_prediction_visualization,
|
||||
get_available_models
|
||||
take_screenshot,
|
||||
)
|
||||
|
||||
|
||||
async def predict_with_all_models(image, instruction: str, models) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Predict click coordinates with all models sequentially.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Instruction text
|
||||
models: List of model instances
|
||||
|
||||
|
||||
Returns:
|
||||
List of prediction results
|
||||
"""
|
||||
predictions = []
|
||||
|
||||
|
||||
for model in models:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
print(f"\n🔄 Loading {model_wrapper.model_name}...")
|
||||
|
||||
|
||||
try:
|
||||
# Load model
|
||||
await model_wrapper.load_model()
|
||||
|
||||
|
||||
# Predict
|
||||
coords = await model_wrapper.predict_click(image, instruction)
|
||||
|
||||
predictions.append({
|
||||
'model_name': model_wrapper.model_name,
|
||||
'coords': coords,
|
||||
'error': None
|
||||
})
|
||||
|
||||
|
||||
predictions.append(
|
||||
{"model_name": model_wrapper.model_name, "coords": coords, "error": None}
|
||||
)
|
||||
|
||||
if coords:
|
||||
print(f"✅ {model_wrapper.model_name}: ({coords[0]}, {coords[1]})")
|
||||
else:
|
||||
print(f"❌ {model_wrapper.model_name}: No prediction")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {model_wrapper.model_name}: ERROR - {str(e)}")
|
||||
predictions.append({
|
||||
'model_name': model_wrapper.model_name,
|
||||
'coords': None,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
predictions.append(
|
||||
{"model_name": model_wrapper.model_name, "coords": None, "error": str(e)}
|
||||
)
|
||||
|
||||
finally:
|
||||
# Always unload model to free memory
|
||||
try:
|
||||
@@ -70,7 +66,7 @@ async def predict_with_all_models(image, instruction: str, models) -> List[Dict[
|
||||
print(f"🗑️ Unloaded {model_wrapper.model_name}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error unloading {model_wrapper.model_name}: {e}")
|
||||
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
@@ -103,87 +99,91 @@ async def main():
|
||||
Main interactive loop.
|
||||
"""
|
||||
print_header()
|
||||
|
||||
|
||||
# Get available models
|
||||
models = get_available_models()
|
||||
print_models(models)
|
||||
|
||||
|
||||
# Create output directory for visualizations
|
||||
output_dir = "interactive_output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
session_count = 0
|
||||
last_screenshot = None
|
||||
screenshot_timestamp = None
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get user input
|
||||
print(f"\n{'='*40}")
|
||||
user_input = input("🎯 Enter instruction (or command): ").strip()
|
||||
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
|
||||
# Handle commands
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
|
||||
elif user_input.lower() == 'models':
|
||||
|
||||
elif user_input.lower() == "models":
|
||||
print_models(models)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == 'screenshot':
|
||||
|
||||
elif user_input.lower() == "screenshot":
|
||||
print("📸 Taking screenshot...")
|
||||
try:
|
||||
last_screenshot = take_screenshot()
|
||||
screenshot_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
screenshot_path = os.path.join(output_dir, f"screenshot_{screenshot_timestamp}.png")
|
||||
screenshot_path = os.path.join(
|
||||
output_dir, f"screenshot_{screenshot_timestamp}.png"
|
||||
)
|
||||
last_screenshot.save(screenshot_path)
|
||||
print(f"✅ Screenshot captured and saved to: {screenshot_path}")
|
||||
print(f"📝 Ready for instructions! Screenshot size: {last_screenshot.size}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error taking screenshot: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Handle instruction input
|
||||
if last_screenshot is None:
|
||||
print("⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command.")
|
||||
print(
|
||||
"⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command."
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
session_count += 1
|
||||
print(f"\n🎯 Session {session_count}: '{user_input}'")
|
||||
print(f"📷 Using screenshot from: {screenshot_timestamp}")
|
||||
|
||||
|
||||
# Predict with all models using last screenshot
|
||||
print(f"\n🤖 Testing {len(models)} models on screenshot...")
|
||||
predictions = await predict_with_all_models(last_screenshot, user_input, models)
|
||||
|
||||
|
||||
# Display results summary
|
||||
print(f"\n📊 Results Summary:")
|
||||
print("\n📊 Results Summary:")
|
||||
print("-" * 50)
|
||||
for pred in predictions:
|
||||
if pred['coords']:
|
||||
if pred["coords"]:
|
||||
print(f"✅ {pred['model_name']}: ({pred['coords'][0]}, {pred['coords'][1]})")
|
||||
elif pred['error']:
|
||||
elif pred["error"]:
|
||||
print(f"❌ {pred['model_name']}: ERROR - {pred['error']}")
|
||||
else:
|
||||
print(f"❌ {pred['model_name']}: No prediction")
|
||||
|
||||
|
||||
# Save visualization
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
vis_filename = f"session_{session_count:03d}_{timestamp}.png"
|
||||
vis_path = os.path.join(output_dir, vis_filename)
|
||||
|
||||
|
||||
try:
|
||||
save_prediction_visualization(last_screenshot, user_input, predictions, vis_path)
|
||||
print(f"\n💾 Visualization saved to: {vis_path}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error saving visualization: {e}")
|
||||
|
||||
|
||||
print(f"\n✨ Session {session_count} completed!")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 Interrupted by user. Goodbye!")
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user