mirror of
https://github.com/appium/appium.git
synced 2026-02-17 07:19:55 -06:00
fix(base-driver): Use shouldUpgradeCallback for proper upgrades processing (#21706)
This commit is contained in:
@@ -86,6 +86,39 @@ export function defaultToJSONContentType(req, res, next) {
|
||||
next();
|
||||
}
|
||||
|
||||
/**
|
||||
* Core function to handle WebSocket upgrade requests by matching the request path
|
||||
* against registered WebSocket handlers in the webSocketsMapping.
|
||||
*
|
||||
* @param {import('http').IncomingMessage} req - The HTTP request
|
||||
* @param {import('stream').Duplex} socket - The network socket
|
||||
* @param {Buffer} head - The first packet of the upgraded stream
|
||||
* @param {import('@appium/types').StringRecord<import('@appium/types').WSServer>} webSocketsMapping - Mapping of paths to WebSocket servers
|
||||
* @returns {boolean} - Returns true if the upgrade was handled, false otherwise
|
||||
*/
|
||||
export function tryHandleWebSocketUpgrade(req, socket, head, webSocketsMapping) {
|
||||
if (_.toLower(req.headers?.upgrade) !== 'websocket') {
|
||||
return false;
|
||||
}
|
||||
|
||||
let currentPathname;
|
||||
try {
|
||||
currentPathname = new URL(req.url ?? '', 'http://localhost').pathname;
|
||||
} catch {
|
||||
currentPathname = req.url ?? '';
|
||||
}
|
||||
for (const [pathname, wsServer] of _.toPairs(webSocketsMapping)) {
|
||||
if (match(pathname)(currentPathname)) {
|
||||
wsServer.handleUpgrade(req, socket, head, (ws) => {
|
||||
wsServer.emit('connection', ws, req);
|
||||
});
|
||||
return true;
|
||||
}
|
||||
}
|
||||
log.info(`Did not match the websocket upgrade request at ${currentPathname} to any known route`);
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {import('@appium/types').StringRecord<import('@appium/types').WSServer>} webSocketsMapping
|
||||
@@ -93,23 +126,9 @@ export function defaultToJSONContentType(req, res, next) {
|
||||
*/
|
||||
export function handleUpgrade(webSocketsMapping) {
|
||||
return (req, res, next) => {
|
||||
if (!req.headers?.upgrade || _.toLower(req.headers.upgrade) !== 'websocket') {
|
||||
return next();
|
||||
if (tryHandleWebSocketUpgrade(req, req.socket, Buffer.from(''), webSocketsMapping)) {
|
||||
return;
|
||||
}
|
||||
let currentPathname;
|
||||
try {
|
||||
currentPathname = new URL(req.url ?? '').pathname;
|
||||
} catch {
|
||||
currentPathname = req.url ?? '';
|
||||
}
|
||||
for (const [pathname, wsServer] of _.toPairs(webSocketsMapping)) {
|
||||
if (match(pathname)(currentPathname)) {
|
||||
return wsServer.handleUpgrade(req, req.socket, Buffer.from(''), (ws) => {
|
||||
wsServer.emit('connection', ws, req);
|
||||
});
|
||||
}
|
||||
}
|
||||
log.info(`Did not match the websocket upgrade request at ${currentPathname} to any known route`);
|
||||
next();
|
||||
};
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
allowCrossDomainAsyncExecute,
|
||||
handleIdempotency,
|
||||
handleUpgrade,
|
||||
tryHandleWebSocketUpgrade,
|
||||
catch404Handler,
|
||||
handleLogContext,
|
||||
} from './middleware';
|
||||
@@ -31,6 +32,158 @@ import {fs, timing} from '@appium/support';
|
||||
|
||||
const KEEP_ALIVE_TIMEOUT_MS = 10 * 60 * 1000; // 10 minutes
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {ServerOpts} opts
|
||||
* @returns {Promise<AppiumServer>}
|
||||
*/
|
||||
export async function server(opts) {
|
||||
const {
|
||||
routeConfiguringFunction,
|
||||
port,
|
||||
hostname,
|
||||
cliArgs = /** @type {import('@appium/types').ServerArgs} */ ({}),
|
||||
allowCors = true,
|
||||
basePath = DEFAULT_BASE_PATH,
|
||||
extraMethodMap = {},
|
||||
serverUpdaters = [],
|
||||
keepAliveTimeout = KEEP_ALIVE_TIMEOUT_MS,
|
||||
requestTimeout,
|
||||
} = opts;
|
||||
|
||||
const app = express();
|
||||
const httpServer = await createServer(app, cliArgs);
|
||||
|
||||
return await new B(async (resolve, reject) => {
|
||||
// we put an async function as the promise constructor because we want some things to happen in
|
||||
// serial (application of plugin updates, for example). But we still need to use a promise here
|
||||
// because some elements of server start failure only happen in httpServer listeners. So the
|
||||
// way we resolve it is to use an async function here but to wrap all the inner logic in
|
||||
// try/catch so any errors can be passed to reject.
|
||||
try {
|
||||
const appiumServer = configureHttp({
|
||||
httpServer,
|
||||
reject,
|
||||
keepAliveTimeout,
|
||||
gracefulShutdownTimeout: cliArgs.shutdownTimeout,
|
||||
});
|
||||
const useLegacyUpgradeHandler = !hasShouldUpgradeCallback(httpServer);
|
||||
configureServer({
|
||||
app,
|
||||
addRoutes: routeConfiguringFunction,
|
||||
allowCors,
|
||||
basePath,
|
||||
extraMethodMap,
|
||||
webSocketsMapping: appiumServer.webSocketsMapping,
|
||||
useLegacyUpgradeHandler,
|
||||
});
|
||||
// allow extensions to update the app and http server objects
|
||||
for (const updater of serverUpdaters) {
|
||||
await updater(app, appiumServer, cliArgs);
|
||||
}
|
||||
|
||||
// once all configurations and updaters have been applied, make sure to set up a catchall
|
||||
// handler so that anything unknown 404s. But do this after everything else since we don't
|
||||
// want to block extensions' ability to add routes if they want.
|
||||
app.all('/*all', catch404Handler);
|
||||
|
||||
await startServer({
|
||||
httpServer,
|
||||
hostname,
|
||||
port,
|
||||
keepAliveTimeout,
|
||||
requestTimeout,
|
||||
});
|
||||
|
||||
resolve(appiumServer);
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up some Express middleware and stuff
|
||||
* @param {ConfigureServerOpts} opts
|
||||
*/
|
||||
export function configureServer({
|
||||
app,
|
||||
addRoutes,
|
||||
allowCors = true,
|
||||
basePath = DEFAULT_BASE_PATH,
|
||||
extraMethodMap = {},
|
||||
webSocketsMapping = {},
|
||||
useLegacyUpgradeHandler = true,
|
||||
}) {
|
||||
basePath = normalizeBasePath(basePath);
|
||||
|
||||
app.use(endLogFormatter);
|
||||
app.use(handleLogContext);
|
||||
|
||||
// set up static assets
|
||||
app.use(favicon(path.resolve(STATIC_DIR, 'favicon.ico')));
|
||||
// eslint-disable-next-line import/no-named-as-default-member
|
||||
app.use(express.static(STATIC_DIR));
|
||||
|
||||
// crash routes, for testing
|
||||
app.use(`${basePath}/produce_error`, produceError);
|
||||
app.use(`${basePath}/crash`, produceCrash);
|
||||
|
||||
// Only use legacy Express middleware for WebSocket upgrades if shouldUpgradeCallback is not available
|
||||
// When shouldUpgradeCallback is available, upgrades are handled directly on the HTTP server
|
||||
// to avoid Express middleware timeout issues with long-lived connections
|
||||
if (useLegacyUpgradeHandler) {
|
||||
app.use(handleUpgrade(webSocketsMapping));
|
||||
}
|
||||
if (allowCors) {
|
||||
app.use(allowCrossDomain);
|
||||
} else {
|
||||
app.use(allowCrossDomainAsyncExecute(basePath));
|
||||
}
|
||||
app.use(handleIdempotency);
|
||||
app.use(defaultToJSONContentType);
|
||||
app.use(bodyParser.urlencoded({extended: true}));
|
||||
app.use(methodOverride());
|
||||
app.use(catchAllHandler);
|
||||
|
||||
// make sure appium never fails because of a file size upload limit
|
||||
app.use(bodyParser.json({limit: '1gb'}));
|
||||
|
||||
// set up start logging (which depends on bodyParser doing its thing)
|
||||
app.use(startLogFormatter);
|
||||
|
||||
addRoutes(app, {basePath, extraMethodMap});
|
||||
|
||||
// dynamic routes for testing, etc.
|
||||
app.all('/welcome', welcome);
|
||||
app.all('/test/guinea-pig', guineaPig);
|
||||
app.all('/test/guinea-pig-scrollable', guineaPigScrollable);
|
||||
app.all('/test/guinea-pig-app-banner', guineaPigAppBanner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize base path string
|
||||
* @param {string} basePath
|
||||
* @returns {string}
|
||||
*/
|
||||
export function normalizeBasePath(basePath) {
|
||||
if (!_.isString(basePath)) {
|
||||
throw new Error(`Invalid path prefix ${basePath}`);
|
||||
}
|
||||
|
||||
// ensure the path prefix does not end in '/', since our method map
|
||||
// starts all paths with '/'
|
||||
basePath = basePath.replace(/\/$/, '');
|
||||
|
||||
// likewise, ensure the path prefix does always START with /, unless the path
|
||||
// is empty meaning no base path at all
|
||||
if (basePath !== '' && !basePath.startsWith('/')) {
|
||||
basePath = `/${basePath}`;
|
||||
}
|
||||
|
||||
return basePath;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {import('express').Express} app
|
||||
@@ -70,127 +223,6 @@ async function createServer (app, cliArgs) {
|
||||
}, app);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {ServerOpts} opts
|
||||
* @returns {Promise<AppiumServer>}
|
||||
*/
|
||||
export async function server(opts) {
|
||||
const {
|
||||
routeConfiguringFunction,
|
||||
port,
|
||||
hostname,
|
||||
cliArgs = /** @type {import('@appium/types').ServerArgs} */ ({}),
|
||||
allowCors = true,
|
||||
basePath = DEFAULT_BASE_PATH,
|
||||
extraMethodMap = {},
|
||||
serverUpdaters = [],
|
||||
keepAliveTimeout = KEEP_ALIVE_TIMEOUT_MS,
|
||||
requestTimeout,
|
||||
} = opts;
|
||||
|
||||
const app = express();
|
||||
const httpServer = await createServer(app, cliArgs);
|
||||
|
||||
return await new B(async (resolve, reject) => {
|
||||
// we put an async function as the promise constructor because we want some things to happen in
|
||||
// serial (application of plugin updates, for example). But we still need to use a promise here
|
||||
// because some elements of server start failure only happen in httpServer listeners. So the
|
||||
// way we resolve it is to use an async function here but to wrap all the inner logic in
|
||||
// try/catch so any errors can be passed to reject.
|
||||
try {
|
||||
const appiumServer = configureHttp({
|
||||
httpServer,
|
||||
reject,
|
||||
keepAliveTimeout,
|
||||
gracefulShutdownTimeout: cliArgs.shutdownTimeout,
|
||||
});
|
||||
configureServer({
|
||||
app,
|
||||
addRoutes: routeConfiguringFunction,
|
||||
allowCors,
|
||||
basePath,
|
||||
extraMethodMap,
|
||||
webSocketsMapping: appiumServer.webSocketsMapping,
|
||||
});
|
||||
// allow extensions to update the app and http server objects
|
||||
for (const updater of serverUpdaters) {
|
||||
await updater(app, appiumServer, cliArgs);
|
||||
}
|
||||
|
||||
// once all configurations and updaters have been applied, make sure to set up a catchall
|
||||
// handler so that anything unknown 404s. But do this after everything else since we don't
|
||||
// want to block extensions' ability to add routes if they want.
|
||||
app.all('/*all', catch404Handler);
|
||||
|
||||
await startServer({
|
||||
httpServer,
|
||||
hostname,
|
||||
port,
|
||||
keepAliveTimeout,
|
||||
requestTimeout,
|
||||
});
|
||||
|
||||
resolve(appiumServer);
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up some Express middleware and stuff
|
||||
* @param {ConfigureServerOpts} opts
|
||||
*/
|
||||
export function configureServer({
|
||||
app,
|
||||
addRoutes,
|
||||
allowCors = true,
|
||||
basePath = DEFAULT_BASE_PATH,
|
||||
extraMethodMap = {},
|
||||
webSocketsMapping = {},
|
||||
}) {
|
||||
basePath = normalizeBasePath(basePath);
|
||||
|
||||
app.use(endLogFormatter);
|
||||
app.use(handleLogContext);
|
||||
|
||||
// set up static assets
|
||||
app.use(favicon(path.resolve(STATIC_DIR, 'favicon.ico')));
|
||||
// eslint-disable-next-line import/no-named-as-default-member
|
||||
app.use(express.static(STATIC_DIR));
|
||||
|
||||
// crash routes, for testing
|
||||
app.use(`${basePath}/produce_error`, produceError);
|
||||
app.use(`${basePath}/crash`, produceCrash);
|
||||
|
||||
app.use(handleUpgrade(webSocketsMapping));
|
||||
if (allowCors) {
|
||||
app.use(allowCrossDomain);
|
||||
} else {
|
||||
app.use(allowCrossDomainAsyncExecute(basePath));
|
||||
}
|
||||
app.use(handleIdempotency);
|
||||
app.use(defaultToJSONContentType);
|
||||
app.use(bodyParser.urlencoded({extended: true}));
|
||||
app.use(methodOverride());
|
||||
app.use(catchAllHandler);
|
||||
|
||||
// make sure appium never fails because of a file size upload limit
|
||||
app.use(bodyParser.json({limit: '1gb'}));
|
||||
|
||||
// set up start logging (which depends on bodyParser doing its thing)
|
||||
app.use(startLogFormatter);
|
||||
|
||||
addRoutes(app, {basePath, extraMethodMap});
|
||||
|
||||
// dynamic routes for testing, etc.
|
||||
app.all('/welcome', welcome);
|
||||
app.all('/test/guinea-pig', guineaPig);
|
||||
app.all('/test/guinea-pig-scrollable', guineaPigScrollable);
|
||||
app.all('/test/guinea-pig-app-banner', guineaPigAppBanner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Monkeypatches the `http.Server` instance and returns a {@linkcode AppiumServer}.
|
||||
* This function _mutates_ the `httpServer` parameter.
|
||||
@@ -212,6 +244,20 @@ function configureHttp({httpServer, reject, keepAliveTimeout, gracefulShutdownTi
|
||||
return Boolean(this['_spdyState']?.secure);
|
||||
};
|
||||
|
||||
// This avoids Express middleware timeout issues with long-lived WebSocket connections
|
||||
// See: https://github.com/appium/appium/issues/20760
|
||||
// See: https://github.com/nodejs/node/pull/59824
|
||||
if (hasShouldUpgradeCallback(httpServer)) {
|
||||
// shouldUpgradeCallback only returns a boolean to indicate if the upgrade should proceed
|
||||
// eslint-disable-next-line dot-notation
|
||||
appiumServer['shouldUpgradeCallback'] = (req) => _.toLower(req.headers?.upgrade) === 'websocket';
|
||||
appiumServer.on('upgrade', (req, socket, head) => {
|
||||
if (!tryHandleWebSocketUpgrade(req, socket, head, appiumServer.webSocketsMapping)) {
|
||||
socket.destroy();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// http.Server.close() only stops new connections, but we need to wait until
|
||||
// all connections are closed and the `close` event is emitted
|
||||
const originalClose = appiumServer.close.bind(appiumServer);
|
||||
@@ -294,29 +340,23 @@ async function startServer({
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize base path string
|
||||
* @param {string} basePath
|
||||
* @returns {string}
|
||||
* Checks if the provided server instance supports `shouldUpgradeCallback`.
|
||||
* This feature was added in Node.js v22.21.0 (LTS) and v24.9.0.
|
||||
* @param {import('http').Server} server - The HTTP server instance to check
|
||||
* @returns {boolean}
|
||||
*/
|
||||
export function normalizeBasePath(basePath) {
|
||||
if (!_.isString(basePath)) {
|
||||
throw new Error(`Invalid path prefix ${basePath}`);
|
||||
function hasShouldUpgradeCallback(server) {
|
||||
// Check if shouldUpgradeCallback is available on http.Server
|
||||
// This is a runtime check that works regardless of TypeScript types
|
||||
try {
|
||||
// Use bracket notation to access property that may not exist in type definitions
|
||||
// eslint-disable-next-line dot-notation
|
||||
return typeof server['shouldUpgradeCallback'] !== 'undefined';
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
|
||||
// ensure the path prefix does not end in '/', since our method map
|
||||
// starts all paths with '/'
|
||||
basePath = basePath.replace(/\/$/, '');
|
||||
|
||||
// likewise, ensure the path prefix does always START with /, unless the path
|
||||
// is empty meaning no base path at all
|
||||
if (basePath !== '' && !basePath.startsWith('/')) {
|
||||
basePath = `/${basePath}`;
|
||||
}
|
||||
|
||||
return basePath;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Options for {@linkcode startServer}.
|
||||
* @typedef StartServerOpts
|
||||
@@ -386,4 +426,6 @@ export function normalizeBasePath(basePath) {
|
||||
* @property {string} [basePath]
|
||||
* @property {MethodMap} [extraMethodMap]
|
||||
* @property {import('@appium/types').StringRecord} [webSocketsMapping={}]
|
||||
* @property {boolean} [useLegacyUpgradeHandler=true] - Whether to use legacy Express middleware for WebSocket upgrades.
|
||||
* Set to false when using shouldUpgradeCallback on the HTTP server (Node.js >= 22.21.0 or >= 24.9.0).
|
||||
*/
|
||||
|
||||
@@ -81,7 +81,7 @@ describe('fs', function () {
|
||||
|
||||
// Mock fsPromises.rename to simulate EXDEV error
|
||||
const originalRename = fs.rename;
|
||||
fs.rename = async (src, dst) => {
|
||||
fs.rename = async () => {
|
||||
const err = new Error('cross-device link not permitted');
|
||||
err.code = 'EXDEV';
|
||||
throw err;
|
||||
|
||||
Reference in New Issue
Block a user