From 949e5b9be8d2428b2c2bc04512686ff93b8b586c Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 27 Jul 2025 22:02:51 +0200 Subject: [PATCH] feat(rfdetr): add object detection API (#5923) Signed-off-by: Ettore Di Giacinto --- .github/workflows/backend.yml | 76 ++++++- Makefile | 9 + README.md | 2 + backend/backend.proto | 18 ++ backend/index.yaml | 81 ++++++++ backend/python/common/template/protogen.sh | 2 + backend/python/rfdetr/Makefile | 20 ++ backend/python/rfdetr/backend.py | 174 ++++++++++++++++ backend/python/rfdetr/install.sh | 19 ++ backend/python/rfdetr/protogen.sh | 13 ++ backend/python/rfdetr/requirements-cpu.txt | 7 + .../python/rfdetr/requirements-cublas11.txt | 8 + .../python/rfdetr/requirements-cublas12.txt | 7 + .../python/rfdetr/requirements-hipblas.txt | 9 + backend/python/rfdetr/requirements-intel.txt | 13 ++ backend/python/rfdetr/requirements.txt | 3 + backend/python/rfdetr/run.sh | 9 + backend/python/rfdetr/test.sh | 11 + core/backend/detection.go | 34 +++ core/config/backend_config.go | 8 + core/http/endpoints/localai/detection.go | 59 ++++++ core/http/routes/localai.go | 5 + core/http/views/backends.html | 8 + core/http/views/models.html | 8 + core/schema/localai.go | 17 ++ .../content/docs/features/object-detection.md | 193 ++++++++++++++++++ gallery/index.yaml | 22 ++ pkg/grpc/backend.go | 3 +- pkg/grpc/base/base.go | 4 + pkg/grpc/client.go | 22 ++ pkg/grpc/embed.go | 4 + pkg/grpc/interface.go | 3 +- pkg/grpc/server.go | 18 +- pkg/utils/base64.go | 2 +- 34 files changed, 884 insertions(+), 7 deletions(-) create mode 100644 backend/python/rfdetr/Makefile create mode 100755 backend/python/rfdetr/backend.py create mode 100755 backend/python/rfdetr/install.sh create mode 100644 backend/python/rfdetr/protogen.sh create mode 100644 backend/python/rfdetr/requirements-cpu.txt create mode 100644 backend/python/rfdetr/requirements-cublas11.txt create mode 100644 backend/python/rfdetr/requirements-cublas12.txt create mode 100644 backend/python/rfdetr/requirements-hipblas.txt create mode 100644 backend/python/rfdetr/requirements-intel.txt create mode 100644 backend/python/rfdetr/requirements.txt create mode 100755 backend/python/rfdetr/run.sh create mode 100755 backend/python/rfdetr/test.sh create mode 100644 core/backend/detection.go create mode 100644 core/http/endpoints/localai/detection.go create mode 100644 docs/content/docs/features/object-detection.md diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index cf22dea2a..af9b1eb9c 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -868,7 +868,81 @@ jobs: skip-drivers: 'false' backend: "huggingface" dockerfile: "./backend/Dockerfile.golang" - context: "./" + context: "./" + # rfdetr + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64,linux/arm64' + tag-latest: 'auto' + tag-suffix: '-cpu-rfdetr' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:22.04" + skip-drivers: 'false' + backend: "rfdetr" + dockerfile: "./backend/Dockerfile.python" + context: "./backend" + - build-type: 'cublas' + cuda-major-version: "12" + cuda-minor-version: "0" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-12-rfdetr' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:22.04" + skip-drivers: 'false' + backend: "rfdetr" + dockerfile: "./backend/Dockerfile.python" + context: "./backend" + - build-type: 'cublas' + cuda-major-version: "11" + cuda-minor-version: "7" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-11-rfdetr' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:22.04" + skip-drivers: 'false' + backend: "rfdetr" + dockerfile: "./backend/Dockerfile.python" + context: "./backend" + - build-type: 'intel' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-intel-rfdetr' + runs-on: 'ubuntu-latest' + base-image: "quay.io/go-skynet/intel-oneapi-base:latest" + skip-drivers: 'false' + backend: "rfdetr" + dockerfile: "./backend/Dockerfile.python" + context: "./backend" + - build-type: 'cublas' + cuda-major-version: "12" + cuda-minor-version: "0" + platforms: 'linux/arm64' + skip-drivers: 'true' + tag-latest: 'auto' + tag-suffix: '-nvidia-l4t-arm64-rfdetr' + base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" + runs-on: 'ubuntu-24.04-arm' + backend: "rfdetr" + dockerfile: "./backend/Dockerfile.python" + context: "./backend" + # runs out of space on the runner + # - build-type: 'hipblas' + # cuda-major-version: "" + # cuda-minor-version: "" + # platforms: 'linux/amd64' + # tag-latest: 'auto' + # tag-suffix: '-gpu-hipblas-rfdetr' + # base-image: "rocm/dev-ubuntu-22.04:6.1" + # runs-on: 'ubuntu-latest' + # skip-drivers: 'false' + # backend: "rfdetr" + # dockerfile: "./backend/Dockerfile.python" + # context: "./backend" llama-cpp-darwin: runs-on: macOS-14 strategy: diff --git a/Makefile b/Makefile index e36c72a1d..1f2730a63 100644 --- a/Makefile +++ b/Makefile @@ -155,6 +155,9 @@ backends/local-store: docker-build-local-store docker-save-local-store build backends/huggingface: docker-build-huggingface docker-save-huggingface build ./local-ai backends install "ocifile://$(abspath ./backend-images/huggingface.tar)" +backends/rfdetr: docker-build-rfdetr docker-save-rfdetr build + ./local-ai backends install "ocifile://$(abspath ./backend-images/rfdetr.tar)" + ######################################################## ## AIO tests ######################################################## @@ -373,6 +376,12 @@ docker-build-local-store: docker-build-huggingface: docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:huggingface -f backend/Dockerfile.golang --build-arg BACKEND=huggingface . +docker-build-rfdetr: + docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:rfdetr -f backend/Dockerfile.python --build-arg BACKEND=rfdetr ./backend + +docker-save-rfdetr: backend-images + docker save local-ai-backend:rfdetr -o backend-images/rfdetr.tar + docker-save-huggingface: backend-images docker save local-ai-backend:huggingface -o backend-images/huggingface.tar diff --git a/README.md b/README.md index 904e00efa..3521dffdd 100644 --- a/README.md +++ b/README.md @@ -195,6 +195,7 @@ For more information, see [💻 Getting started](https://localai.io/basics/getti ## 📰 Latest project news +- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr) - July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0) - June 2025: [Backend management](https://github.com/mudler/LocalAI/pull/5607) has been added. Attention: extras images are going to be deprecated from the next release! Read [the backend management PR](https://github.com/mudler/LocalAI/pull/5607). - May 2025: [Audio input](https://github.com/mudler/LocalAI/pull/5466) and [Reranking](https://github.com/mudler/LocalAI/pull/5396) in llama.cpp backend, [Realtime API](https://github.com/mudler/LocalAI/pull/5392), Support to Gemma, SmollVLM, and more multimodal models (available in the gallery). @@ -228,6 +229,7 @@ Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3A - ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/) - 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/) - 🥽 [Vision API](https://localai.io/features/gpt-vision/) +- 🔍 [Object Detection](https://localai.io/features/object-detection/) - 📈 [Reranker API](https://localai.io/features/reranker/) - 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/) - [Agentic capabilities](https://github.com/mudler/LocalAGI) diff --git a/backend/backend.proto b/backend/backend.proto index 18ce66155..4acd8504d 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -20,6 +20,7 @@ service Backend { rpc SoundGeneration(SoundGenerationRequest) returns (Result) {} rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {} rpc Status(HealthMessage) returns (StatusResponse) {} + rpc Detect(DetectOptions) returns (DetectResponse) {} rpc StoresSet(StoresSetOptions) returns (Result) {} rpc StoresDelete(StoresDeleteOptions) returns (Result) {} @@ -376,3 +377,20 @@ message Message { string role = 1; string content = 2; } + +message DetectOptions { + string src = 1; +} + +message Detection { + float x = 1; + float y = 2; + float width = 3; + float height = 4; + float confidence = 5; + string class_name = 6; +} + +message DetectResponse { + repeated Detection Detections = 1; +} diff --git a/backend/index.yaml b/backend/index.yaml index aafd4a724..e93a3b4f3 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -73,6 +73,28 @@ nvidia-l4t: "nvidia-l4t-arm64-stablediffusion-ggml" # metal: "metal-stablediffusion-ggml" # darwin-x86: "darwin-x86-stablediffusion-ggml" +- &rfdetr + name: "rfdetr" + alias: "rfdetr" + license: apache-2.0 + icon: https://avatars.githubusercontent.com/u/53104118?s=200&v=4 + description: | + RF-DETR is a real-time, transformer-based object detection model architecture developed by Roboflow and released under the Apache 2.0 license. + RF-DETR is the first real-time model to exceed 60 AP on the Microsoft COCO benchmark alongside competitive performance at base sizes. It also achieves state-of-the-art performance on RF100-VL, an object detection benchmark that measures model domain adaptability to real world problems. RF-DETR is fastest and most accurate for its size when compared current real-time objection models. + RF-DETR is small enough to run on the edge using Inference, making it an ideal model for deployments that need both strong accuracy and real-time performance. + urls: + - https://github.com/roboflow/rf-detr + tags: + - object-detection + - rfdetr + - gpu + - cpu + capabilities: + nvidia: "cuda12-rfdetr" + intel: "intel-rfdetr" + #amd: "rocm-rfdetr" + nvidia-l4t: "nvidia-l4t-arm64-rfdetr" + default: "cpu-rfdetr" - &vllm name: "vllm" license: apache-2.0 @@ -663,6 +685,65 @@ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-vllm" mirrors: - localai/localai-backends:master-gpu-intel-sycl-f16-vllm +# rfdetr +- !!merge <<: *rfdetr + name: "rfdetr-development" + capabilities: + nvidia: "cuda12-rfdetr-development" + intel: "intel-rfdetr-development" + #amd: "rocm-rfdetr-development" + nvidia-l4t: "nvidia-l4t-arm64-rfdetr-development" + default: "cpu-rfdetr-development" +- !!merge <<: *rfdetr + name: "cuda12-rfdetr" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-rfdetr" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-12-rfdetr +- !!merge <<: *rfdetr + name: "intel-rfdetr" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-rfdetr" + mirrors: + - localai/localai-backends:latest-gpu-intel-rfdetr +# - !!merge <<: *rfdetr +# name: "rocm-rfdetr" +# uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-hipblas-rfdetr" +# mirrors: +# - localai/localai-backends:latest-gpu-hipblas-rfdetr +- !!merge <<: *rfdetr + name: "nvidia-l4t-arm64-rfdetr" + uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-rfdetr" + mirrors: + - localai/localai-backends:latest-nvidia-l4t-arm64-rfdetr +- !!merge <<: *rfdetr + name: "cpu-rfdetr" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-rfdetr" + mirrors: + - localai/localai-backends:latest-cpu-rfdetr +- !!merge <<: *rfdetr + name: "cuda12-rfdetr-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-rfdetr" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-12-rfdetr +- !!merge <<: *rfdetr + name: "intel-rfdetr-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-rfdetr" + mirrors: + - localai/localai-backends:master-gpu-intel-rfdetr +# - !!merge <<: *rfdetr +# name: "rocm-rfdetr-development" +# uri: "quay.io/go-skynet/local-ai-backends:master-gpu-hipblas-rfdetr" +# mirrors: +# - localai/localai-backends:master-gpu-hipblas-rfdetr +- !!merge <<: *rfdetr + name: "cpu-rfdetr-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cpu-rfdetr" + mirrors: + - localai/localai-backends:master-cpu-rfdetr +- !!merge <<: *rfdetr + name: "intel-rfdetr" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-rfdetr" + mirrors: + - localai/localai-backends:latest-gpu-intel-rfdetr ## Rerankers - !!merge <<: *rerankers name: "rerankers-development" diff --git a/backend/python/common/template/protogen.sh b/backend/python/common/template/protogen.sh index d608379c1..0569b6c6e 100644 --- a/backend/python/common/template/protogen.sh +++ b/backend/python/common/template/protogen.sh @@ -8,4 +8,6 @@ else source $backend_dir/../common/libbackend.sh fi +ensureVenv + python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto \ No newline at end of file diff --git a/backend/python/rfdetr/Makefile b/backend/python/rfdetr/Makefile new file mode 100644 index 000000000..c0e5169f7 --- /dev/null +++ b/backend/python/rfdetr/Makefile @@ -0,0 +1,20 @@ +.DEFAULT_GOAL := install + +.PHONY: install +install: + bash install.sh + $(MAKE) protogen + +.PHONY: protogen +protogen: backend_pb2_grpc.py backend_pb2.py + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +backend_pb2_grpc.py backend_pb2.py: + bash protogen.sh + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ \ No newline at end of file diff --git a/backend/python/rfdetr/backend.py b/backend/python/rfdetr/backend.py new file mode 100755 index 000000000..57f68647f --- /dev/null +++ b/backend/python/rfdetr/backend.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +""" +gRPC server for RFDETR object detection models. +""" +from concurrent import futures + +import argparse +import signal +import sys +import os +import time +import base64 +import backend_pb2 +import backend_pb2_grpc +import grpc + +import requests + +import supervision as sv +from inference import get_model +from PIL import Image +from io import BytesIO + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + """ + A gRPC servicer for the RFDETR backend service. + + This class implements the gRPC methods for object detection using RFDETR models. + """ + + def __init__(self): + self.model = None + self.model_name = None + + def Health(self, request, context): + """ + A gRPC method that returns the health status of the backend service. + + Args: + request: A HealthMessage object that contains the request parameters. + context: A grpc.ServicerContext object that provides information about the RPC. + + Returns: + A Reply object that contains the health status of the backend service. + """ + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + def LoadModel(self, request, context): + """ + A gRPC method that loads a RFDETR model into memory. + + Args: + request: A ModelOptions object that contains the model parameters. + context: A grpc.ServicerContext object that provides information about the RPC. + + Returns: + A Result object that contains the result of the LoadModel operation. + """ + model_name = request.Model + try: + # Load the RFDETR model + self.model = get_model(model_name) + self.model_name = model_name + print(f'Loaded RFDETR model: {model_name}') + except Exception as err: + return backend_pb2.Result(success=False, message=f"Failed to load model: {err}") + + return backend_pb2.Result(message="Model loaded successfully", success=True) + + def Detect(self, request, context): + """ + A gRPC method that performs object detection on an image. + + Args: + request: A DetectOptions object that contains the image source. + context: A grpc.ServicerContext object that provides information about the RPC. + + Returns: + A DetectResponse object that contains the detection results. + """ + if self.model is None: + print(f"Model is None") + return backend_pb2.DetectResponse() + print(f"Model is not None") + try: + print(f"Decoding image") + # Decode the base64 image + print(f"Image data: {request.src}") + + image_data = base64.b64decode(request.src) + image = Image.open(BytesIO(image_data)) + + # Perform inference + predictions = self.model.infer(image, confidence=0.5)[0] + + # Convert to proto format + proto_detections = [] + for i in range(len(predictions.predictions)): + pred = predictions.predictions[i] + print(f"Prediction: {pred}") + proto_detection = backend_pb2.Detection( + x=float(pred.x), + y=float(pred.y), + width=float(pred.width), + height=float(pred.height), + confidence=float(pred.confidence), + class_name=pred.class_name + ) + proto_detections.append(proto_detection) + + return backend_pb2.DetectResponse(Detections=proto_detections) + except Exception as err: + print(f"Detection error: {err}") + return backend_pb2.DetectResponse() + + def Status(self, request, context): + """ + A gRPC method that returns the status of the backend service. + + Args: + request: A HealthMessage object that contains the request parameters. + context: A grpc.ServicerContext object that provides information about the RPC. + + Returns: + A StatusResponse object that contains the status information. + """ + state = backend_pb2.StatusResponse.READY if self.model is not None else backend_pb2.StatusResponse.UNINITIALIZED + return backend_pb2.StatusResponse(state=state) + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB + ]) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("[RFDETR] Server started. Listening on: " + address, file=sys.stderr) + + # Define the signal handler function + def signal_handler(sig, frame): + print("[RFDETR] Received termination signal. Shutting down...") + server.stop(0) + sys.exit(0) + + # Set the signal handlers for SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the RFDETR gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + print(f"[RFDETR] startup: {args}", file=sys.stderr) + serve(args.addr) + + + diff --git a/backend/python/rfdetr/install.sh b/backend/python/rfdetr/install.sh new file mode 100755 index 000000000..32befa8e6 --- /dev/null +++ b/backend/python/rfdetr/install.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. +# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. +# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index +# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index +if [ "x${BUILD_PROFILE}" == "xintel" ]; then + EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" +fi + +installRequirements diff --git a/backend/python/rfdetr/protogen.sh b/backend/python/rfdetr/protogen.sh new file mode 100644 index 000000000..0569b6c6e --- /dev/null +++ b/backend/python/rfdetr/protogen.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +ensureVenv + +python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto \ No newline at end of file diff --git a/backend/python/rfdetr/requirements-cpu.txt b/backend/python/rfdetr/requirements-cpu.txt new file mode 100644 index 000000000..d0d1f4afa --- /dev/null +++ b/backend/python/rfdetr/requirements-cpu.txt @@ -0,0 +1,7 @@ +rfdetr +opencv-python +accelerate +peft +inference +torch==2.7.1 +optimum-quanto \ No newline at end of file diff --git a/backend/python/rfdetr/requirements-cublas11.txt b/backend/python/rfdetr/requirements-cublas11.txt new file mode 100644 index 000000000..14449b3d4 --- /dev/null +++ b/backend/python/rfdetr/requirements-cublas11.txt @@ -0,0 +1,8 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 +torch==2.7.1+cu118 +rfdetr +opencv-python +accelerate +inference +peft +optimum-quanto \ No newline at end of file diff --git a/backend/python/rfdetr/requirements-cublas12.txt b/backend/python/rfdetr/requirements-cublas12.txt new file mode 100644 index 000000000..36eaa47bb --- /dev/null +++ b/backend/python/rfdetr/requirements-cublas12.txt @@ -0,0 +1,7 @@ +torch==2.7.1 +rfdetr +opencv-python +accelerate +inference +peft +optimum-quanto \ No newline at end of file diff --git a/backend/python/rfdetr/requirements-hipblas.txt b/backend/python/rfdetr/requirements-hipblas.txt new file mode 100644 index 000000000..536a31efb --- /dev/null +++ b/backend/python/rfdetr/requirements-hipblas.txt @@ -0,0 +1,9 @@ +--extra-index-url https://download.pytorch.org/whl/rocm6.3 +torch==2.7.1+rocm6.3 +torchvision==0.22.1+rocm6.3 +rfdetr +opencv-python +accelerate +inference +peft +optimum-quanto \ No newline at end of file diff --git a/backend/python/rfdetr/requirements-intel.txt b/backend/python/rfdetr/requirements-intel.txt new file mode 100644 index 000000000..55fcbb318 --- /dev/null +++ b/backend/python/rfdetr/requirements-intel.txt @@ -0,0 +1,13 @@ +--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +intel-extension-for-pytorch==2.3.110+xpu +torch==2.3.1+cxx11.abi +torchvision==0.18.1+cxx11.abi +oneccl_bind_pt==2.3.100+xpu +optimum[openvino] +setuptools +rfdetr +inference +opencv-python +accelerate +peft +optimum-quanto \ No newline at end of file diff --git a/backend/python/rfdetr/requirements.txt b/backend/python/rfdetr/requirements.txt new file mode 100644 index 000000000..44b40efd0 --- /dev/null +++ b/backend/python/rfdetr/requirements.txt @@ -0,0 +1,3 @@ +grpcio==1.71.0 +protobuf +grpcio-tools diff --git a/backend/python/rfdetr/run.sh b/backend/python/rfdetr/run.sh new file mode 100755 index 000000000..82b7b09ec --- /dev/null +++ b/backend/python/rfdetr/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ \ No newline at end of file diff --git a/backend/python/rfdetr/test.sh b/backend/python/rfdetr/test.sh new file mode 100755 index 000000000..eb59f2aaf --- /dev/null +++ b/backend/python/rfdetr/test.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests diff --git a/core/backend/detection.go b/core/backend/detection.go new file mode 100644 index 000000000..dc89560a8 --- /dev/null +++ b/core/backend/detection.go @@ -0,0 +1,34 @@ +package backend + +import ( + "context" + "fmt" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/model" +) + +func Detection( + sourceFile string, + loader *model.ModelLoader, + appConfig *config.ApplicationConfig, + backendConfig config.BackendConfig, +) (*proto.DetectResponse, error) { + opts := ModelOptions(backendConfig, appConfig) + detectionModel, err := loader.Load(opts...) + if err != nil { + return nil, err + } + defer loader.Close() + + if detectionModel == nil { + return nil, fmt.Errorf("could not load detection model") + } + + res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{ + Src: sourceFile, + }) + + return res, err +} diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 113ae131a..739f876a0 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -458,6 +458,7 @@ const ( FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000 FLAG_VAD BackendConfigUsecases = 0b010000000000 FLAG_VIDEO BackendConfigUsecases = 0b100000000000 + FLAG_DETECTION BackendConfigUsecases = 0b1000000000000 // Common Subsets FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT @@ -479,6 +480,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases { "FLAG_VAD": FLAG_VAD, "FLAG_LLM": FLAG_LLM, "FLAG_VIDEO": FLAG_VIDEO, + "FLAG_DETECTION": FLAG_DETECTION, } } @@ -572,6 +574,12 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool { } } + if (u & FLAG_DETECTION) == FLAG_DETECTION { + if c.Backend != "rfdetr" { + return false + } + } + if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION { if c.Backend != "transformers-musicgen" { return false diff --git a/core/http/endpoints/localai/detection.go b/core/http/endpoints/localai/detection.go new file mode 100644 index 000000000..496a64c10 --- /dev/null +++ b/core/http/endpoints/localai/detection.go @@ -0,0 +1,59 @@ +package localai + +import ( + "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" +) + +// DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection +// @Summary Detects objects in the input image. +// @Param request body schema.DetectionRequest true "query params" +// @Success 200 {object} schema.DetectionResponse "Response" +// @Router /v1/detection [post] +func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest) + if !ok || input.Model == "" { + return fiber.ErrBadRequest + } + + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || cfg == nil { + return fiber.ErrBadRequest + } + + log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection") + + image, err := utils.GetContentURIAsBase64(input.Image) + if err != nil { + return err + } + + res, err := backend.Detection(image, ml, appConfig, *cfg) + if err != nil { + return err + } + + response := schema.DetectionResponse{ + Detections: make([]schema.Detection, len(res.Detections)), + } + for i, detection := range res.Detections { + response.Detections[i] = schema.Detection{ + X: detection.X, + Y: detection.Y, + Width: detection.Width, + Height: detection.Height, + ClassName: detection.ClassName, + } + } + + return c.JSON(response) + } +} diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 8abadfcb9..ce9e8496b 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -41,6 +41,11 @@ func RegisterLocalAIRoutes(router *fiber.App, router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint()) } + router.Post("/v1/detection", + requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_DETECTION)), + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.DetectionRequest) }), + localai.DetectionEndpoint(cl, ml, appConfig)) + router.Post("/tts", requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)), requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }), diff --git a/core/http/views/backends.html b/core/http/views/backends.html index c08d0e907..23f7d43de 100644 --- a/core/http/views/backends.html +++ b/core/http/views/backends.html @@ -90,6 +90,14 @@ hx-indicator=".htmx-indicator"> Whisper + diff --git a/core/http/views/models.html b/core/http/views/models.html index 3e243118a..b1f8d7cb4 100644 --- a/core/http/views/models.html +++ b/core/http/views/models.html @@ -115,6 +115,14 @@ hx-indicator=".htmx-indicator"> Audio transcription + diff --git a/core/schema/localai.go b/core/schema/localai.go index 4e819238a..d093faafe 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -120,3 +120,20 @@ type SystemInformationResponse struct { Backends []string `json:"backends"` Models []SysInfoModel `json:"loaded_models"` } + +type DetectionRequest struct { + BasicModelRequest + Image string `json:"image"` +} + +type DetectionResponse struct { + Detections []Detection `json:"detections"` +} + +type Detection struct { + X float32 `json:"x"` + Y float32 `json:"y"` + Width float32 `json:"width"` + Height float32 `json:"height"` + ClassName string `json:"class_name"` +} diff --git a/docs/content/docs/features/object-detection.md b/docs/content/docs/features/object-detection.md new file mode 100644 index 000000000..5126e3001 --- /dev/null +++ b/docs/content/docs/features/object-detection.md @@ -0,0 +1,193 @@ ++++ +disableToc = false +title = "🔍 Object detection" +weight = 13 +url = "/features/object-detection/" ++++ + +LocalAI supports object detection through various backends. This feature allows you to identify and locate objects within images with high accuracy and real-time performance. Currently, [RF-DETR](https://github.com/roboflow/rf-detr) is available as an implementation. + +## Overview + +Object detection in LocalAI is implemented through dedicated backends that can identify and locate objects within images. Each backend provides different capabilities and model architectures. + +**Key Features:** +- Real-time object detection +- High accuracy detection with bounding boxes +- Support for multiple hardware accelerators (CPU, NVIDIA GPU, Intel GPU, AMD GPU) +- Structured detection results with confidence scores +- Easy integration through the `/v1/detection` endpoint + +## Usage + +### Detection Endpoint + +LocalAI provides a dedicated `/v1/detection` endpoint for object detection tasks. This endpoint is specifically designed for object detection and returns structured detection results with bounding boxes and confidence scores. + +### API Reference + +To perform object detection, send a POST request to the `/v1/detection` endpoint: + +```bash +curl -X POST http://localhost:8080/v1/detection \ + -H "Content-Type: application/json" \ + -d '{ + "model": "rfdetr-base", + "image": "https://media.roboflow.com/dog.jpeg" + }' +``` + +### Request Format + +The request body should contain: + +- `model`: The name of the object detection model (e.g., "rfdetr-base") +- `image`: The image to analyze, which can be: + - A URL to an image + - A base64-encoded image + +### Response Format + +The API returns a JSON response with detected objects: + +```json +{ + "detections": [ + { + "x": 100.5, + "y": 150.2, + "width": 200.0, + "height": 300.0, + "confidence": 0.95, + "class_name": "dog" + }, + { + "x": 400.0, + "y": 200.0, + "width": 150.0, + "height": 250.0, + "confidence": 0.87, + "class_name": "person" + } + ] +} +``` + +Each detection includes: +- `x`, `y`: Coordinates of the bounding box top-left corner +- `width`, `height`: Dimensions of the bounding box +- `confidence`: Detection confidence score (0.0 to 1.0) +- `class_name`: The detected object class + +## Backends + +### RF-DETR Backend + +The RF-DETR backend is implemented as a Python-based gRPC service that integrates seamlessly with LocalAI. It provides object detection capabilities using the RF-DETR model architecture and supports multiple hardware configurations: + +- **CPU**: Optimized for CPU inference +- **NVIDIA GPU**: CUDA acceleration for NVIDIA GPUs +- **Intel GPU**: Intel oneAPI optimization +- **AMD GPU**: ROCm acceleration for AMD GPUs +- **NVIDIA Jetson**: Optimized for ARM64 NVIDIA Jetson devices + +#### Setup + +1. **Using the Model Gallery (Recommended)** + + The easiest way to get started is using the model gallery. The `rfdetr-base` model is available in the official LocalAI gallery: + + ```bash + # Install and run the rfdetr-base model + local-ai run rfdetr-base + ``` + + You can also install it through the web interface by navigating to the Models section and searching for "rfdetr-base". + +2. **Manual Configuration** + + Create a model configuration file in your `models` directory: + + ```yaml + name: rfdetr + backend: rfdetr + parameters: + model: rfdetr-base + ``` + +#### Available Models + +Currently, the following model is available in the [Model Gallery]({{%relref "docs/features/model-gallery" %}}): + +- **rfdetr-base**: Base model with balanced performance and accuracy + +You can browse and install this model through the LocalAI web interface or using the command line. + +## Examples + +### Basic Object Detection + +```bash +# Detect objects in an image from URL +curl -X POST http://localhost:8080/v1/detection \ + -H "Content-Type: application/json" \ + -d '{ + "model": "rfdetr-base", + "image": "https://example.com/image.jpg" + }' +``` + +### Base64 Image Detection + +```bash +# Convert image to base64 and send +base64_image=$(base64 -w 0 image.jpg) +curl -X POST http://localhost:8080/v1/detection \ + -H "Content-Type: application/json" \ + -d "{ + \"model\": \"rfdetr-base\", + \"image\": \"data:image/jpeg;base64,$base64_image\" + }" +``` + +## Troubleshooting + +### Common Issues + +1. **Model Loading Errors** + - Ensure the model file is properly downloaded + - Check available disk space + - Verify model compatibility with your backend version + +2. **Low Detection Accuracy** + - Ensure good image quality and lighting + - Check if objects are clearly visible + - Consider using a larger model for better accuracy + +3. **Slow Performance** + - Enable GPU acceleration if available + - Use a smaller model for faster inference + - Optimize image resolution + +### Debug Mode + +Enable debug logging for troubleshooting: + +```bash +local-ai run --debug rfdetr-base +``` + +## Object Detection Category + +LocalAI includes a dedicated **object-detection** category for models and backends that specialize in identifying and locating objects within images. This category currently includes: + +- **RF-DETR**: Real-time transformer-based object detection + +Additional object detection models and backends will be added to this category in the future. You can filter models by the `object-detection` tag in the model gallery to find all available object detection models. + +## Related Features + +- [🎨 Image generation]({{%relref "docs/features/image-generation" %}}): Generate images with AI +- [📖 Text generation]({{%relref "docs/features/text-generation" %}}): Generate text with language models +- [🔍 GPT Vision]({{%relref "docs/features/gpt-vision" %}}): Analyze images with language models +- [🚀 GPU acceleration]({{%relref "docs/features/GPU-acceleration" %}}): Optimize performance with GPU acceleration \ No newline at end of file diff --git a/gallery/index.yaml b/gallery/index.yaml index 898fc75e2..27c34368a 100644 --- a/gallery/index.yaml +++ b/gallery/index.yaml @@ -1,4 +1,26 @@ --- +- &rfdetr + name: "rfdetr-base" + url: "github:mudler/LocalAI/gallery/virtual.yaml@master" + icon: https://avatars.githubusercontent.com/u/53104118?s=200&v=4 + license: apache-2.0 + description: | + RF-DETR is a real-time, transformer-based object detection model architecture developed by Roboflow and released under the Apache 2.0 license. + RF-DETR is the first real-time model to exceed 60 AP on the Microsoft COCO benchmark alongside competitive performance at base sizes. It also achieves state-of-the-art performance on RF100-VL, an object detection benchmark that measures model domain adaptability to real world problems. RF-DETR is fastest and most accurate for its size when compared current real-time objection models. + RF-DETR is small enough to run on the edge using Inference, making it an ideal model for deployments that need both strong accuracy and real-time performance. + tags: + - object-detection + - rfdetr + - gpu + - cpu + urls: + - https://github.com/roboflow/rf-detr + overrides: + backend: rfdetr + parameters: + model: rfdetr-base + known_usecases: + - detection - name: "dream-org_dream-v0-instruct-7b" # chatml url: "github:mudler/LocalAI/gallery/chatml.yaml@master" diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index ac3fe757e..846f7231d 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -9,7 +9,7 @@ import ( var embeds = map[string]*embedBackend{} -func Provide(addr string, llm LLM) { +func Provide(addr string, llm AIModel) { embeds[addr] = &embedBackend{s: &server{llm: llm}} } @@ -42,6 +42,7 @@ type Backend interface { GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) + Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 40a775cb2..e59db2e15 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -69,6 +69,10 @@ func (llm *Base) SoundGeneration(*pb.SoundGenerationRequest) error { return fmt.Errorf("unimplemented") } +func (llm *Base) Detect(*pb.DetectOptions) (pb.DetectResponse, error) { + return pb.DetectResponse{}, fmt.Errorf("unimplemented") +} + func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) { return pb.TokenizationResponse{}, fmt.Errorf("unimplemented") } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 78e1421d0..f0f9a930e 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -504,3 +504,25 @@ func (c *Client) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOp client := pb.NewBackendClient(conn) return client.VAD(ctx, in, opts...) } + +func (c *Client) Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB + grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB + )) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.Detect(ctx, in, opts...) +} diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 417b38904..3369ce0fc 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -59,6 +59,10 @@ func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerati return e.s.SoundGeneration(ctx, in) } +func (e *embedBackend) Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error) { + return e.s.Detect(ctx, in) +} + func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) { return e.s.AudioTranscription(ctx, in) } diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 35c5d9779..66c38f430 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -4,7 +4,7 @@ import ( pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) -type LLM interface { +type AIModel interface { Busy() bool Lock() Unlock() @@ -15,6 +15,7 @@ type LLM interface { Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error GenerateVideo(*pb.GenerateVideoRequest) error + Detect(*pb.DetectOptions) (pb.DetectResponse, error) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) TTS(*pb.TTSRequest) error SoundGeneration(*pb.SoundGenerationRequest) error diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 546ed291c..30962e8c8 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -22,7 +22,7 @@ import ( // server is used to implement helloworld.GreeterServer. type server struct { pb.UnimplementedBackendServer - llm LLM + llm AIModel } func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) { @@ -111,6 +111,18 @@ func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequ return &pb.Result{Message: "Sound Generation audio generated", Success: true}, nil } +func (s *server) Detect(ctx context.Context, in *pb.DetectOptions) (*pb.DetectResponse, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + res, err := s.llm.Detect(in) + if err != nil { + return nil, err + } + return &res, nil +} + func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { if s.llm.Locking() { s.llm.Lock() @@ -251,7 +263,7 @@ func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, e return &res, nil } -func StartServer(address string, model LLM) error { +func StartServer(address string, model AIModel) error { lis, err := net.Listen("tcp", address) if err != nil { return err @@ -269,7 +281,7 @@ func StartServer(address string, model LLM) error { return nil } -func RunServer(address string, model LLM) (func() error, error) { +func RunServer(address string, model AIModel) (func() error, error) { lis, err := net.Listen("tcp", address) if err != nil { return nil, err diff --git a/pkg/utils/base64.go b/pkg/utils/base64.go index 174e80fa3..5c33af463 100644 --- a/pkg/utils/base64.go +++ b/pkg/utils/base64.go @@ -20,7 +20,7 @@ var dataURIPattern = regexp.MustCompile(`^data:([^;]+);base64,`) // GetContentURIAsBase64 checks if the string is an URL, if it's an URL downloads the content in memory encodes it in base64 and returns the base64 string, otherwise returns the string by stripping base64 data headers func GetContentURIAsBase64(s string) (string, error) { - if strings.HasPrefix(s, "http") { + if strings.HasPrefix(s, "http") || strings.HasPrefix(s, "https") { // download the image resp, err := base64DownloadClient.Get(s) if err != nil {