fix: python sdk graceful shutdown and retry errors (#118)

* fix: python sdk graceful shutdown and retry errors

* chore: address changes from review
This commit is contained in:
abelanger5
2024-01-22 16:52:02 -08:00
committed by GitHub
parent cb4072efae
commit cfa4b5c8f4
13 changed files with 169 additions and 60 deletions

View File

@@ -3,7 +3,7 @@
Workers can be created via the `hatchet.worker()` method, after instantiating a `hatchet` instance. It will automatically read in any `HATCHET_CLIENT` environment variables, which can be set in the process by using something like `dotenv`. For example:
```py
from hatchet-sdk import Hatchet
from hatchet_sdk import Hatchet
from dotenv import load_dotenv
load_dotenv()

View File

@@ -3,7 +3,7 @@
To create a workflow, simply create a new class and use the `hatchet.workflow` and `hatchet.step` decorators to define the structure of your workflow. For example, a simple 2-step workflow would look like:
```py
from hatchet-sdk import Hatchet
from hatchet_sdk import Hatchet
hatchet = Hatchet()
@@ -66,7 +66,7 @@ Future steps can access this output by calling `context.step_output("<step>")`.
You can declare a cron schedule by passing `on_crons` to the `hatchet.workflow` decorator. For example, to trigger a workflow every 5 minutes, you can do the following:
```go
from hatchet-sdk import Hatchet
from hatchet_sdk import Hatchet
hatchet = Hatchet()

View File

@@ -3,7 +3,7 @@
Events can be pushed via the client's `client.event.push` method:
```py
from hatchet import new_client
from hatchet_sdk import new_client
client = new_client()

View File

@@ -1,4 +1,4 @@
from hatchet import new_client
from hatchet_sdk import new_client
client = new_client()

View File

@@ -1,4 +1,4 @@
from hatchet import Hatchet, Context
from hatchet_sdk import Hatchet, Context
hatchet = Hatchet()

View File

@@ -1,4 +1,7 @@
from hatchet import new_client
from hatchet_sdk import new_client
from dotenv import load_dotenv
load_dotenv()
client = new_client()

View File

@@ -1,6 +1,9 @@
from hatchet import Hatchet
from hatchet_sdk import Hatchet
from dotenv import load_dotenv
hatchet = Hatchet()
load_dotenv()
hatchet = Hatchet(debug=True)
@hatchet.workflow(on_events=["user:create"])
class MyWorkflow:

View File

@@ -4,6 +4,7 @@ from ..dispatcher_pb2_grpc import DispatcherStub
import time
from ..loader import ClientConfig
from ..logger import logger
import json
import grpc
from typing import Callable, List, Union
@@ -23,8 +24,10 @@ class DispatcherClient:
def send_action_event(self, ctx, in_):
raise NotImplementedError
DEFAULT_ACTION_LISTENER_RETRY_INTERVAL = 5 # seconds
DEFAULT_ACTION_LISTENER_RETRY_INTERVAL = 1 # seconds
DEFAULT_ACTION_LISTENER_RETRY_COUNT = 5
DEFAULT_ACTION_TIMEOUT = 60 # seconds
DEFAULT_REGISTER_TIMEOUT = 5
class GetActionListenerRequest:
def __init__(self, worker_name: str, services: List[str], actions: List[str]):
@@ -57,18 +60,21 @@ START_STEP_RUN = 0
CANCEL_STEP_RUN = 1
class ActionListenerImpl(WorkerActionListener):
def __init__(self, client : DispatcherStub, tenant_id, listen_client, worker_id):
def __init__(self, client : DispatcherStub, tenant_id, worker_id):
self.client = client
self.tenant_id = tenant_id
self.listen_client = listen_client
self.worker_id = worker_id
self.retries = 0
# self.logger = logger
# self.validator = validator
def actions(self):
while True:
logger.info("Listening for actions...")
try:
for assigned_action in self.listen_client:
for assigned_action in self.get_listen_client():
assigned_action : AssignedAction
# Process the received action
@@ -102,11 +108,16 @@ class ActionListenerImpl(WorkerActionListener):
break
elif e.code() == grpc.StatusCode.UNAVAILABLE:
# Retry logic
self.retry_subscribe()
logger.info("Could not connect to Hatchet, retrying...")
self.retries = self.retries + 1
elif e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logger.info("Deadline exceeded, retrying subscription")
continue
else:
# Unknown error, report and break
# self.logger.error(f"Failed to receive message: {e}")
# err_ch(e)
logger.error(f"Failed to receive message: {e}")
break
def parse_action_payload(self, payload : str):
@@ -124,30 +135,28 @@ class ActionListenerImpl(WorkerActionListener):
else:
# self.logger.error(f"Unknown action type: {action_type}")
return None
def retry_subscribe(self):
retries = 0
while retries < DEFAULT_ACTION_LISTENER_RETRY_COUNT:
try:
time.sleep(DEFAULT_ACTION_LISTENER_RETRY_INTERVAL)
self.listen_client = self.client.Listen(WorkerListenRequest(
def get_listen_client(self):
if self.retries > DEFAULT_ACTION_LISTENER_RETRY_COUNT:
raise Exception(f"Could not subscribe to the worker after {DEFAULT_ACTION_LISTENER_RETRY_COUNT} retries")
elif self.retries > 1:
# logger.info
# if we are retrying, we wait for a bit. this should eventually be replaced with exp backoff + jitter
time.sleep(DEFAULT_ACTION_LISTENER_RETRY_INTERVAL)
return self.client.Listen(WorkerListenRequest(
tenantId=self.tenant_id,
workerId=self.worker_id
))
return
except grpc.RpcError as e:
retries += 1
# self.logger.error(f"Failed to retry subscription: {e}")
raise Exception(f"Could not subscribe to the worker after {DEFAULT_ACTION_LISTENER_RETRY_COUNT} retries")
), timeout=DEFAULT_ACTION_TIMEOUT)
def unregister(self):
try:
self.client.Unsubscribe(
WorkerUnsubscribeRequest(
tenant_id=self.tenant_id,
worker_id=self.worker_id
)
tenantId=self.tenant_id,
workerId=self.worker_id
),
timeout=DEFAULT_REGISTER_TIMEOUT,
)
except grpc.RpcError as e:
raise Exception(f"Failed to unsubscribe: {e}")
@@ -166,15 +175,9 @@ class DispatcherClientImpl(DispatcherClient):
workerName=req.worker_name,
actions=req.actions,
services=req.services
))
), timeout=DEFAULT_REGISTER_TIMEOUT)
# Subscribe to the worker
listener = self.client.Listen(WorkerListenRequest(
tenantId=self.tenant_id,
workerId=response.workerId,
))
return ActionListenerImpl(self.client, self.tenant_id, listener, response.workerId)
return ActionListenerImpl(self.client, self.tenant_id, response.workerId)
def send_action_event(self, in_: ActionEvent):
response : ActionEventResponse = self.client.SendActionEvent(in_)

