Files
routedns/lua.go
folbrich f9cc8ac892 dev
2025-06-01 17:13:33 +02:00

116 lines
2.3 KiB
Go

package rdns
import (
"errors"
"fmt"
"strings"
"github.com/miekg/dns"
)
type Lua struct {
id string
resolvers []Resolver
scripts chan *LuaScript
bytecode ByteCode
opt LuaOptions
}
var _ Resolver = &Lua{}
type LuaOptions struct {
Script string
Concurrency uint
}
func NewLua(id string, opt LuaOptions, resolvers ...Resolver) (*Lua, error) {
if opt.Concurrency == 0 {
opt.Concurrency = 4
}
// Compile the script
bytecode, err := LuaCompile(strings.NewReader(opt.Script), id)
if err != nil {
return nil, err
}
r := &Lua{
id: id,
resolvers: resolvers,
opt: opt,
scripts: make(chan *LuaScript, opt.Concurrency),
bytecode: bytecode,
}
// Initialize scripts
for range opt.Concurrency {
s, err := r.newScript()
if err != nil {
return nil, err
}
r.scripts <- s
}
return r, nil
}
func (r *Lua) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
s := <-r.scripts
defer func() { r.scripts <- s }()
log := logger(r.id, q, ci)
// Call the "resolve" function in the script. It should return 2 values.
ret, err := s.Call("Resolve", 2, q, ci)
if err != nil {
log.Error("failed to run lua script", "error", err)
return nil, err
}
// Extract the answer and error from the returned values
if len(ret) != 2 {
return nil, fmt.Errorf("invalid return value, expected 2, got %d", len(ret))
}
answer, ok := ret[0].(*dns.Msg)
if ret[0] != nil && !ok {
return nil, fmt.Errorf("invalid return value, expected Message, got %T", ret[0])
}
err, ok = ret[1].(error)
if ret[1] != nil && !ok {
return nil, fmt.Errorf("invalid return value, expected Error, got %T", ret[1])
}
return answer, err
}
func (r *Lua) String() string {
return r.id
}
func (r *Lua) newScript() (*LuaScript, error) {
s, err := NewScriptFromByteCode(r.bytecode)
if err != nil {
return nil, err
}
// Register types and methods
s.RegisterConstants()
s.RegisterMessageType()
s.RegisterQuestionType()
s.RegisterRRTypes()
s.RegisterEDNS0Types()
s.RegisterErrorType()
// Inject the resolvers into the state (so they can be used in the script)
s.InjectResolvers(r.resolvers)
// The script must contain a Resolve() function which is the entry point
if !s.HasFunction("Resolve") {
return nil, errors.New("no Resolve() function found in lua script")
}
return s, nil
}