feat: streaming events (#309)

* feat: add stream event model

* docs: how to work with db models

* feat: put stream event

* chore: rm comments

* feat: add stream resource type

* feat: enqueue stream event

* fix: contracts

* feat: protos

* chore: set properties correctly for typing

* fix: stream example

* chore: rm old example

* fix: async on

* fix: bytea type

* fix: worker

* feat: put stream data

* feat: stream type

* fix: correct queue

* feat: streaming payloads

* fix: cleanup

* fix: validation

* feat: example file streaming

* chore: rm unused query

* fix: tenant check and read only consumer

* fix: check tenant-steprun relation

* Update prisma/schema.prisma

Co-authored-by: abelanger5 <belanger@sas.upenn.edu>

* chore: generate protos

* chore: rename migration

* release: 0.20.0

* feat(go-sdk): implement streaming in go

---------

Co-authored-by: gabriel ruttner <gabe@hatchet.run>
Co-authored-by: abelanger5 <belanger@sas.upenn.edu>
This commit is contained in:
Gabe Ruttner
2024-04-01 12:46:21 -07:00
committed by GitHub
parent 7b7fbe3668
commit d8b6843dec
49 changed files with 1173 additions and 185 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 177 KiB

View File

@@ -0,0 +1,40 @@
import os
from hatchet_sdk import new_client
from dotenv import load_dotenv
import json
import asyncio
from hatchet_sdk.clients.listener import StepRunEventType
import base64
async def main():
load_dotenv()
hatchet = new_client()
workflowRunId = hatchet.admin.run_workflow("ManualTriggerWorkflow", {"test": "test"})
listener = hatchet.listener.stream(workflowRunId)
# Get the directory of the current script
script_dir = os.path.dirname(os.path.abspath(__file__))
# Create the "out" directory if it doesn't exist
out_dir = os.path.join(script_dir, "out")
os.makedirs(out_dir, exist_ok=True)
async for event in listener:
if event.type == StepRunEventType.STEP_RUN_EVENT_TYPE_STREAM:
# Decode the base64-encoded payload
decoded_payload = base64.b64decode(event.payload)
# Construct the path to the payload file in the "out" directory
payload_path = os.path.join(out_dir, "payload.jpg")
with open(payload_path, "wb") as f:
f.write(decoded_payload)
data = json.dumps({
"type": event.type,
"messageId": workflowRunId
})
print("data: " + data + "\n\n")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,14 +0,0 @@
from hatchet_sdk import new_client
from dotenv import load_dotenv
import json
load_dotenv()
client = new_client()
workflowRunId = client.admin.run_workflow("ManualTriggerWorkflow", {
"test": "test"
})
client.listener.on(workflowRunId, lambda event: print(
'EVENT: ' + event.type + ' ' + json.dumps(event.payload)))

View File

@@ -1,5 +1,7 @@
from hatchet_sdk import Hatchet
from hatchet_sdk import Hatchet, Context
from dotenv import load_dotenv
import base64
import os
load_dotenv()
@@ -9,9 +11,27 @@ hatchet = Hatchet(debug=True)
@hatchet.workflow(on_events=["man:create"])
class ManualTriggerWorkflow:
@hatchet.step()
def step1(self, context):
def step1(self, context: Context):
res = context.playground('res', "HELLO")
# Get the directory of the current script
script_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the path to the image file relative to the script's directory
image_path = os.path.join(script_dir, "image.jpeg")
# Load the image file
with open(image_path, "rb") as image_file:
image_data = image_file.read()
print(len(image_data))
# Encode the image data as base64
base64_image = base64.b64encode(image_data).decode('utf-8')
# Stream the base64-encoded image data
context.put_stream(base64_image)
context.sleep(3)
print("executed step1")
return {"step1": "data1 "+res}
@@ -23,22 +43,6 @@ class ManualTriggerWorkflow:
print("finished step2")
return {"step2": "data2"}
# @hatchet.step()
# def stepb(self, context):
# res = context.playground('res', "HELLO")
# context.sleep(3)
# print("executed step1")
# return {"step1": "data1 "+res}
# @hatchet.step(parents=["stepb"], timeout='4s')
# def stepc(self, context):
# print("started step2")
# context.sleep(1)
# print("finished step2")
# return {"step2": "data2"}
workflow = ManualTriggerWorkflow()
worker = hatchet.worker('manual-worker', max_runs=4)
worker.register_workflow(workflow)

View File

@@ -18,20 +18,11 @@ from .clients.rest.configuration import Configuration
from .clients.rest_client import RestApi
class Client:
def admin(self):
raise NotImplementedError
def dispatcher(self):
raise NotImplementedError
def event(self):
raise NotImplementedError
def listener(self):
raise NotImplementedError
def rest(self):
raise NotImplementedError
admin: AdminClientImpl
dispatcher: DispatcherClientImpl
event: EventClientImpl
listener: ListenerClientImpl
rest: RestApi
class ClientImpl(Client):
@@ -43,30 +34,11 @@ class ClientImpl(Client):
listener_client: ListenerClientImpl,
rest_client: RestApi
):
# self.conn = conn
# self.tenant_id = tenant_id
# self.logger = logger
# self.validator = validator
self.admin = admin_client
self.dispatcher = dispatcher_client
self.event = event_client
self.listener = listener_client
self.rest_client = rest_client
def admin(self) -> AdminClientImpl:
return self.admin
def dispatcher(self) -> DispatcherClientImpl:
return self.dispatcher
def event(self) -> EventClientImpl:
return self.event
def listener(self) -> ListenerClientImpl:
return self.listener
def rest(self) -> RestApi:
return self.rest_client
self.rest = rest_client
def with_host_port(host: str, port: int):
def with_host_port_impl(config: ClientConfig):