View File

@@ -2,12 +2,16 @@ from .client import new_client
from typing import List
from .workflow import WorkflowMeta
from .worker import Worker
from .logger import logger
class Hatchet:
def __init__(self):
def __init__(self, debug=False):
# initialize a client
self.client = new_client()
if not debug:
logger.disable("hatchet_sdk")
def workflow(self, name : str='', on_events : list=[], on_crons : list=[]):
def inner(cls):
cls.on_events = on_events

View File

@@ -0,0 +1,12 @@
import os
import sys
from loguru import logger
# loguru config
config = {
"handlers": [
{"sink": sys.stdout, "format": "hatchet -- {time} - {message}"},
],
}
logger.configure(**config)

View File

@@ -1,4 +1,8 @@
import json
import signal
import sys
import grpc
from typing import Any, Callable, Dict
from .workflow import WorkflowMeta
from .clients.dispatcher import GetActionListenerRequest, ActionListenerImpl, Action
@@ -7,15 +11,22 @@ from .client import new_client
from concurrent.futures import ThreadPoolExecutor, Future
from google.protobuf.timestamp_pb2 import Timestamp
from .context import Context
from .logger import logger
# Worker class
class Worker:
def __init__(self, name: str, max_threads: int = 200):
def __init__(self, name: str, max_threads: int = 200, debug=False, handle_kill=True):
self.name = name
self.thread_pool = ThreadPoolExecutor(max_workers=max_threads)
self.futures: Dict[str, Future] = {} # Store step run ids and futures
self.action_registry : dict[str, Callable[..., Any]] = {}
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
self.killing = False
self.handle_kill = handle_kill
def handle_start_step_run(self, action : Action):
action_name = action.action_id # Assuming action object has 'name' attribute
context = Context(action.action_payload) # Assuming action object has 'context' attribute
@@ -29,8 +40,7 @@ class Worker:
try:
output = future.result()
except Exception as e:
# TODO: handle errors
print("error:", e)
logger.error(f"Error on action finished event: {e}")
raise e
# TODO: case on cancelled errors and such
@@ -39,7 +49,7 @@ class Worker:
try:
event = self.get_action_finished_event(action, output)
except Exception as e:
print("error on action finished event:", e)
logger.error(f"Could not get action finished event: {e}")
raise e
# Send the action event to the dispatcher
@@ -59,7 +69,7 @@ class Worker:
try:
event = self.get_action_event(action, STEP_EVENT_TYPE_STARTED)
except Exception as e:
print("error on action event:", e)
logger.error(f"Could not create action event: {e}")
# Send the action event to the dispatcher
self.client.dispatcher.send_action_event(event)
@@ -96,7 +106,7 @@ class Worker:
try:
event = self.get_action_event(action, STEP_EVENT_TYPE_COMPLETED)
except Exception as e:
print("error on get action event:", e)
logger.error(f"Could not create action finished event: {e}")
raise e
output_bytes = ''
@@ -116,22 +126,52 @@ class Worker:
for action_name, action_func in workflow.get_actions():
self.action_registry[action_name] = create_action_function(action_func)
def exit_gracefully(self, signum, frame):
self.killing = True
# wait for futures to complete
for future in self.futures.values():
future.result()
try:
self.listener.unregister()
except Exception as e:
logger.error(f"Could not unregister worker: {e}")
if self.handle_kill:
logger.info("Exiting...")
sys.exit(0)
def start(self):
def start(self, retry_count=1):
logger.info("Starting worker...")
self.client = new_client()
listener : ActionListenerImpl = self.client.dispatcher.get_action_listener(GetActionListenerRequest(
worker_name="test-worker",
services=["default"],
actions=self.action_registry.keys(),
))
try:
self.listener : ActionListenerImpl = self.client.dispatcher.get_action_listener(GetActionListenerRequest(
worker_name=self.name,
services=["default"],
actions=self.action_registry.keys(),
))
generator = listener.actions()
generator = self.listener.actions()
for action in generator:
if action.action_type == ActionType.START_STEP_RUN:
self.handle_start_step_run(action)
elif action.action_type == ActionType.CANCEL_STEP_RUN:
self.handle_cancel_step_run(action)
for action in generator:
if action.action_type == ActionType.START_STEP_RUN:
self.handle_start_step_run(action)
elif action.action_type == ActionType.CANCEL_STEP_RUN:
self.handle_cancel_step_run(action)
pass # Replace this with your actual processing code
pass # Replace this with your actual processing code
except grpc.RpcError as rpc_error:
logger.error(f"Could not start worker: {rpc_error}")
# if we are here, but not killing, then we should retry start
if not self.killing:
if retry_count > 5:
raise Exception("Could not start worker after 5 retries")
logger.info("Could not start worker, retrying...")
self.start(retry_count + 1)

