This commit is contained in:
folbrich
2025-04-27 16:35:50 +02:00
parent 91833d0fd7
commit 293f299ad6
8 changed files with 194 additions and 101 deletions
+10 -11
View File
@@ -13,19 +13,18 @@ const luaErrorMetatableName = "Error"
func (s *LuaScript) RegisterErrorType() {
L := s.L
mt := L.NewTypeMetatable(luaErrorMetatableName)
L.SetGlobal("Error", mt)
L.SetGlobal(luaErrorMetatableName, mt)
// static attributes
L.SetField(mt, "new", L.NewFunction(newError))
L.SetField(mt, "new", L.NewFunction(
func(L *lua.LState) int {
err := errors.New(L.CheckString(1))
L.Push(userDataWithMetatable(L, luaErrorMetatableName, err))
return 1
}))
// methods
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
"error": getter(errorGetError),
"error": getter(func(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) }),
}))
}
func newError(L *lua.LState) int {
err := errors.New(L.CheckString(1))
L.Push(userDataWithMetatable(L, luaErrorMetatableName, err))
return 1
}
func errorGetError(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) }
+10
View File
@@ -41,3 +41,13 @@ func setter[T any](f func(*lua.LState, T)) func(*lua.LState) int {
return 1
}
}
func getUserDataArg[T any](L *lua.LState, n int) (T, bool) {
ud := L.CheckUserData(n)
v, ok := ud.Value.(T)
if !ok {
L.ArgError(n, fmt.Sprintf("expected %v, got %T", reflect.TypeFor[T](), ud.Value))
return v, false
}
return v, true
}
+78 -38
View File
@@ -14,45 +14,85 @@ func (s *LuaScript) RegisterMessageType() {
mt := L.NewTypeMetatable(luaMessageMetatableName)
L.SetGlobal(luaMessageMetatableName, mt)
// static attributes
L.SetField(mt, "new", L.NewFunction(newMessage))
L.SetField(mt, "new", L.NewFunction(func(L *lua.LState) int {
L.Push(userDataWithMetatable(L, luaMessageMetatableName, new(dns.Msg)))
return 1
}))
// methods
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
"get_question": getter(messageGetQuestion),
"set_question": setter(messageSetQuestion),
"get_questions": getter(func(L *lua.LState, msg *dns.Msg) {
table := L.CreateTable(len(msg.Question), 0)
for _, q := range msg.Question {
lv := userDataWithMetatable(L, luaQuestionMetatableName, &q)
table.Append(lv)
}
L.Push(table)
}),
"set_questions": setter(func(L *lua.LState, msg *dns.Msg) {
table := L.CheckTable(2)
n := table.Len()
questions := make([]dns.Question, 0, n)
for i := range n {
element := table.RawGetInt(i + 1)
if element.Type() != lua.LTUserData {
L.ArgError(1, "invalid type, expected userdata")
return
}
lq := element.(*lua.LUserData)
q, ok := lq.Value.(*dns.Question)
if !ok {
L.ArgError(1, "invalid type, expected question")
return
}
questions = append(questions, *q)
}
msg.Question = questions
}),
"set_question": setter(func(L *lua.LState, msg *dns.Msg) {
msg.SetQuestion(L.CheckString(2), uint16(L.CheckNumber(3)))
}),
"set_id": setter(func(L *lua.LState, msg *dns.Msg) {
msg.Id = uint16(L.CheckInt(2))
}),
"get_id": getter(func(L *lua.LState, msg *dns.Msg) {
L.Push(lua.LNumber(msg.Id))
}),
"set_response": setter(func(L *lua.LState, msg *dns.Msg) {
msg.Response = L.CheckBool(2)
}),
"get_response": getter(func(L *lua.LState, msg *dns.Msg) {
L.Push(lua.LBool(msg.Response))
}),
"set_reply": setter(func(L *lua.LState, msg *dns.Msg) {
request, ok := getUserDataArg[*dns.Msg](L, 2)
if !ok {
return
}
msg.SetReply(request)
}),
"set_rcode": setter(func(L *lua.LState, msg *dns.Msg) {
msg.Rcode = L.CheckInt(2)
}),
"get_rcode": getter(func(L *lua.LState, msg *dns.Msg) {
L.Push(lua.LNumber(msg.Rcode))
}),
"set_rd": setter(func(L *lua.LState, msg *dns.Msg) {
msg.RecursionDesired = L.CheckBool(2)
}),
"get_rd": getter(func(L *lua.LState, msg *dns.Msg) {
L.Push(lua.LBool(msg.RecursionDesired))
}),
"set_ra": setter(func(L *lua.LState, msg *dns.Msg) {
msg.RecursionAvailable = L.CheckBool(2)
}),
"get_ra": getter(func(L *lua.LState, msg *dns.Msg) {
L.Push(lua.LBool(msg.RecursionAvailable))
}),
"set_ad": setter(func(L *lua.LState, msg *dns.Msg) {
msg.AuthenticatedData = L.CheckBool(2)
}),
"get_ad": getter(func(L *lua.LState, msg *dns.Msg) {
L.Push(lua.LBool(msg.AuthenticatedData))
}),
}))
}
func newMessage(L *lua.LState) int {
L.Push(userDataWithMetatable(L, luaMessageMetatableName, new(dns.Msg)))
return 1
}
func messageGetQuestion(L *lua.LState, msg *dns.Msg) {
table := L.CreateTable(len(msg.Question), 0)
for _, q := range msg.Question {
lv := userDataWithMetatable(L, luaQuestionMetatableName, &q)
table.Append(lv)
}
L.Push(table)
}
func messageSetQuestion(L *lua.LState, msg *dns.Msg) {
table := L.CheckTable(2)
n := table.Len()
questions := make([]dns.Question, 0, n)
for i := range n {
element := table.RawGetInt(i + 1)
if element.Type() != lua.LTUserData {
L.ArgError(1, "invalid type, expected userdata")
return
}
lq := element.(*lua.LUserData)
q, ok := lq.Value.(*dns.Question)
if !ok {
L.ArgError(1, "invalid type, expected question")
return
}
questions = append(questions, *q)
}
msg.Question = questions
}
+24 -21
View File
@@ -12,29 +12,32 @@ const luaQuestionMetatableName = "Question"
func (s *LuaScript) RegisterQuestionType() {
L := s.L
mt := L.NewTypeMetatable(luaQuestionMetatableName)
L.SetGlobal("Question", mt)
L.SetGlobal(luaQuestionMetatableName, mt)
// static attributes
L.SetField(mt, "new", L.NewFunction(newQuestion))
L.SetField(mt, "new", L.NewFunction(
func(L *lua.LState) int {
q := &dns.Question{Qclass: dns.ClassINET}
nArgs := L.GetTop()
if nArgs >= 1 { // Name provided
q.Name = L.CheckString(1)
}
if nArgs >= 2 { // Name and type
q.Qtype = uint16(L.CheckNumber(2))
}
if nArgs >= 3 { // Name, type and class
q.Qclass = uint16(L.CheckNumber(3))
}
L.Push(userDataWithMetatable(L, luaQuestionMetatableName, q))
return 1
}))
// methods
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
"get_name": getter(questionGetName),
"get_qtype": getter(questionGetQType),
"get_qclass": getter(questionGetQClass),
"set_name": setter(questionSetName),
"set_qtype": setter(questionSetQType),
"set_qclass": setter(questionSetQClass),
"get_name": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LString(r.Name)) }),
"get_qtype": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qtype)) }),
"get_qclass": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qclass)) }),
"set_name": setter(func(L *lua.LState, r *dns.Question) { r.Name = L.CheckString(2) }),
"set_qtype": setter(func(L *lua.LState, r *dns.Question) { r.Qtype = uint16(L.CheckInt(2)) }),
"set_qclass": setter(func(L *lua.LState, r *dns.Question) { r.Qclass = uint16(L.CheckInt(2)) }),
}))
}
func newQuestion(L *lua.LState) int {
L.Push(userDataWithMetatable(L, luaQuestionMetatableName, new(dns.Question)))
return 1
}
func questionGetName(L *lua.LState, r *dns.Question) { L.Push(lua.LString(r.Name)) }
func questionGetQType(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qtype)) }
func questionGetQClass(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qclass)) }
func questionSetName(L *lua.LState, r *dns.Question) { r.Name = L.CheckString(2) }
func questionSetQType(L *lua.LState, r *dns.Question) { r.Qtype = uint16(L.CheckInt(2)) }
func questionSetQClass(L *lua.LState, r *dns.Question) { r.Qclass = uint16(L.CheckInt(2)) }
-13
View File
@@ -1,9 +1,6 @@
package rdns
import (
"fmt"
"reflect"
"github.com/miekg/dns"
lua "github.com/yuin/gopher-lua"
)
@@ -62,13 +59,3 @@ func resolverResolve(L *lua.LState) int {
return 2
}
func getUserDataArg[T any](L *lua.LState, n int) (T, bool) {
ud := L.CheckUserData(n)
v, ok := ud.Value.(T)
if !ok {
L.ArgError(n, fmt.Sprintf("expected %v, got %T", reflect.TypeFor[T](), ud.Value))
return v, false
}
return v, true
}
+23 -1
View File
@@ -1,6 +1,28 @@
package rdns
import lua "github.com/yuin/gopher-lua"
import (
"github.com/miekg/dns"
lua "github.com/yuin/gopher-lua"
)
func (s *LuaScript) RegisterConstants() {
L := s.L
// Register TypeA, TypeAAAA, etc
for value, name := range dns.TypeToString {
L.SetGlobal("Type"+name, lua.LNumber(value))
}
// Register ClassINET, etc
for value, name := range dns.ClassToString {
L.SetGlobal("Class"+name, lua.LNumber(value))
}
// Register Rcodes, RcodeNOERROR, RcodeNXDOMAIN, etc
for value, name := range dns.RcodeToString {
L.SetGlobal("Rcode"+name, lua.LNumber(value))
}
}
func userDataWithMetatable(L *lua.LState, mtName string, value any) *lua.LUserData {
ud := L.NewUserData()
+1
View File
@@ -96,6 +96,7 @@ func (r *Lua) newScript() (*LuaScript, error) {
}
// Register types and methods
s.RegisterConstants()
s.RegisterMessageType()
s.RegisterQuestionType()
s.RegisterErrorType()
+48 -17
View File
@@ -10,8 +10,6 @@ import (
func TestLuaSimplePassthrough(t *testing.T) {
opt := LuaOptions{
Script: `
function resolve(msg, ci)
end
function Resolve(msg, ci)
local resolver = Resolvers[1]
local answer, err = resolver:resolve(msg, ci)
@@ -70,28 +68,61 @@ end`,
}
func TestLuaStaticAnswer(t *testing.T) {
opt := LuaOptions{
Script: `
tests := map[string]LuaOptions{
"set_questions": {
Script: `
function Resolve(msg, ci)
local question = Question.new("example.com.", TypeA)
local answer = Message.new()
local question = Question.new()
question:set_name("example.com.")
answer:set_question({question})
answer:set_id(msg:get_id())
answer:set_questions({question})
answer:set_response(true)
answer:set_rcode(RcodeNXDOMAIN)
return answer, nil
end`,
},
"set_question": {
Script: `
function Resolve(msg, ci)
local answer = Message.new()
answer:set_question("example.com.", TypeA)
answer:set_id(msg:get_id())
answer:set_response(true)
answer:set_rcode(RcodeNXDOMAIN)
return answer, nil
end`,
},
"set_reply": {
Script: `
function Resolve(msg, ci)
local answer = Message.new()
answer:set_reply(msg)
answer:set_rcode(RcodeNXDOMAIN)
return answer, nil
end`,
},
}
var ci ClientInfo
resolver := new(TestResolver)
for name, opt := range tests {
t.Run(name, func(t *testing.T) {
var ci ClientInfo
resolver := new(TestResolver)
r, err := NewLua("test-lua", opt, resolver)
require.NoError(t, err)
r, err := NewLua("test-lua", opt, resolver)
require.NoError(t, err)
q := new(dns.Msg)
q.SetQuestion("example.com.", dns.TypeA)
q := new(dns.Msg)
q.SetQuestion("example.com.", dns.TypeA)
q.Id = 1234
answer, err := r.Resolve(q, ci)
require.NoError(t, err)
require.Equal(t, 0, resolver.HitCount())
require.Equal(t, "example.com.", answer.Question[0].Name)
answer, err := r.Resolve(q, ci)
require.NoError(t, err)
require.Equal(t, 0, resolver.HitCount())
require.Equal(t, "example.com.", answer.Question[0].Name)
require.Equal(t, dns.TypeA, answer.Question[0].Qtype)
require.Equal(t, uint16(1234), answer.Id)
require.Equal(t, dns.RcodeNameError, answer.Rcode)
require.True(t, answer.Response)
})
}
}