mirror of
https://github.com/folbricht/routedns.git
synced 2026-04-29 04:29:17 -05:00
update
This commit is contained in:
+10
-11
@@ -13,19 +13,18 @@ const luaErrorMetatableName = "Error"
|
||||
func (s *LuaScript) RegisterErrorType() {
|
||||
L := s.L
|
||||
mt := L.NewTypeMetatable(luaErrorMetatableName)
|
||||
L.SetGlobal("Error", mt)
|
||||
L.SetGlobal(luaErrorMetatableName, mt)
|
||||
|
||||
// static attributes
|
||||
L.SetField(mt, "new", L.NewFunction(newError))
|
||||
L.SetField(mt, "new", L.NewFunction(
|
||||
func(L *lua.LState) int {
|
||||
err := errors.New(L.CheckString(1))
|
||||
L.Push(userDataWithMetatable(L, luaErrorMetatableName, err))
|
||||
return 1
|
||||
}))
|
||||
|
||||
// methods
|
||||
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
|
||||
"error": getter(errorGetError),
|
||||
"error": getter(func(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) }),
|
||||
}))
|
||||
}
|
||||
|
||||
func newError(L *lua.LState) int {
|
||||
err := errors.New(L.CheckString(1))
|
||||
L.Push(userDataWithMetatable(L, luaErrorMetatableName, err))
|
||||
return 1
|
||||
}
|
||||
|
||||
func errorGetError(L *lua.LState, r error) { L.Push(lua.LString(r.Error())) }
|
||||
|
||||
@@ -41,3 +41,13 @@ func setter[T any](f func(*lua.LState, T)) func(*lua.LState) int {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func getUserDataArg[T any](L *lua.LState, n int) (T, bool) {
|
||||
ud := L.CheckUserData(n)
|
||||
v, ok := ud.Value.(T)
|
||||
if !ok {
|
||||
L.ArgError(n, fmt.Sprintf("expected %v, got %T", reflect.TypeFor[T](), ud.Value))
|
||||
return v, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
+78
-38
@@ -14,45 +14,85 @@ func (s *LuaScript) RegisterMessageType() {
|
||||
mt := L.NewTypeMetatable(luaMessageMetatableName)
|
||||
L.SetGlobal(luaMessageMetatableName, mt)
|
||||
// static attributes
|
||||
L.SetField(mt, "new", L.NewFunction(newMessage))
|
||||
L.SetField(mt, "new", L.NewFunction(func(L *lua.LState) int {
|
||||
L.Push(userDataWithMetatable(L, luaMessageMetatableName, new(dns.Msg)))
|
||||
return 1
|
||||
}))
|
||||
// methods
|
||||
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
|
||||
"get_question": getter(messageGetQuestion),
|
||||
"set_question": setter(messageSetQuestion),
|
||||
"get_questions": getter(func(L *lua.LState, msg *dns.Msg) {
|
||||
table := L.CreateTable(len(msg.Question), 0)
|
||||
for _, q := range msg.Question {
|
||||
lv := userDataWithMetatable(L, luaQuestionMetatableName, &q)
|
||||
table.Append(lv)
|
||||
}
|
||||
L.Push(table)
|
||||
}),
|
||||
"set_questions": setter(func(L *lua.LState, msg *dns.Msg) {
|
||||
table := L.CheckTable(2)
|
||||
n := table.Len()
|
||||
questions := make([]dns.Question, 0, n)
|
||||
for i := range n {
|
||||
element := table.RawGetInt(i + 1)
|
||||
if element.Type() != lua.LTUserData {
|
||||
L.ArgError(1, "invalid type, expected userdata")
|
||||
return
|
||||
}
|
||||
lq := element.(*lua.LUserData)
|
||||
q, ok := lq.Value.(*dns.Question)
|
||||
if !ok {
|
||||
L.ArgError(1, "invalid type, expected question")
|
||||
return
|
||||
}
|
||||
questions = append(questions, *q)
|
||||
}
|
||||
msg.Question = questions
|
||||
}),
|
||||
"set_question": setter(func(L *lua.LState, msg *dns.Msg) {
|
||||
msg.SetQuestion(L.CheckString(2), uint16(L.CheckNumber(3)))
|
||||
}),
|
||||
"set_id": setter(func(L *lua.LState, msg *dns.Msg) {
|
||||
msg.Id = uint16(L.CheckInt(2))
|
||||
}),
|
||||
"get_id": getter(func(L *lua.LState, msg *dns.Msg) {
|
||||
L.Push(lua.LNumber(msg.Id))
|
||||
}),
|
||||
"set_response": setter(func(L *lua.LState, msg *dns.Msg) {
|
||||
msg.Response = L.CheckBool(2)
|
||||
}),
|
||||
"get_response": getter(func(L *lua.LState, msg *dns.Msg) {
|
||||
L.Push(lua.LBool(msg.Response))
|
||||
}),
|
||||
"set_reply": setter(func(L *lua.LState, msg *dns.Msg) {
|
||||
request, ok := getUserDataArg[*dns.Msg](L, 2)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
msg.SetReply(request)
|
||||
}),
|
||||
"set_rcode": setter(func(L *lua.LState, msg *dns.Msg) {
|
||||
msg.Rcode = L.CheckInt(2)
|
||||
}),
|
||||
"get_rcode": getter(func(L *lua.LState, msg *dns.Msg) {
|
||||
L.Push(lua.LNumber(msg.Rcode))
|
||||
}),
|
||||
"set_rd": setter(func(L *lua.LState, msg *dns.Msg) {
|
||||
msg.RecursionDesired = L.CheckBool(2)
|
||||
}),
|
||||
"get_rd": getter(func(L *lua.LState, msg *dns.Msg) {
|
||||
L.Push(lua.LBool(msg.RecursionDesired))
|
||||
}),
|
||||
"set_ra": setter(func(L *lua.LState, msg *dns.Msg) {
|
||||
msg.RecursionAvailable = L.CheckBool(2)
|
||||
}),
|
||||
"get_ra": getter(func(L *lua.LState, msg *dns.Msg) {
|
||||
L.Push(lua.LBool(msg.RecursionAvailable))
|
||||
}),
|
||||
"set_ad": setter(func(L *lua.LState, msg *dns.Msg) {
|
||||
msg.AuthenticatedData = L.CheckBool(2)
|
||||
}),
|
||||
"get_ad": getter(func(L *lua.LState, msg *dns.Msg) {
|
||||
L.Push(lua.LBool(msg.AuthenticatedData))
|
||||
}),
|
||||
}))
|
||||
}
|
||||
|
||||
func newMessage(L *lua.LState) int {
|
||||
L.Push(userDataWithMetatable(L, luaMessageMetatableName, new(dns.Msg)))
|
||||
return 1
|
||||
}
|
||||
|
||||
func messageGetQuestion(L *lua.LState, msg *dns.Msg) {
|
||||
table := L.CreateTable(len(msg.Question), 0)
|
||||
for _, q := range msg.Question {
|
||||
lv := userDataWithMetatable(L, luaQuestionMetatableName, &q)
|
||||
table.Append(lv)
|
||||
}
|
||||
L.Push(table)
|
||||
}
|
||||
|
||||
func messageSetQuestion(L *lua.LState, msg *dns.Msg) {
|
||||
table := L.CheckTable(2)
|
||||
n := table.Len()
|
||||
questions := make([]dns.Question, 0, n)
|
||||
for i := range n {
|
||||
element := table.RawGetInt(i + 1)
|
||||
if element.Type() != lua.LTUserData {
|
||||
L.ArgError(1, "invalid type, expected userdata")
|
||||
return
|
||||
}
|
||||
lq := element.(*lua.LUserData)
|
||||
q, ok := lq.Value.(*dns.Question)
|
||||
if !ok {
|
||||
L.ArgError(1, "invalid type, expected question")
|
||||
return
|
||||
}
|
||||
questions = append(questions, *q)
|
||||
}
|
||||
msg.Question = questions
|
||||
}
|
||||
|
||||
+24
-21
@@ -12,29 +12,32 @@ const luaQuestionMetatableName = "Question"
|
||||
func (s *LuaScript) RegisterQuestionType() {
|
||||
L := s.L
|
||||
mt := L.NewTypeMetatable(luaQuestionMetatableName)
|
||||
L.SetGlobal("Question", mt)
|
||||
L.SetGlobal(luaQuestionMetatableName, mt)
|
||||
// static attributes
|
||||
L.SetField(mt, "new", L.NewFunction(newQuestion))
|
||||
L.SetField(mt, "new", L.NewFunction(
|
||||
func(L *lua.LState) int {
|
||||
q := &dns.Question{Qclass: dns.ClassINET}
|
||||
nArgs := L.GetTop()
|
||||
if nArgs >= 1 { // Name provided
|
||||
q.Name = L.CheckString(1)
|
||||
}
|
||||
if nArgs >= 2 { // Name and type
|
||||
q.Qtype = uint16(L.CheckNumber(2))
|
||||
}
|
||||
if nArgs >= 3 { // Name, type and class
|
||||
q.Qclass = uint16(L.CheckNumber(3))
|
||||
}
|
||||
L.Push(userDataWithMetatable(L, luaQuestionMetatableName, q))
|
||||
return 1
|
||||
}))
|
||||
|
||||
// methods
|
||||
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
|
||||
"get_name": getter(questionGetName),
|
||||
"get_qtype": getter(questionGetQType),
|
||||
"get_qclass": getter(questionGetQClass),
|
||||
"set_name": setter(questionSetName),
|
||||
"set_qtype": setter(questionSetQType),
|
||||
"set_qclass": setter(questionSetQClass),
|
||||
"get_name": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LString(r.Name)) }),
|
||||
"get_qtype": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qtype)) }),
|
||||
"get_qclass": getter(func(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qclass)) }),
|
||||
"set_name": setter(func(L *lua.LState, r *dns.Question) { r.Name = L.CheckString(2) }),
|
||||
"set_qtype": setter(func(L *lua.LState, r *dns.Question) { r.Qtype = uint16(L.CheckInt(2)) }),
|
||||
"set_qclass": setter(func(L *lua.LState, r *dns.Question) { r.Qclass = uint16(L.CheckInt(2)) }),
|
||||
}))
|
||||
}
|
||||
|
||||
func newQuestion(L *lua.LState) int {
|
||||
L.Push(userDataWithMetatable(L, luaQuestionMetatableName, new(dns.Question)))
|
||||
return 1
|
||||
}
|
||||
|
||||
func questionGetName(L *lua.LState, r *dns.Question) { L.Push(lua.LString(r.Name)) }
|
||||
func questionGetQType(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qtype)) }
|
||||
func questionGetQClass(L *lua.LState, r *dns.Question) { L.Push(lua.LNumber(r.Qclass)) }
|
||||
|
||||
func questionSetName(L *lua.LState, r *dns.Question) { r.Name = L.CheckString(2) }
|
||||
func questionSetQType(L *lua.LState, r *dns.Question) { r.Qtype = uint16(L.CheckInt(2)) }
|
||||
func questionSetQClass(L *lua.LState, r *dns.Question) { r.Qclass = uint16(L.CheckInt(2)) }
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package rdns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
@@ -62,13 +59,3 @@ func resolverResolve(L *lua.LState) int {
|
||||
|
||||
return 2
|
||||
}
|
||||
|
||||
func getUserDataArg[T any](L *lua.LState, n int) (T, bool) {
|
||||
ud := L.CheckUserData(n)
|
||||
v, ok := ud.Value.(T)
|
||||
if !ok {
|
||||
L.ArgError(n, fmt.Sprintf("expected %v, got %T", reflect.TypeFor[T](), ud.Value))
|
||||
return v, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
+23
-1
@@ -1,6 +1,28 @@
|
||||
package rdns
|
||||
|
||||
import lua "github.com/yuin/gopher-lua"
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
func (s *LuaScript) RegisterConstants() {
|
||||
L := s.L
|
||||
|
||||
// Register TypeA, TypeAAAA, etc
|
||||
for value, name := range dns.TypeToString {
|
||||
L.SetGlobal("Type"+name, lua.LNumber(value))
|
||||
}
|
||||
|
||||
// Register ClassINET, etc
|
||||
for value, name := range dns.ClassToString {
|
||||
L.SetGlobal("Class"+name, lua.LNumber(value))
|
||||
}
|
||||
|
||||
// Register Rcodes, RcodeNOERROR, RcodeNXDOMAIN, etc
|
||||
for value, name := range dns.RcodeToString {
|
||||
L.SetGlobal("Rcode"+name, lua.LNumber(value))
|
||||
}
|
||||
}
|
||||
|
||||
func userDataWithMetatable(L *lua.LState, mtName string, value any) *lua.LUserData {
|
||||
ud := L.NewUserData()
|
||||
|
||||
@@ -96,6 +96,7 @@ func (r *Lua) newScript() (*LuaScript, error) {
|
||||
}
|
||||
|
||||
// Register types and methods
|
||||
s.RegisterConstants()
|
||||
s.RegisterMessageType()
|
||||
s.RegisterQuestionType()
|
||||
s.RegisterErrorType()
|
||||
|
||||
+48
-17
@@ -10,8 +10,6 @@ import (
|
||||
func TestLuaSimplePassthrough(t *testing.T) {
|
||||
opt := LuaOptions{
|
||||
Script: `
|
||||
function resolve(msg, ci)
|
||||
end
|
||||
function Resolve(msg, ci)
|
||||
local resolver = Resolvers[1]
|
||||
local answer, err = resolver:resolve(msg, ci)
|
||||
@@ -70,28 +68,61 @@ end`,
|
||||
}
|
||||
|
||||
func TestLuaStaticAnswer(t *testing.T) {
|
||||
opt := LuaOptions{
|
||||
Script: `
|
||||
tests := map[string]LuaOptions{
|
||||
"set_questions": {
|
||||
Script: `
|
||||
function Resolve(msg, ci)
|
||||
local question = Question.new("example.com.", TypeA)
|
||||
local answer = Message.new()
|
||||
local question = Question.new()
|
||||
question:set_name("example.com.")
|
||||
answer:set_question({question})
|
||||
answer:set_id(msg:get_id())
|
||||
answer:set_questions({question})
|
||||
answer:set_response(true)
|
||||
answer:set_rcode(RcodeNXDOMAIN)
|
||||
return answer, nil
|
||||
end`,
|
||||
},
|
||||
"set_question": {
|
||||
Script: `
|
||||
function Resolve(msg, ci)
|
||||
local answer = Message.new()
|
||||
answer:set_question("example.com.", TypeA)
|
||||
answer:set_id(msg:get_id())
|
||||
answer:set_response(true)
|
||||
answer:set_rcode(RcodeNXDOMAIN)
|
||||
return answer, nil
|
||||
end`,
|
||||
},
|
||||
"set_reply": {
|
||||
Script: `
|
||||
function Resolve(msg, ci)
|
||||
local answer = Message.new()
|
||||
answer:set_reply(msg)
|
||||
answer:set_rcode(RcodeNXDOMAIN)
|
||||
return answer, nil
|
||||
end`,
|
||||
},
|
||||
}
|
||||
|
||||
var ci ClientInfo
|
||||
resolver := new(TestResolver)
|
||||
for name, opt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
var ci ClientInfo
|
||||
resolver := new(TestResolver)
|
||||
|
||||
r, err := NewLua("test-lua", opt, resolver)
|
||||
require.NoError(t, err)
|
||||
r, err := NewLua("test-lua", opt, resolver)
|
||||
require.NoError(t, err)
|
||||
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("example.com.", dns.TypeA)
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("example.com.", dns.TypeA)
|
||||
q.Id = 1234
|
||||
|
||||
answer, err := r.Resolve(q, ci)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, resolver.HitCount())
|
||||
require.Equal(t, "example.com.", answer.Question[0].Name)
|
||||
answer, err := r.Resolve(q, ci)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, resolver.HitCount())
|
||||
require.Equal(t, "example.com.", answer.Question[0].Name)
|
||||
require.Equal(t, dns.TypeA, answer.Question[0].Qtype)
|
||||
require.Equal(t, uint16(1234), answer.Id)
|
||||
require.Equal(t, dns.RcodeNameError, answer.Rcode)
|
||||
require.True(t, answer.Response)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user