mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-02-10 18:19:08 -06:00
fix: python sdk graceful shutdown and retry errors (#118)
* fix: python sdk graceful shutdown and retry errors * chore: address changes from review
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
Workers can be created via the `hatchet.worker()` method, after instantiating a `hatchet` instance. It will automatically read in any `HATCHET_CLIENT` environment variables, which can be set in the process by using something like `dotenv`. For example:
|
||||
|
||||
```py
|
||||
from hatchet-sdk import Hatchet
|
||||
from hatchet_sdk import Hatchet
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
To create a workflow, simply create a new class and use the `hatchet.workflow` and `hatchet.step` decorators to define the structure of your workflow. For example, a simple 2-step workflow would look like:
|
||||
|
||||
```py
|
||||
from hatchet-sdk import Hatchet
|
||||
from hatchet_sdk import Hatchet
|
||||
|
||||
hatchet = Hatchet()
|
||||
|
||||
@@ -66,7 +66,7 @@ Future steps can access this output by calling `context.step_output("<step>")`.
|
||||
You can declare a cron schedule by passing `on_crons` to the `hatchet.workflow` decorator. For example, to trigger a workflow every 5 minutes, you can do the following:
|
||||
|
||||
```go
|
||||
from hatchet-sdk import Hatchet
|
||||
from hatchet_sdk import Hatchet
|
||||
|
||||
hatchet = Hatchet()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Events can be pushed via the client's `client.event.push` method:
|
||||
|
||||
```py
|
||||
from hatchet import new_client
|
||||
from hatchet_sdk import new_client
|
||||
|
||||
client = new_client()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from hatchet import new_client
|
||||
from hatchet_sdk import new_client
|
||||
|
||||
client = new_client()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from hatchet import Hatchet, Context
|
||||
from hatchet_sdk import Hatchet, Context
|
||||
|
||||
hatchet = Hatchet()
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from hatchet import new_client
|
||||
from hatchet_sdk import new_client
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
client = new_client()
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from hatchet import Hatchet
|
||||
from hatchet_sdk import Hatchet
|
||||
from dotenv import load_dotenv
|
||||
|
||||
hatchet = Hatchet()
|
||||
load_dotenv()
|
||||
|
||||
hatchet = Hatchet(debug=True)
|
||||
|
||||
@hatchet.workflow(on_events=["user:create"])
|
||||
class MyWorkflow:
|
||||
|
||||
@@ -4,6 +4,7 @@ from ..dispatcher_pb2_grpc import DispatcherStub
|
||||
|
||||
import time
|
||||
from ..loader import ClientConfig
|
||||
from ..logger import logger
|
||||
import json
|
||||
import grpc
|
||||
from typing import Callable, List, Union
|
||||
@@ -23,8 +24,10 @@ class DispatcherClient:
|
||||
def send_action_event(self, ctx, in_):
|
||||
raise NotImplementedError
|
||||
|
||||
DEFAULT_ACTION_LISTENER_RETRY_INTERVAL = 5 # seconds
|
||||
DEFAULT_ACTION_LISTENER_RETRY_INTERVAL = 1 # seconds
|
||||
DEFAULT_ACTION_LISTENER_RETRY_COUNT = 5
|
||||
DEFAULT_ACTION_TIMEOUT = 60 # seconds
|
||||
DEFAULT_REGISTER_TIMEOUT = 5
|
||||
|
||||
class GetActionListenerRequest:
|
||||
def __init__(self, worker_name: str, services: List[str], actions: List[str]):
|
||||
@@ -57,18 +60,21 @@ START_STEP_RUN = 0
|
||||
CANCEL_STEP_RUN = 1
|
||||
|
||||
class ActionListenerImpl(WorkerActionListener):
|
||||
def __init__(self, client : DispatcherStub, tenant_id, listen_client, worker_id):
|
||||
def __init__(self, client : DispatcherStub, tenant_id, worker_id):
|
||||
self.client = client
|
||||
self.tenant_id = tenant_id
|
||||
self.listen_client = listen_client
|
||||
self.worker_id = worker_id
|
||||
self.retries = 0
|
||||
|
||||
# self.logger = logger
|
||||
# self.validator = validator
|
||||
|
||||
def actions(self):
|
||||
while True:
|
||||
logger.info("Listening for actions...")
|
||||
|
||||
try:
|
||||
for assigned_action in self.listen_client:
|
||||
for assigned_action in self.get_listen_client():
|
||||
assigned_action : AssignedAction
|
||||
|
||||
# Process the received action
|
||||
@@ -102,11 +108,16 @@ class ActionListenerImpl(WorkerActionListener):
|
||||
break
|
||||
elif e.code() == grpc.StatusCode.UNAVAILABLE:
|
||||
# Retry logic
|
||||
self.retry_subscribe()
|
||||
logger.info("Could not connect to Hatchet, retrying...")
|
||||
self.retries = self.retries + 1
|
||||
elif e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
||||
logger.info("Deadline exceeded, retrying subscription")
|
||||
continue
|
||||
else:
|
||||
# Unknown error, report and break
|
||||
# self.logger.error(f"Failed to receive message: {e}")
|
||||
# err_ch(e)
|
||||
logger.error(f"Failed to receive message: {e}")
|
||||
break
|
||||
|
||||
def parse_action_payload(self, payload : str):
|
||||
@@ -124,30 +135,28 @@ class ActionListenerImpl(WorkerActionListener):
|
||||
else:
|
||||
# self.logger.error(f"Unknown action type: {action_type}")
|
||||
return None
|
||||
|
||||
def retry_subscribe(self):
|
||||
retries = 0
|
||||
while retries < DEFAULT_ACTION_LISTENER_RETRY_COUNT:
|
||||
try:
|
||||
time.sleep(DEFAULT_ACTION_LISTENER_RETRY_INTERVAL)
|
||||
self.listen_client = self.client.Listen(WorkerListenRequest(
|
||||
|
||||
def get_listen_client(self):
|
||||
if self.retries > DEFAULT_ACTION_LISTENER_RETRY_COUNT:
|
||||
raise Exception(f"Could not subscribe to the worker after {DEFAULT_ACTION_LISTENER_RETRY_COUNT} retries")
|
||||
elif self.retries > 1:
|
||||
# logger.info
|
||||
# if we are retrying, we wait for a bit. this should eventually be replaced with exp backoff + jitter
|
||||
time.sleep(DEFAULT_ACTION_LISTENER_RETRY_INTERVAL)
|
||||
|
||||
return self.client.Listen(WorkerListenRequest(
|
||||
tenantId=self.tenant_id,
|
||||
workerId=self.worker_id
|
||||
))
|
||||
return
|
||||
except grpc.RpcError as e:
|
||||
retries += 1
|
||||
# self.logger.error(f"Failed to retry subscription: {e}")
|
||||
|
||||
raise Exception(f"Could not subscribe to the worker after {DEFAULT_ACTION_LISTENER_RETRY_COUNT} retries")
|
||||
), timeout=DEFAULT_ACTION_TIMEOUT)
|
||||
|
||||
def unregister(self):
|
||||
try:
|
||||
self.client.Unsubscribe(
|
||||
WorkerUnsubscribeRequest(
|
||||
tenant_id=self.tenant_id,
|
||||
worker_id=self.worker_id
|
||||
)
|
||||
tenantId=self.tenant_id,
|
||||
workerId=self.worker_id
|
||||
),
|
||||
timeout=DEFAULT_REGISTER_TIMEOUT,
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
raise Exception(f"Failed to unsubscribe: {e}")
|
||||
@@ -166,15 +175,9 @@ class DispatcherClientImpl(DispatcherClient):
|
||||
workerName=req.worker_name,
|
||||
actions=req.actions,
|
||||
services=req.services
|
||||
))
|
||||
), timeout=DEFAULT_REGISTER_TIMEOUT)
|
||||
|
||||
# Subscribe to the worker
|
||||
listener = self.client.Listen(WorkerListenRequest(
|
||||
tenantId=self.tenant_id,
|
||||
workerId=response.workerId,
|
||||
))
|
||||
|
||||
return ActionListenerImpl(self.client, self.tenant_id, listener, response.workerId)
|
||||
return ActionListenerImpl(self.client, self.tenant_id, response.workerId)
|
||||
|
||||
def send_action_event(self, in_: ActionEvent):
|
||||
response : ActionEventResponse = self.client.SendActionEvent(in_)
|
||||
|
||||
@@ -2,12 +2,16 @@ from .client import new_client
|
||||
from typing import List
|
||||
from .workflow import WorkflowMeta
|
||||
from .worker import Worker
|
||||
from .logger import logger
|
||||
|
||||
class Hatchet:
|
||||
def __init__(self):
|
||||
def __init__(self, debug=False):
|
||||
# initialize a client
|
||||
self.client = new_client()
|
||||
|
||||
if not debug:
|
||||
logger.disable("hatchet_sdk")
|
||||
|
||||
def workflow(self, name : str='', on_events : list=[], on_crons : list=[]):
|
||||
def inner(cls):
|
||||
cls.on_events = on_events
|
||||
|
||||
12
python-client/hatchet_sdk/logger.py
Normal file
12
python-client/hatchet_sdk/logger.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import os
|
||||
import sys
|
||||
from loguru import logger
|
||||
|
||||
# loguru config
|
||||
config = {
|
||||
"handlers": [
|
||||
{"sink": sys.stdout, "format": "hatchet -- {time} - {message}"},
|
||||
],
|
||||
}
|
||||
|
||||
logger.configure(**config)
|
||||
@@ -1,4 +1,8 @@
|
||||
import json
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import grpc
|
||||
from typing import Any, Callable, Dict
|
||||
from .workflow import WorkflowMeta
|
||||
from .clients.dispatcher import GetActionListenerRequest, ActionListenerImpl, Action
|
||||
@@ -7,15 +11,22 @@ from .client import new_client
|
||||
from concurrent.futures import ThreadPoolExecutor, Future
|
||||
from google.protobuf.timestamp_pb2 import Timestamp
|
||||
from .context import Context
|
||||
from .logger import logger
|
||||
|
||||
# Worker class
|
||||
class Worker:
|
||||
def __init__(self, name: str, max_threads: int = 200):
|
||||
def __init__(self, name: str, max_threads: int = 200, debug=False, handle_kill=True):
|
||||
self.name = name
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=max_threads)
|
||||
self.futures: Dict[str, Future] = {} # Store step run ids and futures
|
||||
self.action_registry : dict[str, Callable[..., Any]] = {}
|
||||
|
||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||
|
||||
self.killing = False
|
||||
self.handle_kill = handle_kill
|
||||
|
||||
def handle_start_step_run(self, action : Action):
|
||||
action_name = action.action_id # Assuming action object has 'name' attribute
|
||||
context = Context(action.action_payload) # Assuming action object has 'context' attribute
|
||||
@@ -29,8 +40,7 @@ class Worker:
|
||||
try:
|
||||
output = future.result()
|
||||
except Exception as e:
|
||||
# TODO: handle errors
|
||||
print("error:", e)
|
||||
logger.error(f"Error on action finished event: {e}")
|
||||
raise e
|
||||
|
||||
# TODO: case on cancelled errors and such
|
||||
@@ -39,7 +49,7 @@ class Worker:
|
||||
try:
|
||||
event = self.get_action_finished_event(action, output)
|
||||
except Exception as e:
|
||||
print("error on action finished event:", e)
|
||||
logger.error(f"Could not get action finished event: {e}")
|
||||
raise e
|
||||
|
||||
# Send the action event to the dispatcher
|
||||
@@ -59,7 +69,7 @@ class Worker:
|
||||
try:
|
||||
event = self.get_action_event(action, STEP_EVENT_TYPE_STARTED)
|
||||
except Exception as e:
|
||||
print("error on action event:", e)
|
||||
logger.error(f"Could not create action event: {e}")
|
||||
|
||||
# Send the action event to the dispatcher
|
||||
self.client.dispatcher.send_action_event(event)
|
||||
@@ -96,7 +106,7 @@ class Worker:
|
||||
try:
|
||||
event = self.get_action_event(action, STEP_EVENT_TYPE_COMPLETED)
|
||||
except Exception as e:
|
||||
print("error on get action event:", e)
|
||||
logger.error(f"Could not create action finished event: {e}")
|
||||
raise e
|
||||
|
||||
output_bytes = ''
|
||||
@@ -116,22 +126,52 @@ class Worker:
|
||||
|
||||
for action_name, action_func in workflow.get_actions():
|
||||
self.action_registry[action_name] = create_action_function(action_func)
|
||||
|
||||
def exit_gracefully(self, signum, frame):
|
||||
self.killing = True
|
||||
|
||||
# wait for futures to complete
|
||||
for future in self.futures.values():
|
||||
future.result()
|
||||
|
||||
try:
|
||||
self.listener.unregister()
|
||||
except Exception as e:
|
||||
logger.error(f"Could not unregister worker: {e}")
|
||||
|
||||
if self.handle_kill:
|
||||
logger.info("Exiting...")
|
||||
sys.exit(0)
|
||||
|
||||
def start(self):
|
||||
def start(self, retry_count=1):
|
||||
logger.info("Starting worker...")
|
||||
|
||||
self.client = new_client()
|
||||
|
||||
listener : ActionListenerImpl = self.client.dispatcher.get_action_listener(GetActionListenerRequest(
|
||||
worker_name="test-worker",
|
||||
services=["default"],
|
||||
actions=self.action_registry.keys(),
|
||||
))
|
||||
try:
|
||||
self.listener : ActionListenerImpl = self.client.dispatcher.get_action_listener(GetActionListenerRequest(
|
||||
worker_name=self.name,
|
||||
services=["default"],
|
||||
actions=self.action_registry.keys(),
|
||||
))
|
||||
|
||||
generator = listener.actions()
|
||||
generator = self.listener.actions()
|
||||
|
||||
for action in generator:
|
||||
if action.action_type == ActionType.START_STEP_RUN:
|
||||
self.handle_start_step_run(action)
|
||||
elif action.action_type == ActionType.CANCEL_STEP_RUN:
|
||||
self.handle_cancel_step_run(action)
|
||||
for action in generator:
|
||||
if action.action_type == ActionType.START_STEP_RUN:
|
||||
self.handle_start_step_run(action)
|
||||
elif action.action_type == ActionType.CANCEL_STEP_RUN:
|
||||
self.handle_cancel_step_run(action)
|
||||
|
||||
pass # Replace this with your actual processing code
|
||||
pass # Replace this with your actual processing code
|
||||
except grpc.RpcError as rpc_error:
|
||||
logger.error(f"Could not start worker: {rpc_error}")
|
||||
|
||||
# if we are here, but not killing, then we should retry start
|
||||
if not self.killing:
|
||||
if retry_count > 5:
|
||||
raise Exception("Could not start worker after 5 retries")
|
||||
|
||||
logger.info("Could not start worker, retrying...")
|
||||
|
||||
self.start(retry_count + 1)
|
||||
45
python-client/poetry.lock
generated
45
python-client/poetry.lock
generated
@@ -1,5 +1,16 @@
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
description = "Cross-platform colored terminal text."
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "grpcio"
|
||||
version = "1.60.0"
|
||||
@@ -134,6 +145,24 @@ grpcio = ">=1.60.0"
|
||||
protobuf = ">=4.21.6,<5.0dev"
|
||||
setuptools = "*"
|
||||
|
||||
[[package]]
|
||||
name = "loguru"
|
||||
version = "0.7.2"
|
||||
description = "Python logging made (stupidly) simple"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"},
|
||||
{file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
|
||||
win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "4.25.2"
|
||||
@@ -243,7 +272,21 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments
|
||||
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
||||
|
||||
[[package]]
|
||||
name = "win32-setctime"
|
||||
version = "1.1.0"
|
||||
description = "A small Python utility to set file creation time on Windows"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"},
|
||||
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "e552286ee1c1bf75f33fb747f3380cc865da3e6a0bc5a25fd25e2525d328fd40"
|
||||
content-hash = "5f02802b37a104cad9bf4a41e3f263ef65e6359a18f18fd1e493afcc84b179ed"
|
||||
|
||||
@@ -12,6 +12,7 @@ python-dotenv = "^1.0.0"
|
||||
protobuf = "^4.25.2"
|
||||
pyyaml = "^6.0.1"
|
||||
grpcio-tools = "^1.60.0"
|
||||
loguru = "^0.7.2"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
||||
Reference in New Issue
Block a user