Files
hatchet/python-client/hatchet_sdk/worker.py
abelanger5 cfa4b5c8f4 fix: python sdk graceful shutdown and retry errors (#118)
* fix: python sdk graceful shutdown and retry errors

* chore: address changes from review
2024-01-22 19:52:02 -05:00

177 lines
6.4 KiB
Python

import json
import signal
import sys
import grpc
from typing import Any, Callable, Dict
from .workflow import WorkflowMeta
from .clients.dispatcher import GetActionListenerRequest, ActionListenerImpl, Action
from .dispatcher_pb2 import ActionType, ActionEvent, ActionEventType, STEP_EVENT_TYPE_COMPLETED, STEP_EVENT_TYPE_STARTED
from .client import new_client
from concurrent.futures import ThreadPoolExecutor, Future
from google.protobuf.timestamp_pb2 import Timestamp
from .context import Context
from .logger import logger
# Worker class
class Worker:
def __init__(self, name: str, max_threads: int = 200, debug=False, handle_kill=True):
self.name = name
self.thread_pool = ThreadPoolExecutor(max_workers=max_threads)
self.futures: Dict[str, Future] = {} # Store step run ids and futures
self.action_registry : dict[str, Callable[..., Any]] = {}
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
self.killing = False
self.handle_kill = handle_kill
def handle_start_step_run(self, action : Action):
action_name = action.action_id # Assuming action object has 'name' attribute
context = Context(action.action_payload) # Assuming action object has 'context' attribute
# Find the corresponding action function from the registry
action_func = self.action_registry.get(action_name)
if action_func:
def callback(future : Future):
# Get the output from the future
try:
output = future.result()
except Exception as e:
logger.error(f"Error on action finished event: {e}")
raise e
# TODO: case on cancelled errors and such
# Create an action event
try:
event = self.get_action_finished_event(action, output)
except Exception as e:
logger.error(f"Could not get action finished event: {e}")
raise e
# Send the action event to the dispatcher
self.client.dispatcher.send_action_event(event)
# Remove the future from the dictionary
del self.futures[action.step_run_id]
del self.futures[action.step_run_id + "_callback"]
# Submit the action to the thread pool
future = self.thread_pool.submit(action_func, context)
callback = self.thread_pool.submit(callback, future)
self.futures[action.step_run_id] = future
self.futures[action.step_run_id + "_callback"] = callback
# send an event that the step run has started
try:
event = self.get_action_event(action, STEP_EVENT_TYPE_STARTED)
except Exception as e:
logger.error(f"Could not create action event: {e}")
# Send the action event to the dispatcher
self.client.dispatcher.send_action_event(event)
def handle_cancel_step_run(self, action : Action):
step_run_id = action.step_run_id
future = self.futures.get(step_run_id)
if future:
future.cancel()
del self.futures[step_run_id]
def get_action_event(self, action : Action, event_type : ActionEventType) -> ActionEvent:
# timestamp
# eventTimestamp = datetime.datetime.now(datetime.timezone.utc)
# eventTimestamp = eventTimestamp.isoformat()
eventTimestamp = Timestamp()
eventTimestamp.GetCurrentTime()
return ActionEvent(
tenantId=action.tenant_id,
workerId=action.worker_id,
jobId=action.job_id,
jobRunId=action.job_run_id,
stepId=action.step_id,
stepRunId=action.step_run_id,
actionId=action.action_id,
eventTimestamp=eventTimestamp,
eventType=event_type,
)
def get_action_finished_event(self, action : Action, output : Any) -> ActionEvent:
try:
event = self.get_action_event(action, STEP_EVENT_TYPE_COMPLETED)
except Exception as e:
logger.error(f"Could not create action finished event: {e}")
raise e
output_bytes = ''
if output is not None:
output_bytes = json.dumps(output)
event.eventPayload = output_bytes
return event
def register_workflow(self, workflow : WorkflowMeta):
def create_action_function(action_func):
def action_function(context):
return action_func(workflow, context)
return action_function
for action_name, action_func in workflow.get_actions():
self.action_registry[action_name] = create_action_function(action_func)
def exit_gracefully(self, signum, frame):
self.killing = True
# wait for futures to complete
for future in self.futures.values():
future.result()
try:
self.listener.unregister()
except Exception as e:
logger.error(f"Could not unregister worker: {e}")
if self.handle_kill:
logger.info("Exiting...")
sys.exit(0)
def start(self, retry_count=1):
logger.info("Starting worker...")
self.client = new_client()
try:
self.listener : ActionListenerImpl = self.client.dispatcher.get_action_listener(GetActionListenerRequest(
worker_name=self.name,
services=["default"],
actions=self.action_registry.keys(),
))
generator = self.listener.actions()
for action in generator:
if action.action_type == ActionType.START_STEP_RUN:
self.handle_start_step_run(action)
elif action.action_type == ActionType.CANCEL_STEP_RUN:
self.handle_cancel_step_run(action)
pass # Replace this with your actual processing code
except grpc.RpcError as rpc_error:
logger.error(f"Could not start worker: {rpc_error}")
# if we are here, but not killing, then we should retry start
if not self.killing:
if retry_count > 5:
raise Exception("Could not start worker after 5 retries")
logger.info("Could not start worker, retrying...")
self.start(retry_count + 1)