feat: concurrency groups (#135)

* first pass at moving controllers around

* feat: concurrency limits for strategy CANCEL_IN_PROGRESS

* fix: linting

* chore: bump python sdk version
This commit is contained in:
abelanger5
2024-01-29 21:00:28 -08:00
committed by GitHub
parent 9841aa52d7
commit d63b66a837
72 changed files with 4777 additions and 924 deletions

View File

@@ -0,0 +1,13 @@
from hatchet_sdk import new_client
from dotenv import load_dotenv
load_dotenv()
client = new_client()
client.event.push(
"concurrency-test",
{
"test": "test"
}
)

View File

@@ -0,0 +1,32 @@
from hatchet_sdk import Hatchet
from dotenv import load_dotenv
load_dotenv()
hatchet = Hatchet(debug=True)
@hatchet.workflow(on_events=["concurrency-test"])
class ConcurrencyDemoWorkflow:
def __init__(self):
self.my_value = "test"
@hatchet.concurrency(max_runs=5)
def concurrency(self, context) -> str:
return "concurrency-key"
@hatchet.step()
def step1(self, context):
print("executed step1")
pass
@hatchet.step(parents=["step1"],timeout='4s')
def step2(self, context):
print("started step2")
context.sleep(1)
print("finished step2")
workflow = ConcurrencyDemoWorkflow()
worker = hatchet.worker('concurrency-demo-worker', max_threads=4)
worker.register_workflow(workflow)
worker.start()

View File

@@ -1,5 +1,5 @@
# relative imports
from ..dispatcher_pb2 import ActionEvent, ActionEventResponse, ActionType, AssignedAction, WorkerListenRequest, WorkerRegisterRequest, WorkerUnsubscribeRequest, WorkerRegisterResponse
from ..dispatcher_pb2 import GroupKeyActionEvent, StepActionEvent, ActionEventResponse, ActionType, AssignedAction, WorkerListenRequest, WorkerRegisterRequest, WorkerUnsubscribeRequest, WorkerRegisterResponse
from ..dispatcher_pb2_grpc import DispatcherStub
import time
@@ -21,7 +21,7 @@ class DispatcherClient:
def get_action_listener(self, ctx, req):
raise NotImplementedError
def send_action_event(self, ctx, in_):
def send_step_action_event(self, ctx, in_):
raise NotImplementedError
DEFAULT_ACTION_LISTENER_RETRY_INTERVAL = 1 # seconds
@@ -36,8 +36,10 @@ class GetActionListenerRequest:
self.actions = actions
class Action:
def __init__(self, worker_id: str, tenant_id: str, job_id: str, job_name: str, job_run_id: str, step_id: str, step_run_id: str, action_id: str, action_payload: str, action_type: ActionType):
def __init__(self, worker_id: str, tenant_id: str, workflow_run_id: str, get_group_key_run_id: str, job_id: str, job_name: str, job_run_id: str, step_id: str, step_run_id: str, action_id: str, action_payload: str, action_type: ActionType):
self.worker_id = worker_id
self.workflow_run_id = workflow_run_id
self.get_group_key_run_id = get_group_key_run_id
self.tenant_id = tenant_id
self.job_id = job_id
self.job_name = job_name
@@ -55,9 +57,9 @@ class WorkerActionListener:
def unregister(self):
raise NotImplementedError
# enum for START_STEP_RUN and CANCEL_STEP_RUN
START_STEP_RUN = 0
CANCEL_STEP_RUN = 1
START_GET_GROUP_KEY = 2
class ActionListenerImpl(WorkerActionListener):
def __init__(self, client : DispatcherStub, token, worker_id):
@@ -88,6 +90,8 @@ class ActionListenerImpl(WorkerActionListener):
action = Action(
tenant_id=assigned_action.tenantId,
worker_id=self.worker_id,
workflow_run_id=assigned_action.workflowRunId,
get_group_key_run_id=assigned_action.getGroupKeyRunId,
job_id=assigned_action.jobId,
job_name=assigned_action.jobName,
job_run_id=assigned_action.jobRunId,
@@ -132,6 +136,8 @@ class ActionListenerImpl(WorkerActionListener):
return START_STEP_RUN
elif action_type == ActionType.CANCEL_STEP_RUN:
return CANCEL_STEP_RUN
elif action_type == ActionType.START_GET_GROUP_KEY:
return START_GET_GROUP_KEY
else:
# self.logger.error(f"Unknown action type: {action_type}")
return None
@@ -180,8 +186,13 @@ class DispatcherClientImpl(DispatcherClient):
return ActionListenerImpl(self.client, self.token, response.workerId)
def send_action_event(self, in_: ActionEvent):
response : ActionEventResponse = self.client.SendActionEvent(in_, metadata=get_metadata(self.token),)
def send_step_action_event(self, in_: StepActionEvent):
response : ActionEventResponse = self.client.SendStepActionEvent(in_, metadata=get_metadata(self.token),)
return response
def send_group_key_action_event(self, in_: GroupKeyActionEvent):
response : ActionEventResponse = self.client.SendGroupKeyActionEvent(in_, metadata=get_metadata(self.token),)
return response

View File

@@ -15,7 +15,7 @@ _sym_db = _symbol_database.Default()
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x64ispatcher.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"N\n\x15WorkerRegisterRequest\x12\x12\n\nworkerName\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x63tions\x18\x02 \x03(\t\x12\x10\n\x08services\x18\x03 \x03(\t\"P\n\x16WorkerRegisterResponse\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x10\n\x08workerId\x18\x02 \x01(\t\x12\x12\n\nworkerName\x18\x03 \x01(\t\"\xc1\x01\n\x0e\x41ssignedAction\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\r\n\x05jobId\x18\x02 \x01(\t\x12\x0f\n\x07jobName\x18\x03 \x01(\t\x12\x10\n\x08jobRunId\x18\x04 \x01(\t\x12\x0e\n\x06stepId\x18\x05 \x01(\t\x12\x11\n\tstepRunId\x18\x06 \x01(\t\x12\x10\n\x08\x61\x63tionId\x18\x07 \x01(\t\x12\x1f\n\nactionType\x18\x08 \x01(\x0e\x32\x0b.ActionType\x12\x15\n\ractionPayload\x18\t \x01(\t\"\'\n\x13WorkerListenRequest\x12\x10\n\x08workerId\x18\x01 \x01(\t\",\n\x18WorkerUnsubscribeRequest\x12\x10\n\x08workerId\x18\x01 \x01(\t\"?\n\x19WorkerUnsubscribeResponse\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x10\n\x08workerId\x18\x02 \x01(\t\"\xe4\x01\n\x0b\x41\x63tionEvent\x12\x10\n\x08workerId\x18\x01 \x01(\t\x12\r\n\x05jobId\x18\x02 \x01(\t\x12\x10\n\x08jobRunId\x18\x03 \x01(\t\x12\x0e\n\x06stepId\x18\x04 \x01(\t\x12\x11\n\tstepRunId\x18\x05 \x01(\t\x12\x10\n\x08\x61\x63tionId\x18\x06 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12#\n\teventType\x18\x08 \x01(\x0e\x32\x10.ActionEventType\x12\x14\n\x0c\x65ventPayload\x18\t \x01(\t\"9\n\x13\x41\x63tionEventResponse\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x10\n\x08workerId\x18\x02 \x01(\t*5\n\nActionType\x12\x12\n\x0eSTART_STEP_RUN\x10\x00\x12\x13\n\x0f\x43\x41NCEL_STEP_RUN\x10\x01*\x86\x01\n\x0f\x41\x63tionEventType\x12\x1b\n\x17STEP_EVENT_TYPE_UNKNOWN\x10\x00\x12\x1b\n\x17STEP_EVENT_TYPE_STARTED\x10\x01\x12\x1d\n\x19STEP_EVENT_TYPE_COMPLETED\x10\x02\x12\x1a\n\x16STEP_EVENT_TYPE_FAILED\x10\x03\x32\x81\x02\n\nDispatcher\x12=\n\x08Register\x12\x16.WorkerRegisterRequest\x1a\x17.WorkerRegisterResponse\"\x00\x12\x33\n\x06Listen\x12\x14.WorkerListenRequest\x1a\x0f.AssignedAction\"\x00\x30\x01\x12\x37\n\x0fSendActionEvent\x12\x0c.ActionEvent\x1a\x14.ActionEventResponse\"\x00\x12\x46\n\x0bUnsubscribe\x12\x19.WorkerUnsubscribeRequest\x1a\x1a.WorkerUnsubscribeResponse\"\x00\x42GZEgithub.com/hatchet-dev/hatchet/internal/services/dispatcher/contractsb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x64ispatcher.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"N\n\x15WorkerRegisterRequest\x12\x12\n\nworkerName\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x63tions\x18\x02 \x03(\t\x12\x10\n\x08services\x18\x03 \x03(\t\"P\n\x16WorkerRegisterResponse\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x10\n\x08workerId\x18\x02 \x01(\t\x12\x12\n\nworkerName\x18\x03 \x01(\t\"\xf2\x01\n\x0e\x41ssignedAction\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x15\n\rworkflowRunId\x18\x02 \x01(\t\x12\x18\n\x10getGroupKeyRunId\x18\x03 \x01(\t\x12\r\n\x05jobId\x18\x04 \x01(\t\x12\x0f\n\x07jobName\x18\x05 \x01(\t\x12\x10\n\x08jobRunId\x18\x06 \x01(\t\x12\x0e\n\x06stepId\x18\x07 \x01(\t\x12\x11\n\tstepRunId\x18\x08 \x01(\t\x12\x10\n\x08\x61\x63tionId\x18\t \x01(\t\x12\x1f\n\nactionType\x18\n \x01(\x0e\x32\x0b.ActionType\x12\x15\n\ractionPayload\x18\x0b \x01(\t\"\'\n\x13WorkerListenRequest\x12\x10\n\x08workerId\x18\x01 \x01(\t\",\n\x18WorkerUnsubscribeRequest\x12\x10\n\x08workerId\x18\x01 \x01(\t\"?\n\x19WorkerUnsubscribeResponse\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x10\n\x08workerId\x18\x02 \x01(\t\"\xe1\x01\n\x13GroupKeyActionEvent\x12\x10\n\x08workerId\x18\x01 \x01(\t\x12\x15\n\rworkflowRunId\x18\x02 \x01(\t\x12\x18\n\x10getGroupKeyRunId\x18\x03 \x01(\t\x12\x10\n\x08\x61\x63tionId\x18\x04 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x05 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12+\n\teventType\x18\x06 \x01(\x0e\x32\x18.GroupKeyActionEventType\x12\x14\n\x0c\x65ventPayload\x18\x07 \x01(\t\"\xec\x01\n\x0fStepActionEvent\x12\x10\n\x08workerId\x18\x01 \x01(\t\x12\r\n\x05jobId\x18\x02 \x01(\t\x12\x10\n\x08jobRunId\x18\x03 \x01(\t\x12\x0e\n\x06stepId\x18\x04 \x01(\t\x12\x11\n\tstepRunId\x18\x05 \x01(\t\x12\x10\n\x08\x61\x63tionId\x18\x06 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\'\n\teventType\x18\x08 \x01(\x0e\x32\x14.StepActionEventType\x12\x14\n\x0c\x65ventPayload\x18\t \x01(\t\"9\n\x13\x41\x63tionEventResponse\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x10\n\x08workerId\x18\x02 \x01(\t*N\n\nActionType\x12\x12\n\x0eSTART_STEP_RUN\x10\x00\x12\x13\n\x0f\x43\x41NCEL_STEP_RUN\x10\x01\x12\x17\n\x13START_GET_GROUP_KEY\x10\x02*\xa2\x01\n\x17GroupKeyActionEventType\x12 \n\x1cGROUP_KEY_EVENT_TYPE_UNKNOWN\x10\x00\x12 \n\x1cGROUP_KEY_EVENT_TYPE_STARTED\x10\x01\x12\"\n\x1eGROUP_KEY_EVENT_TYPE_COMPLETED\x10\x02\x12\x1f\n\x1bGROUP_KEY_EVENT_TYPE_FAILED\x10\x03*\x8a\x01\n\x13StepActionEventType\x12\x1b\n\x17STEP_EVENT_TYPE_UNKNOWN\x10\x00\x12\x1b\n\x17STEP_EVENT_TYPE_STARTED\x10\x01\x12\x1d\n\x19STEP_EVENT_TYPE_COMPLETED\x10\x02\x12\x1a\n\x16STEP_EVENT_TYPE_FAILED\x10\x03\x32\xd2\x02\n\nDispatcher\x12=\n\x08Register\x12\x16.WorkerRegisterRequest\x1a\x17.WorkerRegisterResponse\"\x00\x12\x33\n\x06Listen\x12\x14.WorkerListenRequest\x1a\x0f.AssignedAction\"\x00\x30\x01\x12?\n\x13SendStepActionEvent\x12\x10.StepActionEvent\x1a\x14.ActionEventResponse\"\x00\x12G\n\x17SendGroupKeyActionEvent\x12\x14.GroupKeyActionEvent\x1a\x14.ActionEventResponse\"\x00\x12\x46\n\x0bUnsubscribe\x12\x19.WorkerUnsubscribeRequest\x1a\x1a.WorkerUnsubscribeResponse\"\x00\x42GZEgithub.com/hatchet-dev/hatchet/internal/services/dispatcher/contractsb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -23,26 +23,30 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'dispatcher_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'ZEgithub.com/hatchet-dev/hatchet/internal/services/dispatcher/contracts'
_globals['_ACTIONTYPE']._serialized_start=853
_globals['_ACTIONTYPE']._serialized_end=906
_globals['_ACTIONEVENTTYPE']._serialized_start=909
_globals['_ACTIONEVENTTYPE']._serialized_end=1043
_globals['_ACTIONTYPE']._serialized_start=1138
_globals['_ACTIONTYPE']._serialized_end=1216
_globals['_GROUPKEYACTIONEVENTTYPE']._serialized_start=1219
_globals['_GROUPKEYACTIONEVENTTYPE']._serialized_end=1381
_globals['_STEPACTIONEVENTTYPE']._serialized_start=1384
_globals['_STEPACTIONEVENTTYPE']._serialized_end=1522
_globals['_WORKERREGISTERREQUEST']._serialized_start=53
_globals['_WORKERREGISTERREQUEST']._serialized_end=131
_globals['_WORKERREGISTERRESPONSE']._serialized_start=133
_globals['_WORKERREGISTERRESPONSE']._serialized_end=213
_globals['_ASSIGNEDACTION']._serialized_start=216
_globals['_ASSIGNEDACTION']._serialized_end=409
_globals['_WORKERLISTENREQUEST']._serialized_start=411
_globals['_WORKERLISTENREQUEST']._serialized_end=450
_globals['_WORKERUNSUBSCRIBEREQUEST']._serialized_start=452
_globals['_WORKERUNSUBSCRIBEREQUEST']._serialized_end=496
_globals['_WORKERUNSUBSCRIBERESPONSE']._serialized_start=498
_globals['_WORKERUNSUBSCRIBERESPONSE']._serialized_end=561
_globals['_ACTIONEVENT']._serialized_start=564
_globals['_ACTIONEVENT']._serialized_end=792
_globals['_ACTIONEVENTRESPONSE']._serialized_start=794
_globals['_ACTIONEVENTRESPONSE']._serialized_end=851
_globals['_DISPATCHER']._serialized_start=1046
_globals['_DISPATCHER']._serialized_end=1303
_globals['_ASSIGNEDACTION']._serialized_end=458
_globals['_WORKERLISTENREQUEST']._serialized_start=460
_globals['_WORKERLISTENREQUEST']._serialized_end=499
_globals['_WORKERUNSUBSCRIBEREQUEST']._serialized_start=501
_globals['_WORKERUNSUBSCRIBEREQUEST']._serialized_end=545
_globals['_WORKERUNSUBSCRIBERESPONSE']._serialized_start=547
_globals['_WORKERUNSUBSCRIBERESPONSE']._serialized_end=610
_globals['_GROUPKEYACTIONEVENT']._serialized_start=613
_globals['_GROUPKEYACTIONEVENT']._serialized_end=838
_globals['_STEPACTIONEVENT']._serialized_start=841
_globals['_STEPACTIONEVENT']._serialized_end=1077
_globals['_ACTIONEVENTRESPONSE']._serialized_start=1079
_globals['_ACTIONEVENTRESPONSE']._serialized_end=1136
_globals['_DISPATCHER']._serialized_start=1525
_globals['_DISPATCHER']._serialized_end=1863
# @@protoc_insertion_point(module_scope)

View File

@@ -11,19 +11,32 @@ class ActionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
START_STEP_RUN: _ClassVar[ActionType]
CANCEL_STEP_RUN: _ClassVar[ActionType]
START_GET_GROUP_KEY: _ClassVar[ActionType]
class ActionEventType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
class GroupKeyActionEventType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STEP_EVENT_TYPE_UNKNOWN: _ClassVar[ActionEventType]
STEP_EVENT_TYPE_STARTED: _ClassVar[ActionEventType]
STEP_EVENT_TYPE_COMPLETED: _ClassVar[ActionEventType]
STEP_EVENT_TYPE_FAILED: _ClassVar[ActionEventType]
GROUP_KEY_EVENT_TYPE_UNKNOWN: _ClassVar[GroupKeyActionEventType]
GROUP_KEY_EVENT_TYPE_STARTED: _ClassVar[GroupKeyActionEventType]
GROUP_KEY_EVENT_TYPE_COMPLETED: _ClassVar[GroupKeyActionEventType]
GROUP_KEY_EVENT_TYPE_FAILED: _ClassVar[GroupKeyActionEventType]
class StepActionEventType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STEP_EVENT_TYPE_UNKNOWN: _ClassVar[StepActionEventType]
STEP_EVENT_TYPE_STARTED: _ClassVar[StepActionEventType]
STEP_EVENT_TYPE_COMPLETED: _ClassVar[StepActionEventType]
STEP_EVENT_TYPE_FAILED: _ClassVar[StepActionEventType]
START_STEP_RUN: ActionType
CANCEL_STEP_RUN: ActionType
STEP_EVENT_TYPE_UNKNOWN: ActionEventType
STEP_EVENT_TYPE_STARTED: ActionEventType
STEP_EVENT_TYPE_COMPLETED: ActionEventType
STEP_EVENT_TYPE_FAILED: ActionEventType
START_GET_GROUP_KEY: ActionType
GROUP_KEY_EVENT_TYPE_UNKNOWN: GroupKeyActionEventType
GROUP_KEY_EVENT_TYPE_STARTED: GroupKeyActionEventType
GROUP_KEY_EVENT_TYPE_COMPLETED: GroupKeyActionEventType
GROUP_KEY_EVENT_TYPE_FAILED: GroupKeyActionEventType
STEP_EVENT_TYPE_UNKNOWN: StepActionEventType
STEP_EVENT_TYPE_STARTED: StepActionEventType
STEP_EVENT_TYPE_COMPLETED: StepActionEventType
STEP_EVENT_TYPE_FAILED: StepActionEventType
class WorkerRegisterRequest(_message.Message):
__slots__ = ("workerName", "actions", "services")
@@ -46,8 +59,10 @@ class WorkerRegisterResponse(_message.Message):
def __init__(self, tenantId: _Optional[str] = ..., workerId: _Optional[str] = ..., workerName: _Optional[str] = ...) -> None: ...
class AssignedAction(_message.Message):
__slots__ = ("tenantId", "jobId", "jobName", "jobRunId", "stepId", "stepRunId", "actionId", "actionType", "actionPayload")
__slots__ = ("tenantId", "workflowRunId", "getGroupKeyRunId", "jobId", "jobName", "jobRunId", "stepId", "stepRunId", "actionId", "actionType", "actionPayload")
TENANTID_FIELD_NUMBER: _ClassVar[int]
WORKFLOWRUNID_FIELD_NUMBER: _ClassVar[int]
GETGROUPKEYRUNID_FIELD_NUMBER: _ClassVar[int]
JOBID_FIELD_NUMBER: _ClassVar[int]
JOBNAME_FIELD_NUMBER: _ClassVar[int]
JOBRUNID_FIELD_NUMBER: _ClassVar[int]
@@ -57,6 +72,8 @@ class AssignedAction(_message.Message):
ACTIONTYPE_FIELD_NUMBER: _ClassVar[int]
ACTIONPAYLOAD_FIELD_NUMBER: _ClassVar[int]
tenantId: str
workflowRunId: str
getGroupKeyRunId: str
jobId: str
jobName: str
jobRunId: str
@@ -65,7 +82,7 @@ class AssignedAction(_message.Message):
actionId: str
actionType: ActionType
actionPayload: str
def __init__(self, tenantId: _Optional[str] = ..., jobId: _Optional[str] = ..., jobName: _Optional[str] = ..., jobRunId: _Optional[str] = ..., stepId: _Optional[str] = ..., stepRunId: _Optional[str] = ..., actionId: _Optional[str] = ..., actionType: _Optional[_Union[ActionType, str]] = ..., actionPayload: _Optional[str] = ...) -> None: ...
def __init__(self, tenantId: _Optional[str] = ..., workflowRunId: _Optional[str] = ..., getGroupKeyRunId: _Optional[str] = ..., jobId: _Optional[str] = ..., jobName: _Optional[str] = ..., jobRunId: _Optional[str] = ..., stepId: _Optional[str] = ..., stepRunId: _Optional[str] = ..., actionId: _Optional[str] = ..., actionType: _Optional[_Union[ActionType, str]] = ..., actionPayload: _Optional[str] = ...) -> None: ...
class WorkerListenRequest(_message.Message):
__slots__ = ("workerId",)
@@ -87,7 +104,25 @@ class WorkerUnsubscribeResponse(_message.Message):
workerId: str
def __init__(self, tenantId: _Optional[str] = ..., workerId: _Optional[str] = ...) -> None: ...
class ActionEvent(_message.Message):
class GroupKeyActionEvent(_message.Message):
__slots__ = ("workerId", "workflowRunId", "getGroupKeyRunId", "actionId", "eventTimestamp", "eventType", "eventPayload")
WORKERID_FIELD_NUMBER: _ClassVar[int]
WORKFLOWRUNID_FIELD_NUMBER: _ClassVar[int]
GETGROUPKEYRUNID_FIELD_NUMBER: _ClassVar[int]
ACTIONID_FIELD_NUMBER: _ClassVar[int]
EVENTTIMESTAMP_FIELD_NUMBER: _ClassVar[int]
EVENTTYPE_FIELD_NUMBER: _ClassVar[int]
EVENTPAYLOAD_FIELD_NUMBER: _ClassVar[int]
workerId: str
workflowRunId: str
getGroupKeyRunId: str
actionId: str
eventTimestamp: _timestamp_pb2.Timestamp
eventType: GroupKeyActionEventType
eventPayload: str
def __init__(self, workerId: _Optional[str] = ..., workflowRunId: _Optional[str] = ..., getGroupKeyRunId: _Optional[str] = ..., actionId: _Optional[str] = ..., eventTimestamp: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., eventType: _Optional[_Union[GroupKeyActionEventType, str]] = ..., eventPayload: _Optional[str] = ...) -> None: ...
class StepActionEvent(_message.Message):
__slots__ = ("workerId", "jobId", "jobRunId", "stepId", "stepRunId", "actionId", "eventTimestamp", "eventType", "eventPayload")
WORKERID_FIELD_NUMBER: _ClassVar[int]
JOBID_FIELD_NUMBER: _ClassVar[int]
@@ -105,9 +140,9 @@ class ActionEvent(_message.Message):
stepRunId: str
actionId: str
eventTimestamp: _timestamp_pb2.Timestamp
eventType: ActionEventType
eventType: StepActionEventType
eventPayload: str
def __init__(self, workerId: _Optional[str] = ..., jobId: _Optional[str] = ..., jobRunId: _Optional[str] = ..., stepId: _Optional[str] = ..., stepRunId: _Optional[str] = ..., actionId: _Optional[str] = ..., eventTimestamp: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., eventType: _Optional[_Union[ActionEventType, str]] = ..., eventPayload: _Optional[str] = ...) -> None: ...
def __init__(self, workerId: _Optional[str] = ..., jobId: _Optional[str] = ..., jobRunId: _Optional[str] = ..., stepId: _Optional[str] = ..., stepRunId: _Optional[str] = ..., actionId: _Optional[str] = ..., eventTimestamp: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., eventType: _Optional[_Union[StepActionEventType, str]] = ..., eventPayload: _Optional[str] = ...) -> None: ...
class ActionEventResponse(_message.Message):
__slots__ = ("tenantId", "workerId")

View File

@@ -23,9 +23,14 @@ class DispatcherStub(object):
request_serializer=dispatcher__pb2.WorkerListenRequest.SerializeToString,
response_deserializer=dispatcher__pb2.AssignedAction.FromString,
)
self.SendActionEvent = channel.unary_unary(
'/Dispatcher/SendActionEvent',
request_serializer=dispatcher__pb2.ActionEvent.SerializeToString,
self.SendStepActionEvent = channel.unary_unary(
'/Dispatcher/SendStepActionEvent',
request_serializer=dispatcher__pb2.StepActionEvent.SerializeToString,
response_deserializer=dispatcher__pb2.ActionEventResponse.FromString,
)
self.SendGroupKeyActionEvent = channel.unary_unary(
'/Dispatcher/SendGroupKeyActionEvent',
request_serializer=dispatcher__pb2.GroupKeyActionEvent.SerializeToString,
response_deserializer=dispatcher__pb2.ActionEventResponse.FromString,
)
self.Unsubscribe = channel.unary_unary(
@@ -50,7 +55,13 @@ class DispatcherServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendActionEvent(self, request, context):
def SendStepActionEvent(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendGroupKeyActionEvent(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
@@ -75,9 +86,14 @@ def add_DispatcherServicer_to_server(servicer, server):
request_deserializer=dispatcher__pb2.WorkerListenRequest.FromString,
response_serializer=dispatcher__pb2.AssignedAction.SerializeToString,
),
'SendActionEvent': grpc.unary_unary_rpc_method_handler(
servicer.SendActionEvent,
request_deserializer=dispatcher__pb2.ActionEvent.FromString,
'SendStepActionEvent': grpc.unary_unary_rpc_method_handler(
servicer.SendStepActionEvent,
request_deserializer=dispatcher__pb2.StepActionEvent.FromString,
response_serializer=dispatcher__pb2.ActionEventResponse.SerializeToString,
),
'SendGroupKeyActionEvent': grpc.unary_unary_rpc_method_handler(
servicer.SendGroupKeyActionEvent,
request_deserializer=dispatcher__pb2.GroupKeyActionEvent.FromString,
response_serializer=dispatcher__pb2.ActionEventResponse.SerializeToString,
),
'Unsubscribe': grpc.unary_unary_rpc_method_handler(
@@ -130,7 +146,7 @@ class Dispatcher(object):
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def SendActionEvent(request,
def SendStepActionEvent(request,
target,
options=(),
channel_credentials=None,
@@ -140,8 +156,25 @@ class Dispatcher(object):
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/Dispatcher/SendActionEvent',
dispatcher__pb2.ActionEvent.SerializeToString,
return grpc.experimental.unary_unary(request, target, '/Dispatcher/SendStepActionEvent',
dispatcher__pb2.StepActionEvent.SerializeToString,
dispatcher__pb2.ActionEventResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def SendGroupKeyActionEvent(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/Dispatcher/SendGroupKeyActionEvent',
dispatcher__pb2.GroupKeyActionEvent.SerializeToString,
dispatcher__pb2.ActionEventResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@@ -12,6 +12,15 @@ class Hatchet:
if not debug:
logger.disable("hatchet_sdk")
def concurrency(self, name : str='', max_runs : int = 1):
def inner(func):
func._concurrency_fn_name = name or func.__name__
func._concurrency_max_runs = max_runs
return func
return inner
def workflow(self, name : str='', on_events : list=[], on_crons : list=[], version : str=''):
def inner(cls):
cls.on_events = on_events

View File

@@ -10,7 +10,7 @@ import grpc
from typing import Any, Callable, Dict
from .workflow import WorkflowMeta
from .clients.dispatcher import GetActionListenerRequest, ActionListenerImpl, Action
from .dispatcher_pb2 import STEP_EVENT_TYPE_FAILED, ActionType, ActionEvent, ActionEventType, STEP_EVENT_TYPE_COMPLETED, STEP_EVENT_TYPE_STARTED
from .dispatcher_pb2 import ActionType, StepActionEvent, StepActionEventType, GroupKeyActionEvent, GroupKeyActionEventType, STEP_EVENT_TYPE_COMPLETED, STEP_EVENT_TYPE_STARTED, STEP_EVENT_TYPE_FAILED, GROUP_KEY_EVENT_TYPE_STARTED, GROUP_KEY_EVENT_TYPE_COMPLETED, GROUP_KEY_EVENT_TYPE_FAILED
from .client import new_client
from concurrent.futures import ThreadPoolExecutor, Future
from google.protobuf.timestamp_pb2 import Timestamp
@@ -53,24 +53,24 @@ class Worker:
errored = True
# This except is coming from the application itself, so we want to send that to the Hatchet instance
event = self.get_action_event(action, STEP_EVENT_TYPE_FAILED)
event = self.get_step_action_event(action, STEP_EVENT_TYPE_FAILED)
event.eventPayload = str(e)
try:
self.client.dispatcher.send_action_event(event)
self.client.dispatcher.send_step_action_event(event)
except Exception as e:
logger.error(f"Could not send action event: {e}")
if not errored:
# Create an action event
try:
event = self.get_action_finished_event(action, output)
event = self.get_step_action_finished_event(action, output)
except Exception as e:
logger.error(f"Could not get action finished event: {e}")
raise e
# Send the action event to the dispatcher
self.client.dispatcher.send_action_event(event)
self.client.dispatcher.send_step_action_event(event)
# Remove the future from the dictionary
if action.step_run_id in self.futures:
@@ -100,12 +100,86 @@ class Worker:
# send an event that the step run has started
try:
event = self.get_action_event(action, STEP_EVENT_TYPE_STARTED)
event = self.get_step_action_event(action, STEP_EVENT_TYPE_STARTED)
except Exception as e:
logger.error(f"Could not create action event: {e}")
# Send the action event to the dispatcher
self.client.dispatcher.send_action_event(event)
self.client.dispatcher.send_step_action_event(event)
def handle_start_group_key_run(self, action : Action):
action_name = action.action_id # Assuming action object has 'name' attribute
context = Context(action.action_payload) # Assuming action object has 'context' attribute
self.contexts[action.get_group_key_run_id] = context
# Find the corresponding action function from the registry
action_func = self.action_registry.get(action_name)
if action_func:
def callback(future : Future):
errored = False
# Get the output from the future
try:
output = future.result()
except Exception as e:
errored = True
# This except is coming from the application itself, so we want to send that to the Hatchet instance
event = self.get_group_key_action_event(action, GROUP_KEY_EVENT_TYPE_FAILED)
event.eventPayload = str(e)
try:
self.client.dispatcher.send_group_key_action_event(event)
except Exception as e:
logger.error(f"Could not send action event: {e}")
if not errored:
# Create an action event
try:
event = self.get_group_key_action_finished_event(action, output)
except Exception as e:
logger.error(f"Could not get action finished event: {e}")
raise e
# Send the action event to the dispatcher
self.client.dispatcher.send_group_key_action_event(event)
# Remove the future from the dictionary
if action.get_group_key_run_id in self.futures:
del self.futures[action.get_group_key_run_id]
# Submit the action to the thread pool
def wrapped_action_func(context):
# store the thread id
self.threads[action.get_group_key_run_id] = current_thread()
try:
res = action_func(context)
return res
except Exception as e:
logger.error(f"Could not execute action: {e}")
raise e
finally:
if action.get_group_key_run_id in self.threads:
# remove the thread id
logger.debug(f"Removing step run id {action.get_group_key_run_id} from threads")
del self.threads[action.get_group_key_run_id]
future = self.thread_pool.submit(wrapped_action_func, context)
future.add_done_callback(callback)
self.futures[action.get_group_key_run_id] = future
# send an event that the step run has started
try:
event = self.get_group_key_action_event(action, GROUP_KEY_EVENT_TYPE_STARTED)
except Exception as e:
logger.error(f"Could not create action event: {e}")
# Send the action event to the dispatcher
self.client.dispatcher.send_group_key_action_event(event)
def force_kill_thread(self, thread):
"""Terminate a python threading.Thread."""
@@ -136,42 +210,37 @@ class Worker:
except Exception as e:
logger.exception(f"Failed to terminate thread: {e}")
def handle_cancel_step_run(self, action : Action):
step_run_id = action.step_run_id
def handle_cancel_action(self, run_id: str):
# call cancel to signal the context to stop
context = self.contexts.get(step_run_id)
context = self.contexts.get(run_id)
context.cancel()
future = self.futures.get(step_run_id)
future = self.futures.get(run_id)
if future:
future.cancel()
if step_run_id in self.futures:
del self.futures[step_run_id]
if run_id in self.futures:
del self.futures[run_id]
# grace period of 1 second
time.sleep(1)
# check if thread is still running, if so, kill it
if step_run_id in self.threads:
thread = self.threads[step_run_id]
if run_id in self.threads:
thread = self.threads[run_id]
if thread:
self.force_kill_thread(thread)
if step_run_id in self.threads:
del self.threads[step_run_id]
if run_id in self.threads:
del self.threads[run_id]
def get_action_event(self, action : Action, event_type : ActionEventType) -> ActionEvent:
# timestamp
# eventTimestamp = datetime.datetime.now(datetime.timezone.utc)
# eventTimestamp = eventTimestamp.isoformat()
def get_step_action_event(self, action : Action, event_type : StepActionEventType) -> StepActionEvent:
eventTimestamp = Timestamp()
eventTimestamp.GetCurrentTime()
return ActionEvent(
return StepActionEvent(
workerId=action.worker_id,
jobId=action.job_id,
jobRunId=action.job_run_id,
@@ -182,9 +251,9 @@ class Worker:
eventType=event_type,
)
def get_action_finished_event(self, action : Action, output : Any) -> ActionEvent:
def get_step_action_finished_event(self, action : Action, output : Any) -> StepActionEvent:
try:
event = self.get_action_event(action, STEP_EVENT_TYPE_COMPLETED)
event = self.get_step_action_event(action, STEP_EVENT_TYPE_COMPLETED)
except Exception as e:
logger.error(f"Could not create action finished event: {e}")
raise e
@@ -198,6 +267,30 @@ class Worker:
return event
def get_group_key_action_event(self, action : Action, event_type : GroupKeyActionEventType) -> GroupKeyActionEvent:
eventTimestamp = Timestamp()
eventTimestamp.GetCurrentTime()
return GroupKeyActionEvent(
workerId=action.worker_id,
workflowRunId=action.workflow_run_id,
getGroupKeyRunId=action.get_group_key_run_id,
actionId=action.action_id,
eventTimestamp=eventTimestamp,
eventType=event_type,
)
def get_group_key_action_finished_event(self, action : Action, output : str) -> StepActionEvent:
try:
event = self.get_group_key_action_event(action, GROUP_KEY_EVENT_TYPE_COMPLETED)
except Exception as e:
logger.error(f"Could not create action finished event: {e}")
raise e
event.eventPayload = output
return event
def register_workflow(self, workflow : WorkflowMeta):
def create_action_function(action_func):
def action_function(context):
@@ -247,8 +340,10 @@ class Worker:
self.handle_start_step_run(action)
elif action.action_type == ActionType.CANCEL_STEP_RUN:
self.thread_pool.submit(self.handle_cancel_step_run, action)
pass # Replace this with your actual processing code
elif action.action_type == ActionType.START_GET_GROUP_KEY:
self.handle_start_group_key_run(action)
else:
logger.error(f"Unknown action type: {action.action_type}")
except grpc.RpcError as rpc_error:
logger.error(f"Could not start worker: {rpc_error}")

View File

@@ -1,5 +1,5 @@
from .client import new_client
from .workflows_pb2 import CreateWorkflowVersionOpts, CreateWorkflowJobOpts, CreateWorkflowStepOpts
from .workflows_pb2 import CreateWorkflowVersionOpts, CreateWorkflowJobOpts, CreateWorkflowStepOpts, WorkflowConcurrencyOpts
from typing import Callable, List, Tuple, Any
stepsType = List[Tuple[str, Callable[..., Any]]]
@@ -8,6 +8,7 @@ class WorkflowMeta(type):
def __new__(cls, name, bases, attrs):
serviceName = "default"
concurrencyActions: stepsType = [(name.lower() + "-" + func_name, attrs.pop(func_name)) for func_name, func in list(attrs.items()) if hasattr(func, '_concurrency_fn_name')]
steps: stepsType = [(name.lower() + "-" + func_name, attrs.pop(func_name)) for func_name, func in list(attrs.items()) if hasattr(func, '_step_name')]
# Define __init__ and get_step_order methods
@@ -18,7 +19,10 @@ class WorkflowMeta(type):
original_init(self, *args, **kwargs) # Call original __init__
def get_actions(self) -> stepsType:
return [(serviceName + ":" + func_name, func) for func_name, func in steps]
func_actions = [(serviceName + ":" + func_name, func) for func_name, func in steps]
concurrency_actions = [(serviceName + ":" + func_name, func) for func_name, func in concurrencyActions]
return func_actions + concurrency_actions
# Add these methods and steps to class attributes
attrs['__init__'] = __init__
@@ -48,6 +52,17 @@ class WorkflowMeta(type):
for func_name, func in attrs.items() if hasattr(func, '_step_name')
]
concurrency : WorkflowConcurrencyOpts | None = None
if len(concurrencyActions) > 0:
action = concurrencyActions[0]
concurrency = WorkflowConcurrencyOpts(
action="default:" + action[0],
max_runs=action[1]._concurrency_max_runs,
)
client.admin.put_workflow(CreateWorkflowVersionOpts(
name=name,
version=version,
@@ -59,7 +74,8 @@ class WorkflowMeta(type):
timeout="60s",
steps=createStepOpts,
)
]
],
concurrency=concurrency,
))
return super(WorkflowMeta, cls).__new__(cls, name, bases, attrs)

File diff suppressed because one or more lines are too long

View File

@@ -1,12 +1,22 @@
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import wrappers_pb2 as _wrappers_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class ConcurrencyLimitStrategy(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
CANCEL_IN_PROGRESS: _ClassVar[ConcurrencyLimitStrategy]
DROP_NEWEST: _ClassVar[ConcurrencyLimitStrategy]
QUEUE_NEWEST: _ClassVar[ConcurrencyLimitStrategy]
CANCEL_IN_PROGRESS: ConcurrencyLimitStrategy
DROP_NEWEST: ConcurrencyLimitStrategy
QUEUE_NEWEST: ConcurrencyLimitStrategy
class PutWorkflowRequest(_message.Message):
__slots__ = ("opts",)
OPTS_FIELD_NUMBER: _ClassVar[int]
@@ -14,7 +24,7 @@ class PutWorkflowRequest(_message.Message):
def __init__(self, opts: _Optional[_Union[CreateWorkflowVersionOpts, _Mapping]] = ...) -> None: ...
class CreateWorkflowVersionOpts(_message.Message):
__slots__ = ("name", "description", "version", "event_triggers", "cron_triggers", "scheduled_triggers", "jobs")
__slots__ = ("name", "description", "version", "event_triggers", "cron_triggers", "scheduled_triggers", "jobs", "concurrency")
NAME_FIELD_NUMBER: _ClassVar[int]
DESCRIPTION_FIELD_NUMBER: _ClassVar[int]
VERSION_FIELD_NUMBER: _ClassVar[int]
@@ -22,6 +32,7 @@ class CreateWorkflowVersionOpts(_message.Message):
CRON_TRIGGERS_FIELD_NUMBER: _ClassVar[int]
SCHEDULED_TRIGGERS_FIELD_NUMBER: _ClassVar[int]
JOBS_FIELD_NUMBER: _ClassVar[int]
CONCURRENCY_FIELD_NUMBER: _ClassVar[int]
name: str
description: str
version: str
@@ -29,7 +40,18 @@ class CreateWorkflowVersionOpts(_message.Message):
cron_triggers: _containers.RepeatedScalarFieldContainer[str]
scheduled_triggers: _containers.RepeatedCompositeFieldContainer[_timestamp_pb2.Timestamp]
jobs: _containers.RepeatedCompositeFieldContainer[CreateWorkflowJobOpts]
def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., version: _Optional[str] = ..., event_triggers: _Optional[_Iterable[str]] = ..., cron_triggers: _Optional[_Iterable[str]] = ..., scheduled_triggers: _Optional[_Iterable[_Union[_timestamp_pb2.Timestamp, _Mapping]]] = ..., jobs: _Optional[_Iterable[_Union[CreateWorkflowJobOpts, _Mapping]]] = ...) -> None: ...
concurrency: WorkflowConcurrencyOpts
def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., version: _Optional[str] = ..., event_triggers: _Optional[_Iterable[str]] = ..., cron_triggers: _Optional[_Iterable[str]] = ..., scheduled_triggers: _Optional[_Iterable[_Union[_timestamp_pb2.Timestamp, _Mapping]]] = ..., jobs: _Optional[_Iterable[_Union[CreateWorkflowJobOpts, _Mapping]]] = ..., concurrency: _Optional[_Union[WorkflowConcurrencyOpts, _Mapping]] = ...) -> None: ...
class WorkflowConcurrencyOpts(_message.Message):
__slots__ = ("action", "max_runs", "limit_strategy")
ACTION_FIELD_NUMBER: _ClassVar[int]
MAX_RUNS_FIELD_NUMBER: _ClassVar[int]
LIMIT_STRATEGY_FIELD_NUMBER: _ClassVar[int]
action: str
max_runs: int
limit_strategy: ConcurrencyLimitStrategy
def __init__(self, action: _Optional[str] = ..., max_runs: _Optional[int] = ..., limit_strategy: _Optional[_Union[ConcurrencyLimitStrategy, str]] = ...) -> None: ...
class CreateWorkflowJobOpts(_message.Message):
__slots__ = ("name", "description", "timeout", "steps")

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
version = "0.6.0"
version = "0.7.0"
description = ""
authors = ["Alexander Belanger <alexander@hatchet.run>"]
readme = "README.md"