mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-01-08 09:49:41 -06:00
153 lines
3.9 KiB
Go
153 lines
3.9 KiB
Go
package loader
|
|
|
|
import (
|
|
"fmt"
|
|
"net/url"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/hatchet-dev/hatchet/pkg/config/client"
|
|
"github.com/hatchet-dev/hatchet/pkg/config/loader/loaderutils"
|
|
)
|
|
|
|
type ConfigLoader struct {
|
|
directory string
|
|
}
|
|
|
|
// LoadClientConfig loads the client configuration
|
|
func (c *ConfigLoader) LoadClientConfig(token *string) (res *client.ClientConfig, err error) {
|
|
sharedFilePath := filepath.Join(c.directory, "client.yaml")
|
|
configFileBytes, err := loaderutils.GetConfigBytes(sharedFilePath)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cf, err := LoadClientConfigFile(configFileBytes...)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if token != nil {
|
|
cf.Token = *token
|
|
}
|
|
|
|
return GetClientConfigFromConfigFile(cf)
|
|
}
|
|
|
|
// LoadClientConfigFile loads the worker config file via viper
|
|
func LoadClientConfigFile(files ...[]byte) (*client.ClientConfigFile, error) {
|
|
configFile := &client.ClientConfigFile{}
|
|
f := client.BindAllEnv
|
|
|
|
_, err := loaderutils.LoadConfigFromViper(f, configFile, files...)
|
|
|
|
return configFile, err
|
|
}
|
|
|
|
func GetClientConfigFromConfigFile(cf *client.ClientConfigFile) (res *client.ClientConfig, err error) {
|
|
f := client.BindAllEnv
|
|
|
|
_, err = loaderutils.LoadConfigFromViper(f, cf)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not load config from viper: %w", err)
|
|
}
|
|
|
|
// if token is empty, throw an error
|
|
if cf.Token == "" {
|
|
return nil, fmt.Errorf("API token is required. Set it via the HATCHET_CLIENT_TOKEN environment variable")
|
|
}
|
|
|
|
grpcBroadcastAddress := cf.HostPort
|
|
serverURL := cf.ServerURL
|
|
|
|
tokenConf, err := getConfFromJWT(cf.Token)
|
|
|
|
if err == nil {
|
|
if grpcBroadcastAddress == "" && tokenConf.grpcBroadcastAddress != "" {
|
|
grpcBroadcastAddress = tokenConf.grpcBroadcastAddress
|
|
}
|
|
|
|
if serverURL == "" && tokenConf.serverURL != "" {
|
|
serverURL = tokenConf.serverURL
|
|
}
|
|
}
|
|
|
|
// if there's no broadcast address at this point, throw an error
|
|
if grpcBroadcastAddress == "" {
|
|
return nil, fmt.Errorf("gRPC broadcast address is required. Set it via the HATCHET_CLIENT_HOST_PORT environment variable")
|
|
}
|
|
|
|
// if there's no server URL at this point, throw an error
|
|
if serverURL == "" {
|
|
return nil, fmt.Errorf("server URL is required. Set it via the HATCHET_CLIENT_SERVER_URL environment variable")
|
|
}
|
|
|
|
if cf.TenantId == "" {
|
|
cf.TenantId = tokenConf.tenantId
|
|
}
|
|
|
|
tlsServerName := cf.TLS.TLSServerName
|
|
|
|
// if the tls server name is empty, parse the domain from the host:port
|
|
if tlsServerName == "" {
|
|
// parse the domain from the host:port
|
|
domain, err := parseDomain(grpcBroadcastAddress)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not parse domain: %w", err)
|
|
}
|
|
|
|
tlsServerName = domain.Hostname()
|
|
}
|
|
|
|
tlsConf, err := loaderutils.LoadClientTLSConfig(&cf.TLS, tlsServerName)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not load TLS config: %w", err)
|
|
}
|
|
|
|
namespace := cf.Namespace
|
|
|
|
if namespace != "" {
|
|
namespace = strings.ToLower(namespace + "_")
|
|
}
|
|
|
|
var rawRunnableActions []string
|
|
if cf.RawRunnableActions != nil {
|
|
rawRunnableActions = []string{}
|
|
|
|
for _, action := range cf.RawRunnableActions {
|
|
rawRunnableActions = append(rawRunnableActions, namespace+strings.TrimSpace(action))
|
|
}
|
|
}
|
|
|
|
presetLabels := make(map[string]string)
|
|
|
|
if cf.AutoscalingTarget != "" {
|
|
presetLabels["hatchet-autoscaling-target"] = cf.AutoscalingTarget
|
|
}
|
|
|
|
return &client.ClientConfig{
|
|
TenantId: cf.TenantId,
|
|
TLSConfig: tlsConf,
|
|
Token: cf.Token,
|
|
ServerURL: serverURL,
|
|
GRPCBroadcastAddress: grpcBroadcastAddress,
|
|
Namespace: namespace,
|
|
CloudRegisterID: cf.CloudRegisterID,
|
|
RunnableActions: rawRunnableActions,
|
|
NoGrpcRetry: cf.NoGrpcRetry,
|
|
PresetWorkerLabels: presetLabels,
|
|
}, nil
|
|
}
|
|
|
|
func parseDomain(domain string) (*url.URL, error) {
|
|
if !strings.HasPrefix(domain, "http://") && !strings.HasPrefix(domain, "https://") {
|
|
domain = "https://" + domain
|
|
}
|
|
return url.Parse(domain)
|
|
}
|