Use a derive macro for service message parsing

This commit is contained in:
lcdr
2020-06-17 14:08:02 +02:00
parent ad72b7e3aa
commit de557f0a5e
5 changed files with 97 additions and 120 deletions
+6
View File
@@ -1,6 +1,7 @@
mod from_variants;
mod game_message;
mod gm_deserialize;
mod service_message;
use proc_macro::TokenStream;
@@ -18,3 +19,8 @@ pub fn derive_game_message_deserialize(input: TokenStream) -> TokenStream {
pub fn derive_gm_deserialize(input: TokenStream) -> TokenStream {
gm_deserialize::derive(input)
}
#[proc_macro_derive(ServiceMessage)]
pub fn derive_service_message(input: TokenStream) -> TokenStream {
service_message::derive(input)
}
+67
View File
@@ -0,0 +1,67 @@
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{parse_macro_input, parse_quote, Data, DataEnum, DeriveInput, Fields};
pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let data = match &input.data {
Data::Enum(data) => data,
_ => unimplemented!(),
};
let name = &input.ident;
let deser_code = gen_deser_code_enum(data, &name);
let impl_generics = &mut input.generics.clone();
impl_generics.params.push(parse_quote!(__READER: ::std::io::Read));
let (impl_generics, _, _) = impl_generics.split_for_impl();
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
(quote! {
impl #impl_generics ::endio::Deserialize<::endio::LE, __READER> for #name #ty_generics #where_clause {
fn deserialize(reader: &mut __READER) -> ::std::io::Result<Self> {
#deser_code
}
}
}).into()
}
fn gen_deser_code_fields(fields: &Fields) -> TokenStream {
match fields {
Fields::Named(_) => unimplemented!(),
Fields::Unnamed(fields) => {
let mut deser = vec![];
for _ in &fields.unnamed {
deser.push(quote! { ::endio::LERead::read(reader)?, });
}
quote! { ( #(#deser)* ) }
}
Fields::Unit => {
quote! { }
}
}
}
fn gen_deser_code_enum(data: &DataEnum, name: &Ident) -> TokenStream {
let last_disc: syn::ExprLit = parse_quote! { 0 };
let mut last_disc = &last_disc.into();
let mut disc_offset = 0;
let mut arms = vec![];
for f in &data.variants {
let ident = &f.ident;
if let Some((_, x)) = &f.discriminant {
last_disc = x;
disc_offset = 0;
}
let deser_fields = gen_deser_code_fields(&f.fields);
let arm = quote! { disc if disc == (#last_disc + (#disc_offset as u32)) => Self::#ident #deser_fields, };
disc_offset += 1;
arms.push(arm);
}
quote! {
let disc: u32 = ::endio::LERead::read(reader)?;
let _padding: u8 = ::endio::LERead::read(reader)?;
Ok(match disc {
#(#arms)*
_ => return ::std::result::Result::Err(::std::io::Error::new(::std::io::ErrorKind::InvalidData, format!("invalid discriminant value for {}: {}", stringify!(#name), disc)))
})
}
}
+5 -25
View File
@@ -1,10 +1,8 @@
//! All packets an auth server can receive.
use std::io::Result as Res;
use endio::Deserialize;
use lu_packets_derive::ServiceMessage;
use endio::{Deserialize, LERead};
use endio::LittleEndian as LE;
use crate::common::{err, LuWStr33, LuWStr41, LuWStr128, LuWStr256, ServiceId};
use crate::common::{LuWStr33, LuWStr41, LuWStr128, LuWStr256, ServiceId};
pub use crate::general::server::GeneralMessage;
pub type Message = crate::raknet::server::Message<LuMessage>;
@@ -17,30 +15,12 @@ pub enum LuMessage {
Auth(AuthMessage) = ServiceId::Auth as u16,
}
enum AuthId {
LoginRequest,
}
#[derive(Debug)]
#[derive(Debug, ServiceMessage)]
#[repr(u32)]
pub enum AuthMessage {
LoginRequest(LoginRequest)
}
impl<R: LERead> Deserialize<LE, R> for AuthMessage
where u8: Deserialize<LE, R>,
u32: Deserialize<LE, R>,
LoginRequest: Deserialize<LE, R>, {
fn deserialize(reader: &mut R) -> Res<Self> {
let packet_id: u32 = reader.read()?;
let _padding: u8 = reader.read()?;
if packet_id == AuthId::LoginRequest as u32 {
Ok(AuthMessage::LoginRequest(reader.read()?))
} else {
err("auth id", packet_id)
}
}
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub username: LuWStr33,
+4 -21
View File
@@ -2,33 +2,16 @@ use std::io::Result as Res;
use endio::{Deserialize, LERead};
use endio::LittleEndian as LE;
use lu_packets_derive::ServiceMessage;
use crate::common::{err, ServiceId};
use crate::common::ServiceId;
enum GeneralId {
Handshake,
}
#[derive(Debug)]
#[derive(Debug, ServiceMessage)]
#[repr(u32)]
pub enum GeneralMessage {
Handshake(Handshake)
}
impl<R: LERead> Deserialize<LE, R> for GeneralMessage
where u8: Deserialize<LE, R>,
u32: Deserialize<LE, R>,
Handshake: Deserialize<LE, R> {
fn deserialize(reader: &mut R) -> Res<Self> {
let packet_id: u32 = reader.read()?;
let _padding: u8 = reader.read()?;
if packet_id == GeneralId::Handshake as u32 {
Ok(GeneralMessage::Handshake(reader.read()?))
} else {
err("general id", packet_id)
}
}
}
#[derive(Debug)]
pub struct Handshake {
pub network_version: u32,
+15 -74
View File
@@ -6,6 +6,7 @@ use std::io::Result as Res;
use endio::{Deserialize, LERead};
use endio::LittleEndian as LE;
use lu_packets_derive::ServiceMessage;
use crate::common::{err, ObjId, LuWStr33, LuWStr42, LuStr33, ServiceId, ZoneId};
use crate::chat::server::ChatMessage;
@@ -23,81 +24,21 @@ pub enum LuMessage {
World(WorldMessage) = ServiceId::World as u16,
}
enum WorldId {
ClientValidation = 1,
CharacterListRequest = 2,
CharacterCreateRequest = 3,
CharacterLoginRequest = 4,
SubjectGameMessage = 5,
CharacterDeleteRequest = 6,
GeneralChatMessage = 14,
LevelLoadComplete = 19,
RouteMessage = 21,
StringCheck = 25,
RequestFreeTrialRefresh = 32,
UgcDownloadFailed = 120,
}
#[derive(Debug)]
#[derive(Debug, ServiceMessage)]
#[repr(u32)]
pub enum WorldMessage {
ClientValidation(ClientValidation),
CharacterListRequest,
CharacterCreateRequest(CharacterCreateRequest),
CharacterLoginRequest(CharacterLoginRequest),
SubjectGameMessage(SubjectGameMessage),
CharacterDeleteRequest(CharacterDeleteRequest),
GeneralChatMessage(GeneralChatMessage),
LevelLoadComplete(LevelLoadComplete),
RouteMessage(RouteMessage),
StringCheck(StringCheck),
RequestFreeTrialRefresh,
UgcDownloadFailed(UgcDownloadFailed),
}
impl<R: LERead> Deserialize<LE, R> for WorldMessage
where u8: Deserialize<LE, R>,
u32: Deserialize<LE, R>,
ClientValidation: Deserialize<LE, R>,
CharacterCreateRequest: Deserialize<LE, R>,
CharacterLoginRequest: Deserialize<LE, R>,
SubjectGameMessage: Deserialize<LE, R>,
CharacterDeleteRequest: Deserialize<LE, R>,
GeneralChatMessage: Deserialize<LE, R>,
LevelLoadComplete: Deserialize<LE, R>,
RouteMessage: Deserialize<LE, R>,
StringCheck: Deserialize<LE, R>,
UgcDownloadFailed: Deserialize<LE, R> {
fn deserialize(reader: &mut R) -> Res<Self> {
let packet_id: u32 = reader.read()?;
let _padding: u8 = reader.read()?;
if packet_id == WorldId::ClientValidation as u32 {
Ok(Self::ClientValidation(reader.read()?))
} else if packet_id == WorldId::CharacterListRequest as u32 {
Ok(Self::CharacterListRequest)
} else if packet_id == WorldId::CharacterCreateRequest as u32 {
Ok(Self::CharacterCreateRequest(reader.read()?))
} else if packet_id == WorldId::CharacterLoginRequest as u32 {
Ok(Self::CharacterLoginRequest(reader.read()?))
} else if packet_id == WorldId::SubjectGameMessage as u32 {
Ok(Self::SubjectGameMessage(reader.read()?))
} else if packet_id == WorldId::CharacterDeleteRequest as u32 {
Ok(Self::CharacterDeleteRequest(reader.read()?))
} else if packet_id == WorldId::GeneralChatMessage as u32 {
Ok(Self::GeneralChatMessage(reader.read()?))
} else if packet_id == WorldId::LevelLoadComplete as u32 {
Ok(Self::LevelLoadComplete(reader.read()?))
} else if packet_id == WorldId::RouteMessage as u32 {
Ok(Self::RouteMessage(reader.read()?))
} else if packet_id == WorldId::StringCheck as u32 {
Ok(Self::StringCheck(reader.read()?))
} else if packet_id == WorldId::RequestFreeTrialRefresh as u32 {
Ok(Self::RequestFreeTrialRefresh)
} else if packet_id == WorldId::UgcDownloadFailed as u32 {
Ok(Self::UgcDownloadFailed(reader.read()?))
} else {
err("world id", packet_id)
}
}
ClientValidation(ClientValidation) = 1,
CharacterListRequest = 2,
CharacterCreateRequest(CharacterCreateRequest) = 3,
CharacterLoginRequest(CharacterLoginRequest) = 4,
SubjectGameMessage(SubjectGameMessage) = 5,
CharacterDeleteRequest(CharacterDeleteRequest) = 6,
GeneralChatMessage(GeneralChatMessage) = 14,
LevelLoadComplete(LevelLoadComplete) = 19,
RouteMessage(RouteMessage) = 21,
StringCheck(StringCheck) = 25,
RequestFreeTrialRefresh = 32,
UgcDownloadFailed(UgcDownloadFailed) = 120,
}
#[derive(Debug)]