View File

@@ -1,5 +1,16 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "colorama"
version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "grpcio"
version = "1.60.0"
@@ -134,6 +145,24 @@ grpcio = ">=1.60.0"
protobuf = ">=4.21.6,<5.0dev"
setuptools = "*"
[[package]]
name = "loguru"
version = "0.7.2"
description = "Python logging made (stupidly) simple"
optional = false
python-versions = ">=3.5"
files = [
{file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"},
{file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"},
]
[package.dependencies]
colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
[package.extras]
dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"]
[[package]]
name = "protobuf"
version = "4.25.2"
@@ -243,7 +272,21 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
[[package]]
name = "win32-setctime"
version = "1.1.0"
description = "A small Python utility to set file creation time on Windows"
optional = false
python-versions = ">=3.5"
files = [
{file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"},
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
]
[package.extras]
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.8"
content-hash = "e552286ee1c1bf75f33fb747f3380cc865da3e6a0bc5a25fd25e2525d328fd40"
content-hash = "5f02802b37a104cad9bf4a41e3f263ef65e6359a18f18fd1e493afcc84b179ed"

View File

@@ -12,6 +12,7 @@ python-dotenv = "^1.0.0"
protobuf = "^4.25.2"
pyyaml = "^6.0.1"
grpcio-tools = "^1.60.0"
loguru = "^0.7.2"
[build-system]