View File

@@ -1,5 +1,5 @@
from ..events_pb2_grpc import EventsServiceStub
from ..events_pb2 import PushEventRequest, PutLogRequest
from ..events_pb2 import PushEventRequest, PutLogRequest, PutStreamEventRequest
import datetime
from ..loader import ClientConfig
@@ -22,7 +22,7 @@ def proto_timestamp_now():
return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
class EventClientImpl:
def __init__(self, client, token):
def __init__(self, client: EventsServiceStub, token):
self.client = client
self.token = token
@@ -53,4 +53,22 @@ class EventClientImpl:
self.client.PutLog(request, metadata=get_metadata(self.token))
except Exception as e:
raise ValueError(f"Error logging: {e}")
raise ValueError(f"Error logging: {e}")
def stream(self, data: str | bytes, step_run_id: str):
try:
if isinstance(data, str):
data_bytes = data.encode('utf-8')
elif isinstance(data, bytes):
data_bytes = data
else:
raise ValueError("Invalid data type. Expected str, bytes, or file.")
request = PutStreamEventRequest(
stepRunId=step_run_id,
createdAt=proto_timestamp_now(),
message=data_bytes,
)
self.client.PutStreamEvent(request, metadata=get_metadata(self.token))
except Exception as e:
raise ValueError(f"Error putting stream event: {e}")

View File

