Files
hatchet/python-sdk/hatchet_sdk/context.py
2024-02-26 00:16:40 -05:00

76 lines
2.2 KiB
Python

import inspect
from multiprocessing import Event
import os
from .clients.dispatcher import Action, DispatcherClient
from .dispatcher_pb2 import OverridesData
from .logger import logger
import json
def get_caller_file_path():
caller_frame = inspect.stack()[2]
return caller_frame.filename
class Context:
def __init__(self, action: Action, client: DispatcherClient):
try:
self.data = json.loads(action.action_payload)
except Exception as e:
self.data = json.loads("{}")
self.stepRunId = action.step_run_id
self.exit_flag = Event()
self.client = client
# store each key in the overrides field in a lookup table
# overrides_data is a dictionary of key-value pairs
self.overrides_data = self.data.get('overrides', {})
if action.get_group_key_run_id != "":
self.input = self.data
else:
self.input = self.data.get('input', {})
def step_output(self, step: str):
try:
return self.data['parents'][step]
except KeyError:
raise ValueError(f"Step output for '{step}' not found")
def triggered_by_event(self) -> bool:
return self.data.get('triggered_by', '') == 'event'
def workflow_input(self):
return self.input
def sleep(self, seconds: int):
self.exit_flag.wait(seconds)
if self.exit_flag.is_set():
raise Exception("Context cancelled")
def cancel(self):
logger.info("Cancelling step...")
self.exit_flag.set()
# done returns true if the context has been cancelled
def done(self):
return self.exit_flag.is_set()
def playground(self, name: str, default: str = None):
# if the key exists in the overrides_data field, return the value
if name in self.overrides_data:
return self.overrides_data[name]
caller_file = get_caller_file_path()
self.client.put_overrides_data(
OverridesData(
stepRunId=self.stepRunId,
path=name,
value=json.dumps(default),
callerFilename=caller_file
)
)
return default