Change auth to SSO au

This commit is contained in:
lcdr
2020-02-23 13:21:54 +01:00
parent 6ebcb66325
commit fea8b1642c
9 changed files with 18 additions and 190 deletions

View File

@@ -1,6 +1,7 @@
syntax: glob
runtime/db/server_db*
runtime/db/cdclient.sqlite
runtime/db/client/*
runtime/packets/*
runtime/logs/*
documentation/*

View File

@@ -1,11 +1,5 @@
from typing import Dict, TYPE_CHECKING
try:
import bcrypt
from passlib.hash import bcrypt as hash
except ImportError:
from passlib.hash import pbkdf2_sha256 as hash
from persistent import Persistent
from persistent.mapping import PersistentMapping
if TYPE_CHECKING:
@@ -32,158 +26,6 @@ class Account(Persistent):
def set_password(self, password: str) -> None:
self.password = hash.hash(password)
import asyncio
import datetime
import logging
import random
import secrets
import time
from ssl import SSLContext
from typing import Optional
from bitstream import c_bool, c_ubyte, c_uint, c_ushort, ReadStream
from pyraknet.messages import Address
from . import commonserver
from .bitstream import WriteStream
from .messages import AuthServerMsg, MessageType, WorldClientMsg
log = logging.getLogger(__name__)
class LoginError(Exception):
pass
class _LoginReturnCode:
GeneralFailure = 0
Success = 1
AccountBanned = 2
InsufficientAccountPermissions = 5
InvalidUsernameOrPassword = 6
AccountLocked = 7 # when wrong password is entered too many times
# 8 is the same as 6, possible distinction between username/password?
AccountActivationPending = 9
AccountDisabled = 10
GameTimeExpired = 11
FreeTrialEnded = 12
PlaySchedule = 13
AccountNotActivated = 14
class _LoginMessage:
AccountBanned = "You have been banned until %s. If you believe this was in error, contact the server operator."
PasswordIsTemp = "Your password is one-use-only.\nSign in again to set the used password as your permanent password."
PasswordSet = "Password has been set."
SameTempPassword = "Password must not be the same as temporary password"
class AuthServer(commonserver.Server):
_PEER_TYPE = MessageType.AuthServer.value
def __init__(self, host: str, max_connections: int, db_conn, ssl: Optional[SSLContext]):
super().__init__((host, 1001), max_connections, db_conn, ssl)
self.db.servers.clear()
self.conn.transaction_manager.commit()
self._dispatcher.add_listener(AuthServerMsg.LoginRequest, self._on_log_req)
def _on_log_req(self, stream: ReadStream, conn: Connection) -> None:
asyncio.ensure_future(self._on_login_request(stream, conn))
async def _on_login_request(self, request: ReadStream, conn: Connection) -> None:
return_code = _LoginReturnCode.InsufficientAccountPermissions # needed to display error message
message = ""
redirect_host, redirect_port = "", 0
session_key = ""
try:
if not self.db.config["auth_enabled"]:
raise LoginError(self.db.config["auth_disabled_message"])
self.conn.sync()
username = request.read(str, allocated_length=33)
password = request.read(str, allocated_length=41)
if username not in self.db.accounts:
log.info("Login attempt with invalid username %s", username)
raise LoginError(_LoginReturnCode.InvalidUsernameOrPassword)
account = self.db.accounts[username]
if account.gm_level != GMLevel.Admin and account.banned_until > time.time():
raise LoginError(_LoginMessage.AccountBanned % datetime.datetime.fromtimestamp(account.banned_until))
if account.password_state == PasswordState.AcceptNew:
if hash.verify(password, account.password):
raise LoginError(_LoginMessage.SameTempPassword)
account.password = hash.hash(password)
account.password_state = PasswordState.Set
self.conn.transaction_manager.commit()
raise LoginError(_LoginMessage.PasswordSet)
if not hash.verify(password, account.password):
log.info("Login attempt with username %s and invalid password", username)
raise LoginError(_LoginReturnCode.InvalidUsernameOrPassword)
if account.password_state == PasswordState.Temp:
account.password_state = PasswordState.AcceptNew
self.conn.transaction_manager.commit()
raise LoginError(_LoginMessage.PasswordIsTemp)
"""
if account.address is not None and account.address != address:
log.info("Disconnecting duplicate at %s", account.address)
self.close_connection(account.address, server.DisconnectReason.DuplicateLogin)
duplicate_notify = WriteStream()
duplicate_notify.write_header(GeneralMsg.GeneralNotify)
duplicate_notify.write(c_uint(server.NotifyReason.DuplicateDisconnected))
self.send(duplicate_notify, address)
"""
session_key = secrets.token_hex(16)
#account.address = address
account.session_key = session_key
self.conn.transaction_manager.commit()
redirect_host, redirect_port = await self.address_for_world((0, 0, 0), conn.get_type())
log.info("Logging in %s to world %s with key %s", username, (redirect_host, redirect_port), session_key)
except LoginError as e:
if isinstance(e.args[0], str):
message = str(e)
else:
return_code = e.args[0]
except Exception:
import traceback
traceback.print_exc()
message = "Server error during login, contact server operator"
else:
return_code = _LoginReturnCode.Success
response = WriteStream()
response.write_header(WorldClientMsg.LoginResponse)
response.write(c_ubyte(return_code))
response.write(bytes(264))
# client version
response.write(c_ushort(1))
response.write(c_ushort(10))
response.write(c_ushort(64))
first_time_with_subscription = False # not implemented
is_ftp = False # not implemented
response.write(session_key, allocated_length=33)
response.write(redirect_host.encode("latin1"), allocated_length=33)
response.write(bytes(33))
response.write(c_ushort(redirect_port))
response.write(bytes(35))
response.write(bytes(36)) # b"00000000-0000-0000-0000-000000000000"
response.write(bytes(1)) # possibly terminator of the previous
response.write(bytes(4))
response.write(bytes(2)) # b"US"
response.write(bytes(1)) # possibly terminator of the previous
response.write(c_bool(first_time_with_subscription))
response.write(c_bool(is_ftp))
response.write(bytes(8)) # b"\x99\x0f\x05\x00\x00\x00\x00\x00"
response.write(message, length_type=c_ushort) # custom error message
response.write(c_uint(4)) # length of remaining bytes including this
# remaining would be optional debug "stamps"
conn.send(bytes(response))
class GMLevel:
Nothing = 0
Mod = 50

View File

@@ -2,7 +2,6 @@ from enum import Enum, IntEnum
class MessageType(Enum):
General = 0
AuthServer = 1
Social = 2
WorldServer = 4
WorldClient = 5
@@ -15,9 +14,6 @@ class GeneralMsg(LUMessage):
DisconnectNotify = 0x01
GeneralNotify = 0x02
class AuthServerMsg(LUMessage):
LoginRequest = 0x00
class SocialMsg(LUMessage):
GeneralChatMessage = 0x01
PrivateChatMessage = 0x02
@@ -66,7 +62,6 @@ class WorldClientMsg(LUMessage):
# Sadly no better way to get a mapping from headers to enums
MSG_TO_ENUM = {
MessageType.General.value: GeneralMsg,
MessageType.AuthServer.value: AuthServerMsg,
MessageType.Social.value: SocialMsg,
MessageType.WorldServer.value: WorldServerMsg,
MessageType.WorldClient.value: WorldClientMsg}

View File

@@ -71,6 +71,7 @@ import atexit
import importlib.util
import logging
import os.path
import urllib.request
from contextlib import AbstractContextManager as ACM
from ssl import SSLContext
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
@@ -118,13 +119,14 @@ class MultiInstanceAccess(ACM):
class WorldServer(Server):
_PEER_TYPE = MessageType.WorldServer.value
def __init__(self, address: Address, external_host: str, world_id: Tuple[int, int], max_connections: int, db_conn: Connection, ssl: Optional[SSLContext]):
def __init__(self, address: Address, external_host: str, verify_address: str, world_id: Tuple[int, int], max_connections: int, db_conn: Connection, ssl: Optional[SSLContext]):
excluded_packets = {"PositionUpdate", "GameMessage/DropClientLoot", "GameMessage/PickupItem", "GameMessage/ReadyForUpdates", "GameMessage/ScriptNetworkVarUpdate"}
super().__init__(address, max_connections, db_conn, ssl, excluded_packets)
self.replica_manager = ReplicaManager(self._dispatcher)
global _server
_server = self
self.external_host = external_host
self.verify_address = verify_address
self._dispatcher.add_listener(TransportEvent.NetworkInit, self._on_network_init)
self._dispatcher.add_listener(ConnectionEvent.Close, self._on_conn_close)
self.multi = MultiInstanceAccess()
@@ -149,7 +151,6 @@ class WorldServer(Server):
self.accounts: Dict[Connection, Account] = {}
atexit.register(self.shutdown)
asyncio.get_event_loop().call_later(60, self._autosave)
asyncio.get_event_loop().call_later(60 * 60, self._check_shutdown)
self._dispatcher.add_listener(WorldServerMsg.SessionInfo, self._on_session_info)
self._load_plugins()
self.set_world_id(world_id)
@@ -185,7 +186,11 @@ class WorldServer(Server):
def set_world_id(self, world_id: Tuple[int, int]) -> None:
self.world_id = world_id[0], self.instance_id, world_id[1]
if self.world_id[0] != 0: # char
if self.world_id[0] == 0: # char
self.db.servers.clear()
self.conn.transaction_manager.commit()
else:
asyncio.get_event_loop().call_later(60 * 60, self._check_shutdown)
custom_script, world_control_lot = self.db.world_info[self.world_id[0]]
if world_control_lot is None:
world_control_lot = 2365
@@ -250,12 +255,15 @@ class WorldServer(Server):
username = session_info.read(str, allocated_length=33)
session_key = session_info.read(str, allocated_length=33)
if username not in self.db.accounts:
try:
key_valid = urllib.request.urlopen(self.verify_address+f"/verify/{username}/{session_key}").read() == b"1"
except Exception:
log.error("User %s not found in database", username)
conn.close()
return
if self.db.accounts[username].session_key != session_key:
log.error("Database session key %s does not match supplied session key %s", self.db.accounts[username].session_key, session_key)
if not key_valid:
log.error(f"Supplied session key {session_key} for user {username} is invalid")
self.close_connection(conn, reason=DisconnectReason.InvalidSessionKey)
return

View File

@@ -1,4 +1,3 @@
passlib
toml
ZEO
hg+https://bitbucket.org/lcdr/pyraknet/

View File

@@ -10,7 +10,6 @@ import toml
import ZEO
from pyraknet.transports.abc import ConnectionType
from luserver.auth import AuthServer
from luserver.world import WorldServer
with open(os.path.normpath(os.path.join(__file__, "..", "instance.toml"))) as file:
@@ -31,7 +30,7 @@ else:
if len(sys.argv) == 1:
instance_id = "auth"
instance_id = "char"
else:
instance_id = sys.argv[1]+" "+sys.argv[2]
@@ -66,7 +65,7 @@ else:
context = None
if len(sys.argv) == 1:
a = AuthServer(config["connection"]["internal_host"], max_connections=8, db_conn=conn, ssl=context)
WorldServer((config["connection"]["internal_host"], 9999), config["connection"]["external_host"], config["auth"]["verify_address"], world_id=(0, 0), max_connections=8, db_conn=conn, ssl=context)
else:
world_id = int(sys.argv[1]), int(sys.argv[2])
if len(sys.argv) == 4:
@@ -86,7 +85,7 @@ else:
sys.exit()
else:
port = 0
WorldServer((config["connection"]["internal_host"], port), config["connection"]["external_host"], world_id, max_connections=8, db_conn=conn, ssl=context)
WorldServer((config["connection"]["internal_host"], port), config["connection"]["external_host"], config["auth"]["verify_address"], world_id, max_connections=8, db_conn=conn, ssl=context)
loop = asyncio.get_event_loop()
loop.run_forever()

View File

@@ -101,7 +101,6 @@ class Init:
def gen_config(self):
self.root.config = PersistentMapping()
self.root.config["auth_enabled"] = True
self.root.config["credits"] = "Created by lcdr"
for entry in self.config["defaults"]:
self.root.config[entry] = self.config["defaults"][entry]

View File

@@ -3,10 +3,8 @@ import code
import transaction
import ZEO
from luserver.auth import AuthServer
from luserver.world import WorldServer
conn = ZEO.connection(12345)
root = conn.root
c = transaction.commit

View File

@@ -22,19 +22,6 @@ from luserver.math.vector import Vector3
log = logging.getLogger(__name__)
class Auth(ChatCommand):
def __init__(self):
super().__init__("auth")
self.command.add_argument("enabled", type=normal_bool)
self.command.add_argument("--message", nargs="+")
def run(self, args, sender):
with server.multi:
server.db.config["auth_enabled"] = args.enabled
server.chat.sys_msg_sender("Auth is now %s" % args.enabled)
if args.message is not None:
server.db.config["auth_disabled_message"] = " ".join(args.message)
class Ban(ChatCommand):
def __init__(self):
super().__init__("ban")