@@ -20,6 +20,7 @@ class StepRunEventType:
STEP_RUN_EVENT_TYPE_FAILED = 'STEP_RUN_EVENT_TYPE_FAILED'
STEP_RUN_EVENT_TYPE_CANCELLED = 'STEP_RUN_EVENT_TYPE_CANCELLED'
STEP_RUN_EVENT_TYPE_TIMED_OUT = 'STEP_RUN_EVENT_TYPE_TIMED_OUT'
STEP_RUN_EVENT_TYPE_STREAM = 'STEP_RUN_EVENT_TYPE_STREAM'
class WorkflowRunEventType:
WORKFLOW_RUN_EVENT_TYPE_STARTED = 'WORKFLOW_RUN_EVENT_TYPE_STARTED'
@@ -34,6 +35,7 @@ step_run_event_type_mapping = {
ResourceEventType.RESOURCE_EVENT_TYPE_FAILED: StepRunEventType.STEP_RUN_EVENT_TYPE_FAILED,
ResourceEventType.RESOURCE_EVENT_TYPE_CANCELLED: StepRunEventType.STEP_RUN_EVENT_TYPE_CANCELLED,
ResourceEventType.RESOURCE_EVENT_TYPE_TIMED_OUT: StepRunEventType.STEP_RUN_EVENT_TYPE_TIMED_OUT,
ResourceEventType.RESOURCE_EVENT_TYPE_STREAM: StepRunEventType.STEP_RUN_EVENT_TYPE_STREAM,
}
workflow_run_event_type_mapping = {
@@ -95,6 +97,7 @@ class HatchetListener:
if workflow_event.eventPayload:
payload = json.loads(workflow_event.eventPayload)
except Exception as e:
payload = workflow_event.eventPayload
pass
yield StepRunEvent(type=eventType, payload=payload)
@@ -166,8 +169,8 @@ class ListenerClientImpl:
def stream(self, workflow_run_id: str):
return HatchetListener(workflow_run_id, self.token, self.config)
def on(self, workflow_run_id: str, handler: callable = None):
for event in self.stream(workflow_run_id):
async def on(self, workflow_run_id: str, handler: callable = None):
async for event in self.stream(workflow_run_id):
# call the handler if provided
if handler:
handler(event)

View File

@@ -127,6 +127,7 @@ class Context:
# FIXME: this limits the number of concurrent log requests to 1, which means we can do about
# 100 log lines per second but this depends on network.
self.logger_thread_pool = ThreadPoolExecutor(max_workers=1)
self.stream_event_thread_pool = ThreadPoolExecutor(max_workers=1)
# store each key in the overrides field in a lookup table
# overrides_data is a dictionary of key-value pairs
@@ -216,3 +217,15 @@ class Context:
return
self.logger_thread_pool.submit(self._log, line)
def _put_stream(self, data: str | bytes):
try:
self.client.event.stream(data=data, step_run_id=self.stepRunId)
except Exception as e:
logger.error(f"Error putting stream event: {e}")
def put_stream(self, data: str | bytes):
if self.stepRunId == "":
return
self.stream_event_thread_pool.submit(self._put_stream, data)

File diff suppressed because one or more lines are too long

View File

@@ -41,6 +41,7 @@ class ResourceEventType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
RESOURCE_EVENT_TYPE_FAILED: _ClassVar[ResourceEventType]
RESOURCE_EVENT_TYPE_CANCELLED: _ClassVar[ResourceEventType]
RESOURCE_EVENT_TYPE_TIMED_OUT: _ClassVar[ResourceEventType]
RESOURCE_EVENT_TYPE_STREAM: _ClassVar[ResourceEventType]
START_STEP_RUN: ActionType
CANCEL_STEP_RUN: ActionType
START_GET_GROUP_KEY: ActionType
@@ -61,6 +62,7 @@ RESOURCE_EVENT_TYPE_COMPLETED: ResourceEventType
RESOURCE_EVENT_TYPE_FAILED: ResourceEventType
RESOURCE_EVENT_TYPE_CANCELLED: ResourceEventType
RESOURCE_EVENT_TYPE_TIMED_OUT: ResourceEventType
RESOURCE_EVENT_TYPE_STREAM: ResourceEventType
class WorkerRegisterRequest(_message.Message):
__slots__ = ("workerName", "actions", "services", "maxRuns")

View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: events.proto
# Protobuf Python Version: 4.25.0
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -15,7 +15,7 @@ _sym_db = _symbol_database.Default()
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x65vents.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"|\n\x05\x45vent\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x0f\n\x07\x65ventId\x18\x02 \x01(\t\x12\x0b\n\x03key\x18\x03 \x01(\t\x12\x0f\n\x07payload\x18\x04 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x05 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"\x92\x01\n\rPutLogRequest\x12\x11\n\tstepRunId\x18\x01 \x01(\t\x12-\n\tcreatedAt\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0f\n\x07message\x18\x03 \x01(\t\x12\x12\n\x05level\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x10\n\x08metadata\x18\x05 \x01(\tB\x08\n\x06_level\"\x10\n\x0ePutLogResponse\"d\n\x10PushEventRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x0f\n\x07payload\x18\x02 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"%\n\x12ReplayEventRequest\x12\x0f\n\x07\x65ventId\x18\x01 \x01(\t2\x95\x01\n\rEventsService\x12#\n\x04Push\x12\x11.PushEventRequest\x1a\x06.Event\"\x00\x12\x32\n\x11ReplaySingleEvent\x12\x13.ReplayEventRequest\x1a\x06.Event\"\x00\x12+\n\x06PutLog\x12\x0e.PutLogRequest\x1a\x0f.PutLogResponse\"\x00\x42GZEgithub.com/hatchet-dev/hatchet/internal/services/dispatcher/contractsb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x65vents.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"|\n\x05\x45vent\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x0f\n\x07\x65ventId\x18\x02 \x01(\t\x12\x0b\n\x03key\x18\x03 \x01(\t\x12\x0f\n\x07payload\x18\x04 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x05 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"\x92\x01\n\rPutLogRequest\x12\x11\n\tstepRunId\x18\x01 \x01(\t\x12-\n\tcreatedAt\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0f\n\x07message\x18\x03 \x01(\t\x12\x12\n\x05level\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x10\n\x08metadata\x18\x05 \x01(\tB\x08\n\x06_level\"\x10\n\x0ePutLogResponse\"|\n\x15PutStreamEventRequest\x12\x11\n\tstepRunId\x18\x01 \x01(\t\x12-\n\tcreatedAt\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0f\n\x07message\x18\x03 \x01(\x0c\x12\x10\n\x08metadata\x18\x05 \x01(\t\"\x18\n\x16PutStreamEventResponse\"d\n\x10PushEventRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x0f\n\x07payload\x18\x02 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"%\n\x12ReplayEventRequest\x12\x0f\n\x07\x65ventId\x18\x01 \x01(\t2\xda\x01\n\rEventsService\x12#\n\x04Push\x12\x11.PushEventRequest\x1a\x06.Event\"\x00\x12\x32\n\x11ReplaySingleEvent\x12\x13.ReplayEventRequest\x1a\x06.Event\"\x00\x12+\n\x06PutLog\x12\x0e.PutLogRequest\x1a\x0f.PutLogResponse\"\x00\x12\x43\n\x0ePutStreamEvent\x12\x16.PutStreamEventRequest\x1a\x17.PutStreamEventResponse\"\x00\x42GZEgithub.com/hatchet-dev/hatchet/internal/services/dispatcher/contractsb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -29,10 +29,14 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_globals['_PUTLOGREQUEST']._serialized_end=322
_globals['_PUTLOGRESPONSE']._serialized_start=324
_globals['_PUTLOGRESPONSE']._serialized_end=340
_globals['_PUSHEVENTREQUEST']._serialized_start=342
_globals['_PUSHEVENTREQUEST']._serialized_end=442
_globals['_REPLAYEVENTREQUEST']._serialized_start=444
_globals['_REPLAYEVENTREQUEST']._serialized_end=481
_globals['_EVENTSSERVICE']._serialized_start=484
_globals['_EVENTSSERVICE']._serialized_end=633
_globals['_PUTSTREAMEVENTREQUEST']._serialized_start=342
_globals['_PUTSTREAMEVENTREQUEST']._serialized_end=466
_globals['_PUTSTREAMEVENTRESPONSE']._serialized_start=468
_globals['_PUTSTREAMEVENTRESPONSE']._serialized_end=492
_globals['_PUSHEVENTREQUEST']._serialized_start=494
_globals['_PUSHEVENTREQUEST']._serialized_end=594
_globals['_REPLAYEVENTREQUEST']._serialized_start=596
_globals['_REPLAYEVENTREQUEST']._serialized_end=633
_globals['_EVENTSSERVICE']._serialized_start=636
_globals['_EVENTSSERVICE']._serialized_end=854
# @@protoc_insertion_point(module_scope)

View File

@@ -37,6 +37,22 @@ class PutLogResponse(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class PutStreamEventRequest(_message.Message):
__slots__ = ("stepRunId", "createdAt", "message", "metadata")
STEPRUNID_FIELD_NUMBER: _ClassVar[int]
CREATEDAT_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
METADATA_FIELD_NUMBER: _ClassVar[int]
stepRunId: str
createdAt: _timestamp_pb2.Timestamp
message: bytes
metadata: str
def __init__(self, stepRunId: _Optional[str] = ..., createdAt: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., message: _Optional[bytes] = ..., metadata: _Optional[str] = ...) -> None: ...
class PutStreamEventResponse(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class PushEventRequest(_message.Message):
__slots__ = ("key", "payload", "eventTimestamp")
KEY_FIELD_NUMBER: _ClassVar[int]

View File

@@ -28,6 +28,11 @@ class EventsServiceStub(object):
request_serializer=events__pb2.PutLogRequest.SerializeToString,
response_deserializer=events__pb2.PutLogResponse.FromString,
)
self.PutStreamEvent = channel.unary_unary(
'/EventsService/PutStreamEvent',
request_serializer=events__pb2.PutStreamEventRequest.SerializeToString,
response_deserializer=events__pb2.PutStreamEventResponse.FromString,
)
class EventsServiceServicer(object):
@@ -51,6 +56,12 @@ class EventsServiceServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def PutStreamEvent(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_EventsServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
@@ -69,6 +80,11 @@ def add_EventsServiceServicer_to_server(servicer, server):
request_deserializer=events__pb2.PutLogRequest.FromString,
response_serializer=events__pb2.PutLogResponse.SerializeToString,
),
'PutStreamEvent': grpc.unary_unary_rpc_method_handler(
servicer.PutStreamEvent,
request_deserializer=events__pb2.PutStreamEventRequest.FromString,
response_serializer=events__pb2.PutStreamEventResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'EventsService', rpc_method_handlers)
@@ -129,3 +145,20 @@ class EventsService(object):
events__pb2.PutLogResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def PutStreamEvent(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/EventsService/PutStreamEvent',
events__pb2.PutStreamEventRequest.SerializeToString,
events__pb2.PutStreamEventResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
version = "0.19.0"
version = "0.20.0"
description = ""
authors = ["Alexander Belanger <alexander@hatchet.run>"]
readme = "README.md"